Skip to content

Instantly share code, notes, and snippets.

@BexTuychiev
Created March 7, 2026 10:36
Show Gist options
  • Select an option

  • Save BexTuychiev/3db83632a3e1c2eac14c450d1fa68589 to your computer and use it in GitHub Desktop.

Select an option

Save BexTuychiev/3db83632a3e1c2eac14c450d1fa68589 to your computer and use it in GitHub Desktop.
Flask API auth.py for Claude Code Plan Mode tutorial - refactoring example
"""
Flask API Authentication Module
================================
Monolithic auth module handling token validation, role-based access control,
and session management. This file is the starting point for a Plan Mode
refactoring tutorial: splitting it into token_validation.py, role_access.py,
and session_management.py.
Repository: https://www.datacamp.com/tutorial/claude-code-plan-mode
"""
import hashlib
import hmac
import logging
import os
import time
import uuid
from datetime import datetime, timedelta, timezone
from functools import wraps
from typing import Any
import jwt
from flask import Flask, g, jsonify, request
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration constants
# ---------------------------------------------------------------------------
JWT_SECRET = os.environ.get("JWT_SECRET", "change-me-in-production")
JWT_ALGORITHM = "HS256"
TOKEN_EXPIRY_BUFFER = 30 # seconds before expiry to trigger refresh
ACCESS_TOKEN_TTL = timedelta(minutes=15)
REFRESH_TOKEN_TTL = timedelta(days=7)
SESSION_IDLE_TIMEOUT = timedelta(minutes=30)
MAX_ACTIVE_SESSIONS = 5
ROLE_HIERARCHY: dict[str, int] = {
"viewer": 10,
"editor": 20,
"admin": 30,
"superadmin": 40,
}
ROLE_PERMISSIONS: dict[str, list[str]] = {
"viewer": ["read"],
"editor": ["read", "write"],
"admin": ["read", "write", "delete", "manage_users"],
"superadmin": ["read", "write", "delete", "manage_users", "manage_roles", "audit"],
}
# ---------------------------------------------------------------------------
# In-memory stores (replace with Redis/DB in production)
# ---------------------------------------------------------------------------
_token_cache: dict[str, dict[str, Any]] = {}
_revoked_tokens: set[str] = set()
_active_sessions: dict[str, dict[str, Any]] = {}
_rate_limit_counters: dict[str, list[float]] = {}
# ---------------------------------------------------------------------------
# Token validation
# ---------------------------------------------------------------------------
def decode_jwt(token: str) -> dict[str, Any]:
"""Decode and validate a JWT token, returning the payload."""
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
except jwt.ExpiredSignatureError:
logger.warning("Token expired: %s...", token[:20])
raise
except jwt.InvalidTokenError as exc:
logger.error("Invalid token: %s", exc)
raise
jti = payload.get("jti")
if jti and jti in _revoked_tokens:
raise jwt.InvalidTokenError(f"Token {jti} has been revoked")
return payload
def _validate_expiry(payload: dict[str, Any]) -> bool:
"""Check whether a token is within the expiry buffer window."""
exp = payload.get("exp")
if exp is None:
return False
remaining = exp - time.time()
return remaining > TOKEN_EXPIRY_BUFFER
def _generate_jti() -> str:
"""Generate a unique JWT ID."""
return uuid.uuid4().hex
def create_access_token(
user_id: str,
role: str,
extra_claims: dict[str, Any] | None = None,
) -> str:
"""Create a signed JWT access token."""
now = datetime.now(timezone.utc)
jti = _generate_jti()
payload = {
"sub": user_id,
"role": role,
"iat": now,
"exp": now + ACCESS_TOKEN_TTL,
"jti": jti,
"type": "access",
}
if extra_claims:
payload.update(extra_claims)
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
# Cache for fast lookup
_token_cache[jti] = {
"user_id": user_id,
"role": role,
"created_at": now.isoformat(),
"type": "access",
}
logger.info("Access token created for user %s (jti=%s)", user_id, jti)
return token
def create_refresh_token(user_id: str, role: str) -> str:
"""Create a long-lived refresh token."""
now = datetime.now(timezone.utc)
jti = _generate_jti()
payload = {
"sub": user_id,
"role": role,
"iat": now,
"exp": now + REFRESH_TOKEN_TTL,
"jti": jti,
"type": "refresh",
}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
_token_cache[jti] = {
"user_id": user_id,
"role": role,
"created_at": now.isoformat(),
"type": "refresh",
}
logger.info("Refresh token created for user %s (jti=%s)", user_id, jti)
return token
def verify_token(token: str) -> dict[str, Any]:
"""
Full token verification pipeline: decode, check revocation, validate
expiry buffer, and update cache hit timestamp.
"""
payload = decode_jwt(token)
jti = payload.get("jti")
if not _validate_expiry(payload):
logger.info("Token %s is within expiry buffer", jti)
payload["_near_expiry"] = True
# Update cache with last-used timestamp
if jti and jti in _token_cache:
_token_cache[jti]["last_used"] = time.time()
return payload
def revoke_token(jti: str) -> bool:
"""Revoke a token by its JTI."""
_revoked_tokens.add(jti)
_token_cache.pop(jti, None)
logger.info("Token revoked: %s", jti)
return True
def revoke_all_user_tokens(user_id: str) -> int:
"""Revoke every cached token belonging to a user."""
count = 0
to_remove = []
for jti, meta in _token_cache.items():
if meta.get("user_id") == user_id:
_revoked_tokens.add(jti)
to_remove.append(jti)
count += 1
for jti in to_remove:
del _token_cache[jti]
logger.info("Revoked %d tokens for user %s", count, user_id)
return count
def cleanup_expired_cache() -> int:
"""Remove expired entries from the token cache."""
now = time.time()
expired = [
jti
for jti, meta in _token_cache.items()
if meta.get("last_used") and now - meta["last_used"] > REFRESH_TOKEN_TTL.total_seconds()
]
for jti in expired:
del _token_cache[jti]
return len(expired)
# ---------------------------------------------------------------------------
# Role-based access control
# ---------------------------------------------------------------------------
def check_role(required_role: str, user_role: str) -> bool:
"""Check if user_role meets or exceeds the required_role in the hierarchy."""
required_level = ROLE_HIERARCHY.get(required_role)
user_level = ROLE_HIERARCHY.get(user_role)
if required_level is None:
logger.error("Unknown required role: %s", required_role)
return False
if user_level is None:
logger.error("Unknown user role: %s", user_role)
return False
return user_level >= required_level
def get_permissions(role: str) -> list[str]:
"""Return the list of permissions for a given role."""
perms = ROLE_PERMISSIONS.get(role, [])
if not perms:
logger.warning("No permissions defined for role: %s", role)
return perms
def has_permission(role: str, permission: str) -> bool:
"""Check whether a role includes a specific permission."""
return permission in get_permissions(role)
def get_role_level(role: str) -> int:
"""Return the numeric level for a role, or -1 if unknown."""
return ROLE_HIERARCHY.get(role, -1)
def validate_role_transition(current_role: str, target_role: str) -> bool:
"""Ensure a role change is valid (no self-promotion beyond one level)."""
current_level = get_role_level(current_role)
target_level = get_role_level(target_role)
if current_level < 0 or target_level < 0:
return False
# Users can only be promoted one level at a time
return target_level <= current_level + 10
def require_role(role: str):
"""Flask decorator that enforces a minimum role on a route."""
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
user_role = getattr(g, "current_user_role", None)
if user_role is None:
return jsonify({"error": "Authentication required"}), 401
if not check_role(role, user_role):
return jsonify({"error": "Insufficient permissions"}), 403
return fn(*args, **kwargs)
return wrapper
return decorator
def require_permission(permission: str):
"""Flask decorator that enforces a specific permission."""
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
user_role = getattr(g, "current_user_role", None)
if user_role is None:
return jsonify({"error": "Authentication required"}), 401
if not has_permission(user_role, permission):
return jsonify({"error": f"Missing permission: {permission}"}), 403
return fn(*args, **kwargs)
return wrapper
return decorator
# ---------------------------------------------------------------------------
# Session management
# ---------------------------------------------------------------------------
def create_session(user_id: str, metadata: dict[str, Any] | None = None) -> str:
"""Create a new session, enforcing the max-active-sessions limit."""
user_sessions = [
(sid, s) for sid, s in _active_sessions.items() if s["user_id"] == user_id
]
if len(user_sessions) >= MAX_ACTIVE_SESSIONS:
# Evict the oldest session
oldest_sid = min(user_sessions, key=lambda x: x[1]["created_at"])[0]
invalidate_session(oldest_sid)
logger.info(
"Evicted oldest session %s for user %s (limit=%d)",
oldest_sid,
user_id,
MAX_ACTIVE_SESSIONS,
)
session_id = uuid.uuid4().hex
now = datetime.now(timezone.utc)
_active_sessions[session_id] = {
"user_id": user_id,
"created_at": now.isoformat(),
"last_activity": now.isoformat(),
"metadata": metadata or {},
"ip_address": request.remote_addr if request else None,
"user_agent": (
request.headers.get("User-Agent") if request else None
),
}
logger.info("Session created: %s for user %s", session_id, user_id)
return session_id
def refresh_session(session_id: str) -> bool:
"""
Update the last-activity timestamp for a session. Also checks
the token cache to verify the user's tokens are still valid.
"""
session = _active_sessions.get(session_id)
if session is None:
logger.warning("Attempted to refresh unknown session: %s", session_id)
return False
# Check idle timeout
last = datetime.fromisoformat(session["last_activity"])
if datetime.now(timezone.utc) - last > SESSION_IDLE_TIMEOUT:
logger.info("Session %s timed out (idle)", session_id)
invalidate_session(session_id)
return False
# Cross-check: does the user still have a valid cached token?
user_id = session["user_id"]
has_valid_token = any(
meta["user_id"] == user_id and meta["type"] == "access"
for meta in _token_cache.values()
)
if not has_valid_token:
logger.info("No valid access token for user %s, invalidating session", user_id)
invalidate_session(session_id)
return False
session["last_activity"] = datetime.now(timezone.utc).isoformat()
return True
def invalidate_session(session_id: str) -> bool:
"""Remove a session from the active store."""
removed = _active_sessions.pop(session_id, None)
if removed:
logger.info("Session invalidated: %s", session_id)
return True
return False
def invalidate_all_user_sessions(user_id: str) -> int:
"""Remove all sessions for a user."""
to_remove = [
sid for sid, s in _active_sessions.items() if s["user_id"] == user_id
]
for sid in to_remove:
del _active_sessions[sid]
logger.info("Invalidated %d sessions for user %s", len(to_remove), user_id)
return len(to_remove)
def get_active_sessions(user_id: str) -> list[dict[str, Any]]:
"""List all active sessions for a user."""
return [
{"session_id": sid, **data}
for sid, data in _active_sessions.items()
if data["user_id"] == user_id
]
def get_session_info(session_id: str) -> dict[str, Any] | None:
"""Return metadata for a single session."""
session = _active_sessions.get(session_id)
if session:
return {"session_id": session_id, **session}
return None
# ---------------------------------------------------------------------------
# Rate limiting (simple in-memory sliding window)
# ---------------------------------------------------------------------------
def check_rate_limit(
identifier: str, max_requests: int = 60, window_seconds: int = 60
) -> bool:
"""Return True if the request is within rate limits."""
now = time.time()
window_start = now - window_seconds
if identifier not in _rate_limit_counters:
_rate_limit_counters[identifier] = []
# Prune old entries
_rate_limit_counters[identifier] = [
ts for ts in _rate_limit_counters[identifier] if ts > window_start
]
if len(_rate_limit_counters[identifier]) >= max_requests:
return False
_rate_limit_counters[identifier].append(now)
return True
# ---------------------------------------------------------------------------
# Flask middleware / request hooks
# ---------------------------------------------------------------------------
def init_auth(app: Flask) -> None:
"""Register authentication middleware on a Flask app."""
@app.before_request
def _authenticate():
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
g.current_user = None
g.current_user_role = None
return
token = auth_header[7:]
try:
payload = verify_token(token)
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
g.current_user = None
g.current_user_role = None
return
g.current_user = payload.get("sub")
g.current_user_role = payload.get("role")
# Rate-limit per user
if g.current_user and not check_rate_limit(g.current_user):
return jsonify({"error": "Rate limit exceeded"}), 429
# Auto-refresh session if one exists
session_id = request.headers.get("X-Session-ID")
if session_id:
refresh_session(session_id)
def login_required(fn):
"""Decorator that requires an authenticated user on g."""
@wraps(fn)
def wrapper(*args, **kwargs):
if getattr(g, "current_user", None) is None:
return jsonify({"error": "Authentication required"}), 401
return fn(*args, **kwargs)
return wrapper
# ---------------------------------------------------------------------------
# Dead code (kept "for later" but never referenced)
# ---------------------------------------------------------------------------
def _legacy_hash(password: str) -> str:
"""
Old password hashing from v1. Replaced by bcrypt in v2 but never removed.
No callers exist outside this file; no test coverage.
"""
salt = "static-salt-v1"
return hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
def _migrate_v1_tokens(old_secret: str) -> int:
"""
One-time migration helper for v1 → v2 token format. Was run once in
production in 2023 and has been dead code since. No test coverage.
"""
migrated = 0
for jti, meta in list(_token_cache.items()):
if meta.get("version") == "v1":
_token_cache[jti]["version"] = "v2"
migrated += 1
return migrated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment