333 lines
10 KiB
Python
333 lines
10 KiB
Python
"""Advanced patterns and real-world use cases for rate limiting."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any
|
|
|
|
from fastapi import Depends, FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from fastapi_traffic import (
|
|
Algorithm,
|
|
MemoryBackend,
|
|
RateLimiter,
|
|
RateLimitExceeded,
|
|
rate_limit,
|
|
)
|
|
from fastapi_traffic.core.decorator import RateLimitDependency
|
|
from fastapi_traffic.core.limiter import set_limiter
|
|
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_: FastAPI):
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield
|
|
await limiter.close()
|
|
|
|
|
|
app = FastAPI(title="Advanced Patterns", lifespan=lifespan)
|
|
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def rate_limit_handler(_: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"error": "rate_limit_exceeded", "retry_after": exc.retry_after},
|
|
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Pattern 1: Cost-based rate limiting
|
|
# Different operations consume different amounts of quota
|
|
# =============================================================================
|
|
|
|
|
|
@app.get("/api/list")
|
|
@rate_limit(limit=100, window_size=60, cost=1)
|
|
async def list_items(_: Request) -> dict[str, Any]:
|
|
"""Cheap operation - costs 1 token."""
|
|
return {"items": ["a", "b", "c"], "cost": 1}
|
|
|
|
|
|
@app.get("/api/details/{item_id}")
|
|
@rate_limit(limit=100, window_size=60, cost=5)
|
|
async def get_details(_: Request, item_id: str) -> dict[str, Any]:
|
|
"""Medium operation - costs 5 tokens."""
|
|
return {"item_id": item_id, "details": "...", "cost": 5}
|
|
|
|
|
|
@app.post("/api/generate")
|
|
@rate_limit(limit=100, window_size=60, cost=20)
|
|
async def generate_content(_: Request) -> dict[str, Any]:
|
|
"""Expensive operation - costs 20 tokens."""
|
|
return {"generated": "AI-generated content...", "cost": 20}
|
|
|
|
|
|
@app.post("/api/bulk-export")
|
|
@rate_limit(limit=100, window_size=60, cost=50)
|
|
async def bulk_export(_: Request) -> dict[str, Any]:
|
|
"""Very expensive operation - costs 50 tokens."""
|
|
return {"export_url": "https://...", "cost": 50}
|
|
|
|
|
|
# =============================================================================
|
|
# Pattern 2: Sliding scale exemptions
|
|
# Gradually reduce limits instead of hard blocking
|
|
# =============================================================================
|
|
|
|
|
|
def get_request_priority(request: Request) -> int:
|
|
"""Determine request priority (higher = more important)."""
|
|
# Premium users get higher priority
|
|
if request.headers.get("X-Premium-User") == "true":
|
|
return 100
|
|
|
|
# Authenticated users get medium priority
|
|
if request.headers.get("Authorization"):
|
|
return 50
|
|
|
|
# Anonymous users get lowest priority
|
|
return 10
|
|
|
|
|
|
def should_exempt_high_priority(request: Request) -> bool:
|
|
"""Exempt high-priority requests from rate limiting."""
|
|
return get_request_priority(request) >= 100
|
|
|
|
|
|
@app.get("/api/priority-based")
|
|
@rate_limit(
|
|
limit=10,
|
|
window_size=60,
|
|
exempt_when=should_exempt_high_priority,
|
|
)
|
|
async def priority_endpoint(request: Request) -> dict[str, Any]:
|
|
"""Premium users are exempt from rate limits."""
|
|
priority = get_request_priority(request)
|
|
return {
|
|
"message": "Success",
|
|
"priority": priority,
|
|
"exempt": priority >= 100,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Pattern 3: Rate limit by resource, not just user
|
|
# Prevent abuse of specific resources
|
|
# =============================================================================
|
|
|
|
|
|
def resource_key_extractor(request: Request) -> str:
|
|
"""Rate limit by resource ID + user."""
|
|
resource_id = request.path_params.get("resource_id", "unknown")
|
|
user_id = request.headers.get("X-User-ID", "anonymous")
|
|
return f"resource:{resource_id}:user:{user_id}"
|
|
|
|
|
|
@app.get("/api/resources/{resource_id}")
|
|
@rate_limit(
|
|
limit=10,
|
|
window_size=60,
|
|
key_extractor=resource_key_extractor,
|
|
)
|
|
async def get_resource(_: Request, resource_id: str) -> dict[str, str]:
|
|
"""Each user can access each resource 10 times per minute."""
|
|
return {"resource_id": resource_id, "data": "..."}
|
|
|
|
|
|
# =============================================================================
|
|
# Pattern 4: Login/authentication rate limiting
|
|
# Prevent brute force attacks
|
|
# =============================================================================
|
|
|
|
|
|
def login_key_extractor(request: Request) -> str:
|
|
"""Rate limit by IP + username to prevent brute force."""
|
|
ip = request.client.host if request.client else "unknown"
|
|
# In real app, parse username from request body
|
|
username = request.headers.get("X-Username", "unknown")
|
|
return f"login:{ip}:{username}"
|
|
|
|
|
|
@app.post("/auth/login")
|
|
@rate_limit(
|
|
limit=5,
|
|
window_size=300, # 5 attempts per 5 minutes
|
|
algorithm=Algorithm.SLIDING_WINDOW, # Precise tracking for security
|
|
key_extractor=login_key_extractor,
|
|
error_message="Too many login attempts. Please try again in 5 minutes.",
|
|
)
|
|
async def login(_: Request) -> dict[str, str]:
|
|
"""Login endpoint with brute force protection."""
|
|
return {"message": "Login successful", "token": "..."}
|
|
|
|
|
|
# Password reset - even stricter limits
|
|
def password_reset_key(request: Request) -> str:
|
|
ip = request.client.host if request.client else "unknown"
|
|
return f"password_reset:{ip}"
|
|
|
|
|
|
@app.post("/auth/password-reset")
|
|
@rate_limit(
|
|
limit=3,
|
|
window_size=3600, # 3 attempts per hour
|
|
key_extractor=password_reset_key,
|
|
error_message="Too many password reset requests. Please try again later.",
|
|
)
|
|
async def password_reset(_: Request) -> dict[str, str]:
|
|
"""Password reset with strict rate limiting."""
|
|
return {"message": "Password reset email sent"}
|
|
|
|
|
|
# =============================================================================
|
|
# Pattern 5: Webhook/callback rate limiting
|
|
# Limit outgoing requests to prevent overwhelming external services
|
|
# =============================================================================
|
|
|
|
webhook_rate_limit = RateLimitDependency(
|
|
limit=100,
|
|
window_size=60,
|
|
key_prefix="webhook",
|
|
)
|
|
|
|
|
|
async def check_webhook_limit(
|
|
_: Request,
|
|
webhook_url: str,
|
|
) -> None:
|
|
"""Check rate limit before sending webhook."""
|
|
# Create key based on destination domain
|
|
from urllib.parse import urlparse
|
|
|
|
domain = urlparse(webhook_url).netloc
|
|
_key = f"webhook:{domain}" # Would be used with limiter in production
|
|
|
|
# Manually check limit (simplified example)
|
|
# In production, you'd use the limiter directly
|
|
__ = _key # Suppress unused variable warning
|
|
|
|
|
|
@app.post("/api/send-webhook")
|
|
async def send_webhook(
|
|
_: Request,
|
|
webhook_url: str = "https://example.com/webhook",
|
|
rate_info: Any = Depends(webhook_rate_limit),
|
|
) -> dict[str, Any]:
|
|
"""Send webhook with rate limiting to protect external services."""
|
|
# await check_webhook_limit(request, webhook_url)
|
|
return {
|
|
"message": "Webhook sent",
|
|
"destination": webhook_url,
|
|
"remaining_quota": rate_info.remaining,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Pattern 6: Request fingerprinting
|
|
# Detect and limit similar requests (e.g., spam prevention)
|
|
# =============================================================================
|
|
|
|
|
|
def request_fingerprint(request: Request) -> str:
|
|
"""Create fingerprint based on request characteristics."""
|
|
ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("User-Agent", "")
|
|
accept_language = request.headers.get("Accept-Language", "")
|
|
|
|
# Create hash of request characteristics
|
|
fingerprint_data = f"{ip}:{user_agent}:{accept_language}"
|
|
fingerprint = hashlib.md5(fingerprint_data.encode()).hexdigest()[:16]
|
|
|
|
return f"fingerprint:{fingerprint}"
|
|
|
|
|
|
@app.post("/api/submit-form")
|
|
@rate_limit(
|
|
limit=5,
|
|
window_size=60,
|
|
key_extractor=request_fingerprint,
|
|
error_message="Too many submissions from this device.",
|
|
)
|
|
async def submit_form(_: Request) -> dict[str, str]:
|
|
"""Form submission with fingerprint-based rate limiting."""
|
|
return {"message": "Form submitted successfully"}
|
|
|
|
|
|
# =============================================================================
|
|
# Pattern 7: Time-of-day based limits
|
|
# Different limits during peak vs off-peak hours
|
|
# =============================================================================
|
|
|
|
|
|
def is_peak_hours() -> bool:
|
|
"""Check if current time is during peak hours (9 AM - 6 PM UTC)."""
|
|
current_hour = time.gmtime().tm_hour
|
|
return 9 <= current_hour < 18
|
|
|
|
|
|
def peak_aware_exempt(_: Request) -> bool:
|
|
"""Exempt requests during off-peak hours."""
|
|
return not is_peak_hours()
|
|
|
|
|
|
@app.get("/api/peak-aware")
|
|
@rate_limit(
|
|
limit=10, # Strict limit during peak hours
|
|
window_size=60,
|
|
exempt_when=peak_aware_exempt, # No limit during off-peak
|
|
)
|
|
async def peak_aware_endpoint(_: Request) -> dict[str, Any]:
|
|
"""Stricter limits during peak hours."""
|
|
return {
|
|
"message": "Success",
|
|
"is_peak_hours": is_peak_hours(),
|
|
"rate_limited": is_peak_hours(),
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Pattern 8: Cascading limits (multiple tiers)
|
|
# =============================================================================
|
|
|
|
per_second = RateLimitDependency(limit=5, window_size=1, key_prefix="sec")
|
|
per_minute = RateLimitDependency(limit=100, window_size=60, key_prefix="min")
|
|
per_hour = RateLimitDependency(limit=1000, window_size=3600, key_prefix="hour")
|
|
|
|
|
|
async def cascading_limits(
|
|
_: Request,
|
|
sec_info: Any = Depends(per_second),
|
|
min_info: Any = Depends(per_minute),
|
|
hour_info: Any = Depends(per_hour),
|
|
) -> dict[str, Any]:
|
|
"""Apply multiple rate limit tiers."""
|
|
return {
|
|
"per_second": {"remaining": sec_info.remaining},
|
|
"per_minute": {"remaining": min_info.remaining},
|
|
"per_hour": {"remaining": hour_info.remaining},
|
|
}
|
|
|
|
|
|
@app.get("/api/cascading")
|
|
async def cascading_endpoint(
|
|
_: Request,
|
|
limits: dict[str, Any] = Depends(cascading_limits),
|
|
) -> dict[str, Any]:
|
|
"""Endpoint with per-second, per-minute, and per-hour limits."""
|
|
return {"message": "Success", "limits": limits}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|