Files
fastapi-traffic/examples/07_redis_distributed.py

198 lines
5.4 KiB
Python

"""Example demonstrating Redis backend for distributed rate limiting.
This example shows how to use Redis for rate limiting across multiple
application instances (e.g., in a Kubernetes deployment or load-balanced setup).
Requirements:
pip install redis
Environment variables:
REDIS_URL: Redis connection URL (default: redis://localhost:6379/0)
"""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from typing import Annotated
from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimiter,
RateLimitExceeded,
rate_limit,
)
from fastapi_traffic.backends.redis import RedisBackend
from fastapi_traffic.core.limiter import set_limiter
async def create_redis_backend():
"""Create Redis backend with fallback to memory."""
try:
from fastapi_traffic import RedisBackend
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
backend = await RedisBackend.from_url(
redis_url,
key_prefix="myapp",
)
# Verify connection
if await backend.ping():
print(f"Connected to Redis at {redis_url}")
return backend
else:
print("Redis ping failed, falling back to memory backend")
return MemoryBackend()
except ImportError:
print("Redis package not installed. Install with: pip install redis")
print("Falling back to memory backend")
return MemoryBackend()
except Exception as e:
print(f"Failed to connect to Redis: {e}")
print("Falling back to memory backend")
return MemoryBackend()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup/shutdown."""
app.state.backend = await create_redis_backend()
app.state.limiter = RateLimiter(app.state.backend)
await app.state.limiter.initialize()
set_limiter(app.state.limiter)
yield
await app.state.limiter.close()
app = FastAPI(
title="Distributed Rate Limiting with Redis",
lifespan=lifespan,
)
def get_backend(request: Request) -> RedisBackend | MemoryBackend:
"""Dependency to get the rate limiting backend."""
return request.app.state.backend
def get_limiter(request: Request) -> RateLimiter:
"""Dependency to get the rate limiter."""
return request.app.state.limiter
BackendDep = Annotated[RedisBackend | MemoryBackend, Depends(get_backend)]
LimiterDep = Annotated[RateLimiter, Depends(get_limiter)]
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
# Rate limits are shared across all instances when using Redis
@app.get("/api/shared-limit")
@rate_limit(
limit=100,
window_size=60,
key_prefix="shared",
)
async def shared_limit(_: Request) -> dict[str, str]:
"""This rate limit is shared across all application instances."""
return {
"message": "Success",
"note": "Rate limit counter is shared via Redis",
}
# Per-user limits also work across instances
def user_extractor(request: Request) -> str:
user_id = request.headers.get("X-User-ID", "anonymous")
return f"user:{user_id}"
@app.get("/api/user-limit")
@rate_limit(
limit=50,
window_size=60,
key_extractor=user_extractor,
key_prefix="user_api",
)
async def user_limit(request: Request) -> dict[str, str]:
"""Per-user rate limit shared across instances."""
user_id = request.headers.get("X-User-ID", "anonymous")
return {
"message": "Success",
"user_id": user_id,
}
# Token bucket works well with Redis for burst handling
@app.get("/api/burst-allowed")
@rate_limit(
limit=100,
window_size=60,
algorithm=Algorithm.TOKEN_BUCKET,
burst_size=20,
key_prefix="burst",
)
async def burst_allowed(_: Request) -> dict[str, str]:
"""Token bucket with Redis allows controlled bursts across instances."""
return {"message": "Burst request successful"}
@app.get("/health")
async def health(backend: BackendDep) -> dict[str, object]:
"""Health check with Redis status."""
redis_healthy = False
backend_type = type(backend).__name__
if hasattr(backend, "ping"):
try:
redis_healthy = await backend.ping()
except Exception:
redis_healthy = False
return {
"status": "healthy",
"backend": backend_type,
"redis_connected": redis_healthy,
}
@app.get("/stats")
async def stats(backend: BackendDep) -> dict[str, object]:
"""Get rate limiting statistics from Redis."""
if hasattr(backend, "get_stats"):
try:
return await backend.get_stats()
except Exception as e:
return {"error": str(e)}
return {"message": "Stats not available for this backend"}
if __name__ == "__main__":
import uvicorn
# Run multiple instances on different ports to test distributed limiting:
# REDIS_URL=redis://localhost:6379/0 python 07_redis_distributed.py
# In another terminal:
# uvicorn 07_redis_distributed:app --port 8001
uvicorn.run(app, host="0.0.0.0", port=8000)