Created
March 7, 2026 10:36
-
-
Save BexTuychiev/3db83632a3e1c2eac14c450d1fa68589 to your computer and use it in GitHub Desktop.
Flask API auth.py for Claude Code Plan Mode tutorial - refactoring example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Flask API Authentication Module | |
| ================================ | |
| Monolithic auth module handling token validation, role-based access control, | |
| and session management. This file is the starting point for a Plan Mode | |
| refactoring tutorial: splitting it into token_validation.py, role_access.py, | |
| and session_management.py. | |
| Repository: https://www.datacamp.com/tutorial/claude-code-plan-mode | |
| """ | |
| import hashlib | |
| import hmac | |
| import logging | |
| import os | |
| import time | |
| import uuid | |
| from datetime import datetime, timedelta, timezone | |
| from functools import wraps | |
| from typing import Any | |
| import jwt | |
| from flask import Flask, g, jsonify, request | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Configuration constants | |
| # --------------------------------------------------------------------------- | |
| JWT_SECRET = os.environ.get("JWT_SECRET", "change-me-in-production") | |
| JWT_ALGORITHM = "HS256" | |
| TOKEN_EXPIRY_BUFFER = 30 # seconds before expiry to trigger refresh | |
| ACCESS_TOKEN_TTL = timedelta(minutes=15) | |
| REFRESH_TOKEN_TTL = timedelta(days=7) | |
| SESSION_IDLE_TIMEOUT = timedelta(minutes=30) | |
| MAX_ACTIVE_SESSIONS = 5 | |
| ROLE_HIERARCHY: dict[str, int] = { | |
| "viewer": 10, | |
| "editor": 20, | |
| "admin": 30, | |
| "superadmin": 40, | |
| } | |
| ROLE_PERMISSIONS: dict[str, list[str]] = { | |
| "viewer": ["read"], | |
| "editor": ["read", "write"], | |
| "admin": ["read", "write", "delete", "manage_users"], | |
| "superadmin": ["read", "write", "delete", "manage_users", "manage_roles", "audit"], | |
| } | |
| # --------------------------------------------------------------------------- | |
| # In-memory stores (replace with Redis/DB in production) | |
| # --------------------------------------------------------------------------- | |
| _token_cache: dict[str, dict[str, Any]] = {} | |
| _revoked_tokens: set[str] = set() | |
| _active_sessions: dict[str, dict[str, Any]] = {} | |
| _rate_limit_counters: dict[str, list[float]] = {} | |
| # --------------------------------------------------------------------------- | |
| # Token validation | |
| # --------------------------------------------------------------------------- | |
| def decode_jwt(token: str) -> dict[str, Any]: | |
| """Decode and validate a JWT token, returning the payload.""" | |
| try: | |
| payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) | |
| except jwt.ExpiredSignatureError: | |
| logger.warning("Token expired: %s...", token[:20]) | |
| raise | |
| except jwt.InvalidTokenError as exc: | |
| logger.error("Invalid token: %s", exc) | |
| raise | |
| jti = payload.get("jti") | |
| if jti and jti in _revoked_tokens: | |
| raise jwt.InvalidTokenError(f"Token {jti} has been revoked") | |
| return payload | |
| def _validate_expiry(payload: dict[str, Any]) -> bool: | |
| """Check whether a token is within the expiry buffer window.""" | |
| exp = payload.get("exp") | |
| if exp is None: | |
| return False | |
| remaining = exp - time.time() | |
| return remaining > TOKEN_EXPIRY_BUFFER | |
| def _generate_jti() -> str: | |
| """Generate a unique JWT ID.""" | |
| return uuid.uuid4().hex | |
| def create_access_token( | |
| user_id: str, | |
| role: str, | |
| extra_claims: dict[str, Any] | None = None, | |
| ) -> str: | |
| """Create a signed JWT access token.""" | |
| now = datetime.now(timezone.utc) | |
| jti = _generate_jti() | |
| payload = { | |
| "sub": user_id, | |
| "role": role, | |
| "iat": now, | |
| "exp": now + ACCESS_TOKEN_TTL, | |
| "jti": jti, | |
| "type": "access", | |
| } | |
| if extra_claims: | |
| payload.update(extra_claims) | |
| token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) | |
| # Cache for fast lookup | |
| _token_cache[jti] = { | |
| "user_id": user_id, | |
| "role": role, | |
| "created_at": now.isoformat(), | |
| "type": "access", | |
| } | |
| logger.info("Access token created for user %s (jti=%s)", user_id, jti) | |
| return token | |
| def create_refresh_token(user_id: str, role: str) -> str: | |
| """Create a long-lived refresh token.""" | |
| now = datetime.now(timezone.utc) | |
| jti = _generate_jti() | |
| payload = { | |
| "sub": user_id, | |
| "role": role, | |
| "iat": now, | |
| "exp": now + REFRESH_TOKEN_TTL, | |
| "jti": jti, | |
| "type": "refresh", | |
| } | |
| token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) | |
| _token_cache[jti] = { | |
| "user_id": user_id, | |
| "role": role, | |
| "created_at": now.isoformat(), | |
| "type": "refresh", | |
| } | |
| logger.info("Refresh token created for user %s (jti=%s)", user_id, jti) | |
| return token | |
| def verify_token(token: str) -> dict[str, Any]: | |
| """ | |
| Full token verification pipeline: decode, check revocation, validate | |
| expiry buffer, and update cache hit timestamp. | |
| """ | |
| payload = decode_jwt(token) | |
| jti = payload.get("jti") | |
| if not _validate_expiry(payload): | |
| logger.info("Token %s is within expiry buffer", jti) | |
| payload["_near_expiry"] = True | |
| # Update cache with last-used timestamp | |
| if jti and jti in _token_cache: | |
| _token_cache[jti]["last_used"] = time.time() | |
| return payload | |
| def revoke_token(jti: str) -> bool: | |
| """Revoke a token by its JTI.""" | |
| _revoked_tokens.add(jti) | |
| _token_cache.pop(jti, None) | |
| logger.info("Token revoked: %s", jti) | |
| return True | |
| def revoke_all_user_tokens(user_id: str) -> int: | |
| """Revoke every cached token belonging to a user.""" | |
| count = 0 | |
| to_remove = [] | |
| for jti, meta in _token_cache.items(): | |
| if meta.get("user_id") == user_id: | |
| _revoked_tokens.add(jti) | |
| to_remove.append(jti) | |
| count += 1 | |
| for jti in to_remove: | |
| del _token_cache[jti] | |
| logger.info("Revoked %d tokens for user %s", count, user_id) | |
| return count | |
| def cleanup_expired_cache() -> int: | |
| """Remove expired entries from the token cache.""" | |
| now = time.time() | |
| expired = [ | |
| jti | |
| for jti, meta in _token_cache.items() | |
| if meta.get("last_used") and now - meta["last_used"] > REFRESH_TOKEN_TTL.total_seconds() | |
| ] | |
| for jti in expired: | |
| del _token_cache[jti] | |
| return len(expired) | |
| # --------------------------------------------------------------------------- | |
| # Role-based access control | |
| # --------------------------------------------------------------------------- | |
| def check_role(required_role: str, user_role: str) -> bool: | |
| """Check if user_role meets or exceeds the required_role in the hierarchy.""" | |
| required_level = ROLE_HIERARCHY.get(required_role) | |
| user_level = ROLE_HIERARCHY.get(user_role) | |
| if required_level is None: | |
| logger.error("Unknown required role: %s", required_role) | |
| return False | |
| if user_level is None: | |
| logger.error("Unknown user role: %s", user_role) | |
| return False | |
| return user_level >= required_level | |
| def get_permissions(role: str) -> list[str]: | |
| """Return the list of permissions for a given role.""" | |
| perms = ROLE_PERMISSIONS.get(role, []) | |
| if not perms: | |
| logger.warning("No permissions defined for role: %s", role) | |
| return perms | |
| def has_permission(role: str, permission: str) -> bool: | |
| """Check whether a role includes a specific permission.""" | |
| return permission in get_permissions(role) | |
| def get_role_level(role: str) -> int: | |
| """Return the numeric level for a role, or -1 if unknown.""" | |
| return ROLE_HIERARCHY.get(role, -1) | |
| def validate_role_transition(current_role: str, target_role: str) -> bool: | |
| """Ensure a role change is valid (no self-promotion beyond one level).""" | |
| current_level = get_role_level(current_role) | |
| target_level = get_role_level(target_role) | |
| if current_level < 0 or target_level < 0: | |
| return False | |
| # Users can only be promoted one level at a time | |
| return target_level <= current_level + 10 | |
| def require_role(role: str): | |
| """Flask decorator that enforces a minimum role on a route.""" | |
| def decorator(fn): | |
| @wraps(fn) | |
| def wrapper(*args, **kwargs): | |
| user_role = getattr(g, "current_user_role", None) | |
| if user_role is None: | |
| return jsonify({"error": "Authentication required"}), 401 | |
| if not check_role(role, user_role): | |
| return jsonify({"error": "Insufficient permissions"}), 403 | |
| return fn(*args, **kwargs) | |
| return wrapper | |
| return decorator | |
| def require_permission(permission: str): | |
| """Flask decorator that enforces a specific permission.""" | |
| def decorator(fn): | |
| @wraps(fn) | |
| def wrapper(*args, **kwargs): | |
| user_role = getattr(g, "current_user_role", None) | |
| if user_role is None: | |
| return jsonify({"error": "Authentication required"}), 401 | |
| if not has_permission(user_role, permission): | |
| return jsonify({"error": f"Missing permission: {permission}"}), 403 | |
| return fn(*args, **kwargs) | |
| return wrapper | |
| return decorator | |
| # --------------------------------------------------------------------------- | |
| # Session management | |
| # --------------------------------------------------------------------------- | |
| def create_session(user_id: str, metadata: dict[str, Any] | None = None) -> str: | |
| """Create a new session, enforcing the max-active-sessions limit.""" | |
| user_sessions = [ | |
| (sid, s) for sid, s in _active_sessions.items() if s["user_id"] == user_id | |
| ] | |
| if len(user_sessions) >= MAX_ACTIVE_SESSIONS: | |
| # Evict the oldest session | |
| oldest_sid = min(user_sessions, key=lambda x: x[1]["created_at"])[0] | |
| invalidate_session(oldest_sid) | |
| logger.info( | |
| "Evicted oldest session %s for user %s (limit=%d)", | |
| oldest_sid, | |
| user_id, | |
| MAX_ACTIVE_SESSIONS, | |
| ) | |
| session_id = uuid.uuid4().hex | |
| now = datetime.now(timezone.utc) | |
| _active_sessions[session_id] = { | |
| "user_id": user_id, | |
| "created_at": now.isoformat(), | |
| "last_activity": now.isoformat(), | |
| "metadata": metadata or {}, | |
| "ip_address": request.remote_addr if request else None, | |
| "user_agent": ( | |
| request.headers.get("User-Agent") if request else None | |
| ), | |
| } | |
| logger.info("Session created: %s for user %s", session_id, user_id) | |
| return session_id | |
| def refresh_session(session_id: str) -> bool: | |
| """ | |
| Update the last-activity timestamp for a session. Also checks | |
| the token cache to verify the user's tokens are still valid. | |
| """ | |
| session = _active_sessions.get(session_id) | |
| if session is None: | |
| logger.warning("Attempted to refresh unknown session: %s", session_id) | |
| return False | |
| # Check idle timeout | |
| last = datetime.fromisoformat(session["last_activity"]) | |
| if datetime.now(timezone.utc) - last > SESSION_IDLE_TIMEOUT: | |
| logger.info("Session %s timed out (idle)", session_id) | |
| invalidate_session(session_id) | |
| return False | |
| # Cross-check: does the user still have a valid cached token? | |
| user_id = session["user_id"] | |
| has_valid_token = any( | |
| meta["user_id"] == user_id and meta["type"] == "access" | |
| for meta in _token_cache.values() | |
| ) | |
| if not has_valid_token: | |
| logger.info("No valid access token for user %s, invalidating session", user_id) | |
| invalidate_session(session_id) | |
| return False | |
| session["last_activity"] = datetime.now(timezone.utc).isoformat() | |
| return True | |
| def invalidate_session(session_id: str) -> bool: | |
| """Remove a session from the active store.""" | |
| removed = _active_sessions.pop(session_id, None) | |
| if removed: | |
| logger.info("Session invalidated: %s", session_id) | |
| return True | |
| return False | |
| def invalidate_all_user_sessions(user_id: str) -> int: | |
| """Remove all sessions for a user.""" | |
| to_remove = [ | |
| sid for sid, s in _active_sessions.items() if s["user_id"] == user_id | |
| ] | |
| for sid in to_remove: | |
| del _active_sessions[sid] | |
| logger.info("Invalidated %d sessions for user %s", len(to_remove), user_id) | |
| return len(to_remove) | |
| def get_active_sessions(user_id: str) -> list[dict[str, Any]]: | |
| """List all active sessions for a user.""" | |
| return [ | |
| {"session_id": sid, **data} | |
| for sid, data in _active_sessions.items() | |
| if data["user_id"] == user_id | |
| ] | |
| def get_session_info(session_id: str) -> dict[str, Any] | None: | |
| """Return metadata for a single session.""" | |
| session = _active_sessions.get(session_id) | |
| if session: | |
| return {"session_id": session_id, **session} | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Rate limiting (simple in-memory sliding window) | |
| # --------------------------------------------------------------------------- | |
| def check_rate_limit( | |
| identifier: str, max_requests: int = 60, window_seconds: int = 60 | |
| ) -> bool: | |
| """Return True if the request is within rate limits.""" | |
| now = time.time() | |
| window_start = now - window_seconds | |
| if identifier not in _rate_limit_counters: | |
| _rate_limit_counters[identifier] = [] | |
| # Prune old entries | |
| _rate_limit_counters[identifier] = [ | |
| ts for ts in _rate_limit_counters[identifier] if ts > window_start | |
| ] | |
| if len(_rate_limit_counters[identifier]) >= max_requests: | |
| return False | |
| _rate_limit_counters[identifier].append(now) | |
| return True | |
| # --------------------------------------------------------------------------- | |
| # Flask middleware / request hooks | |
| # --------------------------------------------------------------------------- | |
| def init_auth(app: Flask) -> None: | |
| """Register authentication middleware on a Flask app.""" | |
| @app.before_request | |
| def _authenticate(): | |
| auth_header = request.headers.get("Authorization", "") | |
| if not auth_header.startswith("Bearer "): | |
| g.current_user = None | |
| g.current_user_role = None | |
| return | |
| token = auth_header[7:] | |
| try: | |
| payload = verify_token(token) | |
| except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): | |
| g.current_user = None | |
| g.current_user_role = None | |
| return | |
| g.current_user = payload.get("sub") | |
| g.current_user_role = payload.get("role") | |
| # Rate-limit per user | |
| if g.current_user and not check_rate_limit(g.current_user): | |
| return jsonify({"error": "Rate limit exceeded"}), 429 | |
| # Auto-refresh session if one exists | |
| session_id = request.headers.get("X-Session-ID") | |
| if session_id: | |
| refresh_session(session_id) | |
| def login_required(fn): | |
| """Decorator that requires an authenticated user on g.""" | |
| @wraps(fn) | |
| def wrapper(*args, **kwargs): | |
| if getattr(g, "current_user", None) is None: | |
| return jsonify({"error": "Authentication required"}), 401 | |
| return fn(*args, **kwargs) | |
| return wrapper | |
| # --------------------------------------------------------------------------- | |
| # Dead code (kept "for later" but never referenced) | |
| # --------------------------------------------------------------------------- | |
| def _legacy_hash(password: str) -> str: | |
| """ | |
| Old password hashing from v1. Replaced by bcrypt in v2 but never removed. | |
| No callers exist outside this file; no test coverage. | |
| """ | |
| salt = "static-salt-v1" | |
| return hashlib.sha256(f"{salt}{password}".encode()).hexdigest() | |
| def _migrate_v1_tokens(old_secret: str) -> int: | |
| """ | |
| One-time migration helper for v1 → v2 token format. Was run once in | |
| production in 2023 and has been dead code since. No test coverage. | |
| """ | |
| migrated = 0 | |
| for jti, meta in list(_token_cache.items()): | |
| if meta.get("version") == "v1": | |
| _token_cache[jti]["version"] = "v2" | |
| migrated += 1 | |
| return migrated |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment