Skip to content

Instantly share code, notes, and snippets.

@RyadPasha
Created February 26, 2025 04:38
Show Gist options
  • Select an option

  • Save RyadPasha/c357144d1757a367d0c74ef25094e1a0 to your computer and use it in GitHub Desktop.

Select an option

Save RyadPasha/c357144d1757a367d0c74ef25094e1a0 to your computer and use it in GitHub Desktop.
Middleware for Logging AWS Lambda Requests
"""
Middleware for Logging AWS Lambda Requests
This module implements a middleware for logging AWS Lambda function
requests and responses using AWS Lambda Powertools. It provides
automatic logging of request metadata, including headers, query
parameters, source IP, execution time, and response details.
Requires: Python 3.10+
Author: Mohamed Riyad
Date: 2024-02-21
Dependencies:
- aws_lambda_powertools
- datetime
- os
- time
- typing
Usage:
This middleware is applied using the `@lambda_handler_decorator`
from AWS Lambda Powertools.
Example:
```python
from aws_lambda_powertools.middleware_factory import lambda_handler_decorator
from aws_lambda_powertools.utilities.typing import LambdaContext
@lambda_handler_decorator(trace_execution=True)
def my_lambda_handler(event: dict, context: LambdaContext):
return {"statusCode": 200, "body": "Hello World"}
```
"""
import os
import time
from typing import Callable, TypedDict
from datetime import datetime, timezone
from urllib.parse import urlencode
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.middleware_factory import lambda_handler_decorator
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools import Logger
tracer = Tracer()
logger = Logger()
# Read log level from environment variable
LOG_LEVEL = os.getenv("LOG_LEVEL", "BASIC").upper() # Options: "NONE", "BASIC" or "ADVANCED"
# Headers that should be masked
SENSITIVE_HEADERS = ["authorization", "cookie", "x-api-key", "x-amz-security-token"]
class MetadataDetails(TypedDict):
headers: dict[str, str]
request_uri: str
query_string: str
http_method: str
request_id: str | None
source_ip: str | None
body: str | None
def _mask_sensitive_headers(headers: dict) -> dict:
"""Mask sensitive headers."""
masked_headers = {}
for key, value in headers.items():
if key.lower() in SENSITIVE_HEADERS:
masked_headers[key] = "****"
else:
masked_headers[key] = value
return masked_headers
def _get_source_ip(event: dict) -> str | None:
"""
Extract source IP from different types of AWS Lambda event sources.
Args:
event (dict): AWS Lambda event.
Returns:
str | None: Source IP address if found, None otherwise.
"""
# API Gateway REST API
if event.get("requestContext", {}).get("identity", {}).get("sourceIp"):
return event["requestContext"]["identity"]["sourceIp"]
# API Gateway HTTP API / ALB
if event.get("requestContext", {}).get("http", {}).get("sourceIp"):
return event["requestContext"]["http"]["sourceIp"]
# CloudFront
records = event.get("Records", [])
if records and isinstance(records, list) and len(records) > 0:
cf_request = records[0].get("cf", {}).get("request", {})
if cf_request.get("clientIp"):
return cf_request["clientIp"]
# VPC Lambda (X-Forwarded-For)
if event.get("headers", {}).get("x-forwarded-for"):
return event["headers"]["x-forwarded-for"].split(",")[0].strip()
logger.warning("Could not extract source IP from event", event=event)
return None
def _get_request_id(event: dict, context: LambdaContext | None = None) -> str | None:
"""
Extract a the request ID from Lambda event or context.
Prioritizes IDs in the following order:
1. API Gateway requestId
2. CloudFront request ID
3. Lambda context request ID
Args:
event (dict): AWS Lambda event
context (str | None): AWS Lambda context
Returns:
str | None: Request ID if found, None otherwise
"""
# API Gateway (REST and HTTP API)
if "requestContext" in event:
request_id = event["requestContext"].get("requestId")
if request_id:
return request_id
# CloudFront
records = event.get("Records", [])
if records and isinstance(records, list) and len(records) > 0:
cf_request = records[0].get("cf", {}).get("request", {})
if request_id := cf_request.get("id"):
return request_id
# Fallback to Lambda context request ID
if context:
return context.aws_request_id
logger.warning("No request ID found in event or context", event=event)
return None
def _get_status_code(response: any) -> int:
"""
Extract HTTP status code from Lambda response.
"""
if isinstance(response, dict):
return response.get("statusCode", 500)
return 500
def _extract_metadata_details(event: dict, context: LambdaContext) -> MetadataDetails:
"""
Extract and construct detailed metadata from the AWS Lambda event and context objects.
This function performs the following steps:
- Masks sensitive headers to prevent logging of sensitive information (e.g., Authorization tokens).
- Extracts request headers, HTTP method, and request URI from the event payload.
- Builds the full request URI, including query parameters if present.
- Determines the request ID from available sources (e.g., API Gateway, CloudFront, Lambda context).
- Extracts the source IP address from various event structures (API Gateway REST/HTTP, CloudFront, etc.).
- Captures the request body, if provided.
Args:
event (dict): The AWS Lambda event payload containing request details.
context (LambdaContext): The Lambda context object providing execution context and request identifiers.
Returns:
MetadataDetails: A structured dictionary containing detailed metadata about the request, including:
- headers (dict[str, str]): Masked request headers.
- request_uri (str): Full request path with query parameters.
- http_method (str): HTTP method used for the request (e.g., GET, POST).
- request_id (str | None): Unique identifier for the request.
- source_ip (str | None): IP address of the requester.
- body (str | None): Request body payload, if present.
"""
masked_headers = _mask_sensitive_headers(event.get("headers", {}))
request_context = event.get("requestContext", {})
request_path = request_context.get("path") or request_context.get("http", {}).get("path", "")
query_params = event.get('queryStringParameters', {})
query_string = f"?{urlencode(query_params)}" if query_params else ""
request_uri = f"{request_path}{query_string}"
return MetadataDetails(
headers=masked_headers,
request_uri=request_uri,
http_method=request_context.get("httpMethod") or request_context.get("http", {}).get("method", "NA").upper(),
request_id=_get_request_id(event, context),
source_ip=_get_source_ip(event),
body=event.get("body", ""),
)
def _log_start_request(metadata: MetadataDetails) -> None:
"""
Log the start of a Lambda request.
Captures request details like headers, method, path, IP, and timestamp.
Logging verbosity depends on LOG_LEVEL ('BASIC' or 'ADVANCED').
Args:
event (dict): AWS Lambda event object.
metadata (MetadataDetails): Request metadata extracted from event and context.
"""
start_time = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
if LOG_LEVEL == "ADVANCED":
extra = {
"timestamp": start_time,
"request_id": metadata["request_id"],
"source_ip": metadata["source_ip"],
"request_headers": metadata["headers"],
}
if metadata.get("body"):
extra["request_body"] = metadata["body"]
logger.info(f"Start request: {metadata['http_method']}: {metadata['request_uri']}", extra=extra)
elif LOG_LEVEL == "BASIC":
logger.info(
f"Start request: {metadata['http_method']}: {metadata['request_uri']} | "
f"Request ID: {metadata['request_id']} | Timestamp: {start_time}"
)
def _log_end_request(response: dict, execution_time: float, metadata: MetadataDetails) -> None:
"""
Log the end of a Lambda request.
Captures response details like status code, headers, execution time, and timestamp.
Logging verbosity depends on LOG_LEVEL ('BASIC' or 'ADVANCED').
Args:
response (dict): Lambda function's response.
execution_time (float): Total execution duration in seconds.
metadata (MetadataDetails): Request metadata extracted from event and context.
"""
status_code = _get_status_code(response)
response_headers = _mask_sensitive_headers(response.get("headers", {}))
end_time = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
if LOG_LEVEL == "ADVANCED":
logger.info(
f"End request: {metadata['http_method']}: {metadata['request_uri']}",
extra={
"timestamp": end_time,
"request_id": metadata["request_id"],
"source_ip": metadata["source_ip"],
"status_code": status_code,
"execution_time": f"{execution_time:.4f}s",
"response_headers": response_headers,
"response_body": response.get("body", ""),
}
)
elif LOG_LEVEL == "BASIC":
logger.info(
f"End request: {metadata['http_method']}: {metadata['request_uri']} | "
f"Request ID: {metadata['request_id']} | Status: {status_code} | Time: {execution_time:.4f}s"
)
@lambda_handler_decorator(trace_execution=True)
def logging_middleware(
handler: Callable[[dict, LambdaContext], dict],
event: dict,
context: LambdaContext,
) -> dict:
"""
Middleware for logging Lambda requests and responses.
Args:
handler: The Lambda handler function to wrap
event: AWS Lambda event dictionary
context: AWS Lambda context object
Returns:
dict: The response from the Lambda handler
"""
if LOG_LEVEL != "NONE":
metadata = _extract_metadata_details(event, context)
try:
_log_start_request(metadata)
except Exception:
logger.exception("Logging middleware error during request start")
execution_start = time.time()
#################################
# ==== Invoke Lambda Handler ====
response = handler(event, context)
#################################
execution_time = time.time() - execution_start
if LOG_LEVEL != "NONE":
try:
_log_end_request(response, execution_time, metadata)
except Exception:
logger.exception("Logging middleware error during request end")
return response
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment