157 lines
4.5 KiB
Python
157 lines
4.5 KiB
Python
"""Examples demonstrating custom key extractors for rate limiting."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from fastapi_traffic import (
|
|
MemoryBackend,
|
|
RateLimiter,
|
|
RateLimitExceeded,
|
|
rate_limit,
|
|
)
|
|
from fastapi_traffic.core.limiter import set_limiter
|
|
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_: FastAPI):
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield
|
|
await limiter.close()
|
|
|
|
|
|
app = FastAPI(title="Custom Key Extractors", lifespan=lifespan)
|
|
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
|
|
)
|
|
|
|
|
|
# 1. Default: Rate limit by IP address
|
|
@app.get("/by-ip")
|
|
@rate_limit(10, 60) # Uses default IP-based key extractor
|
|
async def by_ip(request: Request) -> dict[str, str]:
|
|
"""Rate limited by client IP address (default behavior)."""
|
|
return {
|
|
"limited_by": "ip",
|
|
"client_ip": request.client.host if request.client else "unknown",
|
|
}
|
|
|
|
|
|
# 2. Rate limit by API key
|
|
def api_key_extractor(request: Request) -> str:
|
|
"""Extract API key from header."""
|
|
api_key = request.headers.get("X-API-Key", "anonymous")
|
|
return f"api_key:{api_key}"
|
|
|
|
|
|
@app.get("/by-api-key")
|
|
@rate_limit(
|
|
limit=100,
|
|
window_size=3600, # 100 requests per hour per API key
|
|
key_extractor=api_key_extractor,
|
|
)
|
|
async def by_api_key(request: Request) -> dict[str, str]:
|
|
"""Rate limited by API key from X-API-Key header."""
|
|
api_key = request.headers.get("X-API-Key", "anonymous")
|
|
return {"limited_by": "api_key", "api_key": api_key}
|
|
|
|
|
|
# 3. Rate limit by user ID (from JWT or session)
|
|
def user_id_extractor(request: Request) -> str:
|
|
"""Extract user ID from request state or header."""
|
|
# In real apps, this would come from decoded JWT or session
|
|
user_id = request.headers.get("X-User-ID", "anonymous")
|
|
return f"user:{user_id}"
|
|
|
|
|
|
@app.get("/by-user")
|
|
@rate_limit(
|
|
limit=50,
|
|
window_size=60,
|
|
key_extractor=user_id_extractor,
|
|
)
|
|
async def by_user(request: Request) -> dict[str, str]:
|
|
"""Rate limited by user ID."""
|
|
user_id = request.headers.get("X-User-ID", "anonymous")
|
|
return {"limited_by": "user_id", "user_id": user_id}
|
|
|
|
|
|
# 4. Rate limit by endpoint + IP (separate limits per endpoint)
|
|
def endpoint_ip_extractor(request: Request) -> str:
|
|
"""Combine endpoint path with IP for per-endpoint limits."""
|
|
ip = request.client.host if request.client else "unknown"
|
|
path = request.url.path
|
|
return f"endpoint:{path}:ip:{ip}"
|
|
|
|
|
|
@app.get("/endpoint-specific")
|
|
@rate_limit(
|
|
limit=5,
|
|
window_size=60,
|
|
key_extractor=endpoint_ip_extractor,
|
|
)
|
|
async def endpoint_specific(_: Request) -> dict[str, str]:
|
|
"""Each endpoint has its own rate limit counter."""
|
|
return {"limited_by": "endpoint+ip"}
|
|
|
|
|
|
# 5. Rate limit by organization/tenant (multi-tenant apps)
|
|
def tenant_extractor(request: Request) -> str:
|
|
"""Extract tenant from subdomain or header."""
|
|
# Could also parse from subdomain: tenant.example.com
|
|
tenant = request.headers.get("X-Tenant-ID", "default")
|
|
return f"tenant:{tenant}"
|
|
|
|
|
|
@app.get("/by-tenant")
|
|
@rate_limit(
|
|
limit=1000,
|
|
window_size=3600, # 1000 requests per hour per tenant
|
|
key_extractor=tenant_extractor,
|
|
)
|
|
async def by_tenant(request: Request) -> dict[str, str]:
|
|
"""Rate limited by tenant/organization."""
|
|
tenant = request.headers.get("X-Tenant-ID", "default")
|
|
return {"limited_by": "tenant", "tenant_id": tenant}
|
|
|
|
|
|
# 6. Composite key: User + Action type
|
|
def user_action_extractor(request: Request) -> str:
|
|
"""Rate limit specific actions per user."""
|
|
user_id = request.headers.get("X-User-ID", "anonymous")
|
|
action = request.query_params.get("action", "default")
|
|
return f"user:{user_id}:action:{action}"
|
|
|
|
|
|
@app.get("/user-action")
|
|
@rate_limit(
|
|
limit=10,
|
|
window_size=60,
|
|
key_extractor=user_action_extractor,
|
|
)
|
|
async def user_action(
|
|
request: Request,
|
|
action: str = "default",
|
|
) -> dict[str, str]:
|
|
"""Rate limited by user + action combination."""
|
|
user_id = request.headers.get("X-User-ID", "anonymous")
|
|
return {"limited_by": "user+action", "user_id": user_id, "action": action}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|