Files
fastapi-traffic/examples/04_key_extractors.py

157 lines
4.5 KiB
Python

"""Examples demonstrating custom key extractors for rate limiting."""
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
MemoryBackend,
RateLimiter,
RateLimitExceeded,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(_: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(title="Custom Key Extractors", lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
)
# 1. Default: Rate limit by IP address
@app.get("/by-ip")
@rate_limit(10, 60) # Uses default IP-based key extractor
async def by_ip(request: Request) -> dict[str, str]:
"""Rate limited by client IP address (default behavior)."""
return {
"limited_by": "ip",
"client_ip": request.client.host if request.client else "unknown",
}
# 2. Rate limit by API key
def api_key_extractor(request: Request) -> str:
"""Extract API key from header."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api_key:{api_key}"
@app.get("/by-api-key")
@rate_limit(
limit=100,
window_size=3600, # 100 requests per hour per API key
key_extractor=api_key_extractor,
)
async def by_api_key(request: Request) -> dict[str, str]:
"""Rate limited by API key from X-API-Key header."""
api_key = request.headers.get("X-API-Key", "anonymous")
return {"limited_by": "api_key", "api_key": api_key}
# 3. Rate limit by user ID (from JWT or session)
def user_id_extractor(request: Request) -> str:
"""Extract user ID from request state or header."""
# In real apps, this would come from decoded JWT or session
user_id = request.headers.get("X-User-ID", "anonymous")
return f"user:{user_id}"
@app.get("/by-user")
@rate_limit(
limit=50,
window_size=60,
key_extractor=user_id_extractor,
)
async def by_user(request: Request) -> dict[str, str]:
"""Rate limited by user ID."""
user_id = request.headers.get("X-User-ID", "anonymous")
return {"limited_by": "user_id", "user_id": user_id}
# 4. Rate limit by endpoint + IP (separate limits per endpoint)
def endpoint_ip_extractor(request: Request) -> str:
"""Combine endpoint path with IP for per-endpoint limits."""
ip = request.client.host if request.client else "unknown"
path = request.url.path
return f"endpoint:{path}:ip:{ip}"
@app.get("/endpoint-specific")
@rate_limit(
limit=5,
window_size=60,
key_extractor=endpoint_ip_extractor,
)
async def endpoint_specific(_: Request) -> dict[str, str]:
"""Each endpoint has its own rate limit counter."""
return {"limited_by": "endpoint+ip"}
# 5. Rate limit by organization/tenant (multi-tenant apps)
def tenant_extractor(request: Request) -> str:
"""Extract tenant from subdomain or header."""
# Could also parse from subdomain: tenant.example.com
tenant = request.headers.get("X-Tenant-ID", "default")
return f"tenant:{tenant}"
@app.get("/by-tenant")
@rate_limit(
limit=1000,
window_size=3600, # 1000 requests per hour per tenant
key_extractor=tenant_extractor,
)
async def by_tenant(request: Request) -> dict[str, str]:
"""Rate limited by tenant/organization."""
tenant = request.headers.get("X-Tenant-ID", "default")
return {"limited_by": "tenant", "tenant_id": tenant}
# 6. Composite key: User + Action type
def user_action_extractor(request: Request) -> str:
"""Rate limit specific actions per user."""
user_id = request.headers.get("X-User-ID", "anonymous")
action = request.query_params.get("action", "default")
return f"user:{user_id}:action:{action}"
@app.get("/user-action")
@rate_limit(
limit=10,
window_size=60,
key_extractor=user_action_extractor,
)
async def user_action(
request: Request,
action: str = "default",
) -> dict[str, str]:
"""Rate limited by user + action combination."""
user_id = request.headers.get("X-User-ID", "anonymous")
return {"limited_by": "user+action", "user_id": user_id, "action": action}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)