Files
fastapi-traffic/examples/08_tiered_api.py

270 lines
7.4 KiB
Python

"""Example of a production-ready tiered API with different rate limits per plan."""
from __future__ import annotations
from contextlib import asynccontextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Any
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimiter,
RateLimitExceeded,
)
from fastapi_traffic.core.decorator import RateLimitDependency
from fastapi_traffic.core.limiter import set_limiter
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(_: FastAPI):
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(
title="Tiered API Example",
description="API with different rate limits based on subscription tier",
lifespan=lifespan,
)
class Tier(str, Enum):
FREE = "free"
STARTER = "starter"
PRO = "pro"
ENTERPRISE = "enterprise"
@dataclass
class TierConfig:
requests_per_minute: int
requests_per_hour: int
requests_per_day: int
burst_size: int
features: list[str]
# Tier configurations
TIER_CONFIGS: dict[Tier, TierConfig] = {
Tier.FREE: TierConfig(
requests_per_minute=10,
requests_per_hour=100,
requests_per_day=500,
burst_size=5,
features=["basic_api"],
),
Tier.STARTER: TierConfig(
requests_per_minute=60,
requests_per_hour=1000,
requests_per_day=10000,
burst_size=20,
features=["basic_api", "webhooks"],
),
Tier.PRO: TierConfig(
requests_per_minute=300,
requests_per_hour=10000,
requests_per_day=100000,
burst_size=50,
features=["basic_api", "webhooks", "analytics", "priority_support"],
),
Tier.ENTERPRISE: TierConfig(
requests_per_minute=1000,
requests_per_hour=50000,
requests_per_day=500000,
burst_size=200,
features=[
"basic_api",
"webhooks",
"analytics",
"priority_support",
"sla",
"custom_integrations",
],
),
}
# Simulated API key database
API_KEYS: dict[str, dict[str, Any]] = {
"free-key-123": {"tier": Tier.FREE, "user_id": "user_1"},
"starter-key-456": {"tier": Tier.STARTER, "user_id": "user_2"},
"pro-key-789": {"tier": Tier.PRO, "user_id": "user_3"},
"enterprise-key-000": {"tier": Tier.ENTERPRISE, "user_id": "user_4"},
}
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
api_key = request.headers.get("X-API-Key", "")
key_info = API_KEYS.get(api_key, {})
tier = key_info.get("tier", Tier.FREE)
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
"tier": tier.value,
"upgrade_url": (
"https://example.com/pricing" if tier != Tier.ENTERPRISE else None
),
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
def get_api_key_info(request: Request) -> dict[str, Any]:
"""Validate API key and return info."""
api_key = request.headers.get("X-API-Key")
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
key_info = API_KEYS.get(api_key)
if not key_info:
raise HTTPException(status_code=401, detail="Invalid API key")
return {"api_key": api_key, **key_info}
def get_tier_config(key_info: dict[str, Any] = Depends(get_api_key_info)) -> TierConfig:
"""Get rate limit config for user's tier."""
tier = key_info.get("tier", Tier.FREE)
return TIER_CONFIGS[tier]
# Create rate limit dependencies for each tier
tier_rate_limits: dict[Tier, RateLimitDependency] = {}
for tier, config in TIER_CONFIGS.items():
tier_rate_limits[tier] = RateLimitDependency(
limit=config.requests_per_minute,
window_size=60,
algorithm=Algorithm.TOKEN_BUCKET,
burst_size=config.burst_size,
key_prefix=f"tier_{tier.value}",
)
def api_key_extractor(request: Request) -> str:
"""Extract API key for rate limiting."""
api_key = request.headers.get("X-API-Key", "anonymous")
return f"api:{api_key}"
async def apply_tier_rate_limit(
request: Request,
key_info: dict[str, Any] = Depends(get_api_key_info),
) -> dict[str, Any]:
"""Apply rate limit based on user's tier."""
tier = key_info.get("tier", Tier.FREE)
rate_limit_dep = tier_rate_limits[tier]
rate_info = await rate_limit_dep(request)
return {
"tier": tier,
"config": TIER_CONFIGS[tier],
"rate_info": rate_info,
"key_info": key_info,
}
@app.get("/api/v1/data")
async def get_data(
_: Request,
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
) -> dict[str, Any]:
"""Get data with tier-based rate limiting."""
return {
"data": {"items": ["item1", "item2", "item3"]},
"tier": limit_info["tier"].value,
"rate_limit": {
"limit": limit_info["rate_info"].limit,
"remaining": limit_info["rate_info"].remaining,
"reset_at": limit_info["rate_info"].reset_at,
},
}
@app.get("/api/v1/analytics")
async def get_analytics(
_: Request,
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
) -> dict[str, Any]:
"""Analytics endpoint - requires Pro tier or higher."""
tier = limit_info["tier"]
config = limit_info["config"]
if "analytics" not in config.features:
raise HTTPException(
status_code=403,
detail=f"Analytics requires Pro tier or higher. Current tier: {tier.value}",
)
return {
"analytics": {
"total_requests": 12345,
"unique_users": 567,
"avg_response_time_ms": 45,
},
"tier": tier.value,
}
@app.get("/api/v1/tier-info")
async def get_tier_info(
key_info: dict[str, Any] = Depends(get_api_key_info),
) -> dict[str, Any]:
"""Get information about current tier and limits."""
tier = key_info.get("tier", Tier.FREE)
config = TIER_CONFIGS[tier]
return {
"tier": tier.value,
"limits": {
"requests_per_minute": config.requests_per_minute,
"requests_per_hour": config.requests_per_hour,
"requests_per_day": config.requests_per_day,
"burst_size": config.burst_size,
},
"features": config.features,
"upgrade_options": [
t.value
for t in Tier
if TIER_CONFIGS[t].requests_per_minute > config.requests_per_minute
],
}
@app.get("/pricing")
async def pricing() -> dict[str, Any]:
"""Public pricing information."""
return {
"tiers": {
tier.value: {
"requests_per_minute": config.requests_per_minute,
"requests_per_day": config.requests_per_day,
"features": config.features,
}
for tier, config in TIER_CONFIGS.items()
}
}
if __name__ == "__main__":
import uvicorn
# Test with different API keys:
# curl -H "X-API-Key: free-key-123" http://localhost:8000/api/v1/data
# curl -H "X-API-Key: pro-key-789" http://localhost:8000/api/v1/analytics
uvicorn.run(app, host="0.0.0.0", port=8000)