Files
fastapi-traffic/examples/basic_usage.py

175 lines
4.7 KiB
Python

"""Basic usage examples for fastapi-traffic."""
from __future__ import annotations
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
Algorithm,
RateLimiter,
RateLimitExceeded,
SQLiteBackend,
rate_limit,
)
from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter
if TYPE_CHECKING:
from collections.abc import AsyncIterator
# Configure global rate limiter with SQLite backend for persistence
backend = SQLiteBackend("rate_limits.db")
limiter = RateLimiter(backend)
set_limiter(limiter)
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
"""Manage application lifespan - startup and shutdown."""
# Startup: Initialize the rate limiter
await limiter.initialize()
yield
# Shutdown: Cleanup
await limiter.close()
app = FastAPI(title="FastAPI Traffic Example", lifespan=lifespan)
# Exception handler for rate limit exceeded
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
"""Handle rate limit exceeded exceptions."""
headers = exc.limit_info.to_headers() if exc.limit_info else {}
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=headers,
)
# Example 1: Basic decorator usage
@app.get("/api/basic")
@rate_limit(100, 60) # 100 requests per minute
async def basic_endpoint(_: Request) -> dict[str, str]:
"""Basic rate-limited endpoint."""
return {"message": "Hello, World!"}
# Example 2: Custom algorithm
@app.get("/api/token-bucket")
@rate_limit(
limit=50,
window_size=60,
algorithm=Algorithm.TOKEN_BUCKET,
burst_size=10, # Allow bursts of up to 10 requests
)
async def token_bucket_endpoint(_: Request) -> dict[str, str]:
"""Endpoint using token bucket algorithm."""
return {"message": "Token bucket rate limiting"}
# Example 3: Sliding window for precise rate limiting
@app.get("/api/sliding-window")
@rate_limit(
limit=30,
window_size=60,
algorithm=Algorithm.SLIDING_WINDOW,
)
async def sliding_window_endpoint(_: Request) -> dict[str, str]:
"""Endpoint using sliding window algorithm."""
return {"message": "Sliding window rate limiting"}
# Example 4: Custom key extractor (rate limit by API key)
def api_key_extractor(request: Request) -> str:
"""Extract API key from header for rate limiting."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api_key:{api_key}"
@app.get("/api/by-api-key")
@rate_limit(
limit=1000,
window_size=3600, # 1000 requests per hour
key_extractor=api_key_extractor,
)
async def api_key_endpoint(_: Request) -> dict[str, str]:
"""Endpoint rate limited by API key."""
return {"message": "Rate limited by API key"}
# Example 5: Using dependency injection
rate_limit_dep = RateLimitDependency(limit=20, window_size=60)
@app.get("/api/dependency")
async def dependency_endpoint(
_: Request,
rate_info: dict[str, object] = Depends(rate_limit_dep),
) -> dict[str, object]:
"""Endpoint using rate limit as dependency."""
return {
"message": "Rate limit info available",
"rate_limit": rate_info,
}
# Example 6: Exempt certain requests
def is_admin(request: Request) -> bool:
"""Check if request is from admin."""
return request.headers.get("X-Admin-Token") == "secret-admin-token"
@app.get("/api/admin-exempt")
@rate_limit(
limit=10,
window_size=60,
exempt_when=is_admin,
)
async def admin_exempt_endpoint(_: Request) -> dict[str, str]:
"""Endpoint with admin exemption."""
return {"message": "Admins are exempt from rate limiting"}
# Example 7: Different costs for different operations
@app.post("/api/expensive")
@rate_limit(
limit=100,
window_size=60,
cost=10, # This endpoint costs 10 tokens per request
)
async def expensive_endpoint(_: Request) -> dict[str, str]:
"""Expensive operation that costs more tokens."""
return {"message": "Expensive operation completed"}
# Example 8: Global middleware rate limiting
# Uncomment to enable global rate limiting
# app.add_middleware(
# RateLimitMiddleware,
# limit=1000,
# window_size=60,
# exempt_paths={"/health", "/docs", "/openapi.json"},
# )
@app.get("/health")
async def health_check() -> dict[str, str]:
"""Health check endpoint (typically exempt from rate limiting)."""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)