198 lines
5.4 KiB
Python
198 lines
5.4 KiB
Python
"""Example demonstrating Redis backend for distributed rate limiting.
|
|
|
|
This example shows how to use Redis for rate limiting across multiple
|
|
application instances (e.g., in a Kubernetes deployment or load-balanced setup).
|
|
|
|
Requirements:
|
|
pip install redis
|
|
|
|
Environment variables:
|
|
REDIS_URL: Redis connection URL (default: redis://localhost:6379/0)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from typing import Annotated
|
|
|
|
from fastapi import Depends, FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from fastapi_traffic import (
|
|
Algorithm,
|
|
MemoryBackend,
|
|
RateLimiter,
|
|
RateLimitExceeded,
|
|
rate_limit,
|
|
)
|
|
from fastapi_traffic.backends.redis import RedisBackend
|
|
from fastapi_traffic.core.limiter import set_limiter
|
|
|
|
|
|
async def create_redis_backend():
|
|
"""Create Redis backend with fallback to memory."""
|
|
try:
|
|
from fastapi_traffic import RedisBackend
|
|
|
|
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
|
backend = await RedisBackend.from_url(
|
|
redis_url,
|
|
key_prefix="myapp",
|
|
)
|
|
|
|
# Verify connection
|
|
if await backend.ping():
|
|
print(f"Connected to Redis at {redis_url}")
|
|
return backend
|
|
else:
|
|
print("Redis ping failed, falling back to memory backend")
|
|
return MemoryBackend()
|
|
|
|
except ImportError:
|
|
print("Redis package not installed. Install with: pip install redis")
|
|
print("Falling back to memory backend")
|
|
return MemoryBackend()
|
|
|
|
except Exception as e:
|
|
print(f"Failed to connect to Redis: {e}")
|
|
print("Falling back to memory backend")
|
|
return MemoryBackend()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Lifespan context manager for startup/shutdown."""
|
|
app.state.backend = await create_redis_backend()
|
|
app.state.limiter = RateLimiter(app.state.backend)
|
|
await app.state.limiter.initialize()
|
|
set_limiter(app.state.limiter)
|
|
|
|
yield
|
|
|
|
await app.state.limiter.close()
|
|
|
|
|
|
app = FastAPI(
|
|
title="Distributed Rate Limiting with Redis",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
|
|
def get_backend(request: Request) -> RedisBackend | MemoryBackend:
|
|
"""Dependency to get the rate limiting backend."""
|
|
return request.app.state.backend
|
|
|
|
|
|
def get_limiter(request: Request) -> RateLimiter:
|
|
"""Dependency to get the rate limiter."""
|
|
return request.app.state.limiter
|
|
|
|
|
|
BackendDep = Annotated[RedisBackend | MemoryBackend, Depends(get_backend)]
|
|
LimiterDep = Annotated[RateLimiter, Depends(get_limiter)]
|
|
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "rate_limit_exceeded",
|
|
"message": exc.message,
|
|
"retry_after": exc.retry_after,
|
|
},
|
|
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
|
)
|
|
|
|
|
|
# Rate limits are shared across all instances when using Redis
|
|
@app.get("/api/shared-limit")
|
|
@rate_limit(
|
|
limit=100,
|
|
window_size=60,
|
|
key_prefix="shared",
|
|
)
|
|
async def shared_limit(_: Request) -> dict[str, str]:
|
|
"""This rate limit is shared across all application instances."""
|
|
return {
|
|
"message": "Success",
|
|
"note": "Rate limit counter is shared via Redis",
|
|
}
|
|
|
|
|
|
# Per-user limits also work across instances
|
|
def user_extractor(request: Request) -> str:
|
|
user_id = request.headers.get("X-User-ID", "anonymous")
|
|
return f"user:{user_id}"
|
|
|
|
|
|
@app.get("/api/user-limit")
|
|
@rate_limit(
|
|
limit=50,
|
|
window_size=60,
|
|
key_extractor=user_extractor,
|
|
key_prefix="user_api",
|
|
)
|
|
async def user_limit(request: Request) -> dict[str, str]:
|
|
"""Per-user rate limit shared across instances."""
|
|
user_id = request.headers.get("X-User-ID", "anonymous")
|
|
return {
|
|
"message": "Success",
|
|
"user_id": user_id,
|
|
}
|
|
|
|
|
|
# Token bucket works well with Redis for burst handling
|
|
@app.get("/api/burst-allowed")
|
|
@rate_limit(
|
|
limit=100,
|
|
window_size=60,
|
|
algorithm=Algorithm.TOKEN_BUCKET,
|
|
burst_size=20,
|
|
key_prefix="burst",
|
|
)
|
|
async def burst_allowed(_: Request) -> dict[str, str]:
|
|
"""Token bucket with Redis allows controlled bursts across instances."""
|
|
return {"message": "Burst request successful"}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health(backend: BackendDep) -> dict[str, object]:
|
|
"""Health check with Redis status."""
|
|
redis_healthy = False
|
|
backend_type = type(backend).__name__
|
|
|
|
if hasattr(backend, "ping"):
|
|
try:
|
|
redis_healthy = await backend.ping()
|
|
except Exception:
|
|
redis_healthy = False
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"backend": backend_type,
|
|
"redis_connected": redis_healthy,
|
|
}
|
|
|
|
|
|
@app.get("/stats")
|
|
async def stats(backend: BackendDep) -> dict[str, object]:
|
|
"""Get rate limiting statistics from Redis."""
|
|
if hasattr(backend, "get_stats"):
|
|
try:
|
|
return await backend.get_stats()
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
return {"message": "Stats not available for this backend"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# Run multiple instances on different ports to test distributed limiting:
|
|
# REDIS_URL=redis://localhost:6379/0 python 07_redis_distributed.py
|
|
# In another terminal:
|
|
# uvicorn 07_redis_distributed:app --port 8001
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|