270 lines
7.4 KiB
Python
270 lines
7.4 KiB
Python
"""Example of a production-ready tiered API with different rate limits per plan."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from fastapi import Depends, FastAPI, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from fastapi_traffic import (
|
|
Algorithm,
|
|
MemoryBackend,
|
|
RateLimiter,
|
|
RateLimitExceeded,
|
|
)
|
|
from fastapi_traffic.core.decorator import RateLimitDependency
|
|
from fastapi_traffic.core.limiter import set_limiter
|
|
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_: FastAPI):
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield
|
|
await limiter.close()
|
|
|
|
|
|
app = FastAPI(
|
|
title="Tiered API Example",
|
|
description="API with different rate limits based on subscription tier",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
|
|
class Tier(str, Enum):
|
|
FREE = "free"
|
|
STARTER = "starter"
|
|
PRO = "pro"
|
|
ENTERPRISE = "enterprise"
|
|
|
|
|
|
@dataclass
|
|
class TierConfig:
|
|
requests_per_minute: int
|
|
requests_per_hour: int
|
|
requests_per_day: int
|
|
burst_size: int
|
|
features: list[str]
|
|
|
|
|
|
# Tier configurations
|
|
TIER_CONFIGS: dict[Tier, TierConfig] = {
|
|
Tier.FREE: TierConfig(
|
|
requests_per_minute=10,
|
|
requests_per_hour=100,
|
|
requests_per_day=500,
|
|
burst_size=5,
|
|
features=["basic_api"],
|
|
),
|
|
Tier.STARTER: TierConfig(
|
|
requests_per_minute=60,
|
|
requests_per_hour=1000,
|
|
requests_per_day=10000,
|
|
burst_size=20,
|
|
features=["basic_api", "webhooks"],
|
|
),
|
|
Tier.PRO: TierConfig(
|
|
requests_per_minute=300,
|
|
requests_per_hour=10000,
|
|
requests_per_day=100000,
|
|
burst_size=50,
|
|
features=["basic_api", "webhooks", "analytics", "priority_support"],
|
|
),
|
|
Tier.ENTERPRISE: TierConfig(
|
|
requests_per_minute=1000,
|
|
requests_per_hour=50000,
|
|
requests_per_day=500000,
|
|
burst_size=200,
|
|
features=[
|
|
"basic_api",
|
|
"webhooks",
|
|
"analytics",
|
|
"priority_support",
|
|
"sla",
|
|
"custom_integrations",
|
|
],
|
|
),
|
|
}
|
|
|
|
|
|
# Simulated API key database
|
|
API_KEYS: dict[str, dict[str, Any]] = {
|
|
"free-key-123": {"tier": Tier.FREE, "user_id": "user_1"},
|
|
"starter-key-456": {"tier": Tier.STARTER, "user_id": "user_2"},
|
|
"pro-key-789": {"tier": Tier.PRO, "user_id": "user_3"},
|
|
"enterprise-key-000": {"tier": Tier.ENTERPRISE, "user_id": "user_4"},
|
|
}
|
|
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
api_key = request.headers.get("X-API-Key", "")
|
|
key_info = API_KEYS.get(api_key, {})
|
|
tier = key_info.get("tier", Tier.FREE)
|
|
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "rate_limit_exceeded",
|
|
"message": exc.message,
|
|
"retry_after": exc.retry_after,
|
|
"tier": tier.value,
|
|
"upgrade_url": (
|
|
"https://example.com/pricing" if tier != Tier.ENTERPRISE else None
|
|
),
|
|
},
|
|
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
|
)
|
|
|
|
|
|
def get_api_key_info(request: Request) -> dict[str, Any]:
|
|
"""Validate API key and return info."""
|
|
api_key = request.headers.get("X-API-Key")
|
|
if not api_key:
|
|
raise HTTPException(status_code=401, detail="API key required")
|
|
|
|
key_info = API_KEYS.get(api_key)
|
|
if not key_info:
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
|
|
return {"api_key": api_key, **key_info}
|
|
|
|
|
|
def get_tier_config(key_info: dict[str, Any] = Depends(get_api_key_info)) -> TierConfig:
|
|
"""Get rate limit config for user's tier."""
|
|
tier = key_info.get("tier", Tier.FREE)
|
|
return TIER_CONFIGS[tier]
|
|
|
|
|
|
# Create rate limit dependencies for each tier
|
|
tier_rate_limits: dict[Tier, RateLimitDependency] = {}
|
|
for tier, config in TIER_CONFIGS.items():
|
|
tier_rate_limits[tier] = RateLimitDependency(
|
|
limit=config.requests_per_minute,
|
|
window_size=60,
|
|
algorithm=Algorithm.TOKEN_BUCKET,
|
|
burst_size=config.burst_size,
|
|
key_prefix=f"tier_{tier.value}",
|
|
)
|
|
|
|
|
|
def api_key_extractor(request: Request) -> str:
|
|
"""Extract API key for rate limiting."""
|
|
api_key = request.headers.get("X-API-Key", "anonymous")
|
|
return f"api:{api_key}"
|
|
|
|
|
|
async def apply_tier_rate_limit(
|
|
request: Request,
|
|
key_info: dict[str, Any] = Depends(get_api_key_info),
|
|
) -> dict[str, Any]:
|
|
"""Apply rate limit based on user's tier."""
|
|
tier = key_info.get("tier", Tier.FREE)
|
|
rate_limit_dep = tier_rate_limits[tier]
|
|
rate_info = await rate_limit_dep(request)
|
|
|
|
return {
|
|
"tier": tier,
|
|
"config": TIER_CONFIGS[tier],
|
|
"rate_info": rate_info,
|
|
"key_info": key_info,
|
|
}
|
|
|
|
|
|
@app.get("/api/v1/data")
|
|
async def get_data(
|
|
_: Request,
|
|
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
|
|
) -> dict[str, Any]:
|
|
"""Get data with tier-based rate limiting."""
|
|
return {
|
|
"data": {"items": ["item1", "item2", "item3"]},
|
|
"tier": limit_info["tier"].value,
|
|
"rate_limit": {
|
|
"limit": limit_info["rate_info"].limit,
|
|
"remaining": limit_info["rate_info"].remaining,
|
|
"reset_at": limit_info["rate_info"].reset_at,
|
|
},
|
|
}
|
|
|
|
|
|
@app.get("/api/v1/analytics")
|
|
async def get_analytics(
|
|
_: Request,
|
|
limit_info: dict[str, Any] = Depends(apply_tier_rate_limit),
|
|
) -> dict[str, Any]:
|
|
"""Analytics endpoint - requires Pro tier or higher."""
|
|
tier = limit_info["tier"]
|
|
config = limit_info["config"]
|
|
|
|
if "analytics" not in config.features:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Analytics requires Pro tier or higher. Current tier: {tier.value}",
|
|
)
|
|
|
|
return {
|
|
"analytics": {
|
|
"total_requests": 12345,
|
|
"unique_users": 567,
|
|
"avg_response_time_ms": 45,
|
|
},
|
|
"tier": tier.value,
|
|
}
|
|
|
|
|
|
@app.get("/api/v1/tier-info")
|
|
async def get_tier_info(
|
|
key_info: dict[str, Any] = Depends(get_api_key_info),
|
|
) -> dict[str, Any]:
|
|
"""Get information about current tier and limits."""
|
|
tier = key_info.get("tier", Tier.FREE)
|
|
config = TIER_CONFIGS[tier]
|
|
|
|
return {
|
|
"tier": tier.value,
|
|
"limits": {
|
|
"requests_per_minute": config.requests_per_minute,
|
|
"requests_per_hour": config.requests_per_hour,
|
|
"requests_per_day": config.requests_per_day,
|
|
"burst_size": config.burst_size,
|
|
},
|
|
"features": config.features,
|
|
"upgrade_options": [
|
|
t.value
|
|
for t in Tier
|
|
if TIER_CONFIGS[t].requests_per_minute > config.requests_per_minute
|
|
],
|
|
}
|
|
|
|
|
|
@app.get("/pricing")
|
|
async def pricing() -> dict[str, Any]:
|
|
"""Public pricing information."""
|
|
return {
|
|
"tiers": {
|
|
tier.value: {
|
|
"requests_per_minute": config.requests_per_minute,
|
|
"requests_per_day": config.requests_per_day,
|
|
"features": config.features,
|
|
}
|
|
for tier, config in TIER_CONFIGS.items()
|
|
}
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# Test with different API keys:
|
|
# curl -H "X-API-Key: free-key-123" http://localhost:8000/api/v1/data
|
|
# curl -H "X-API-Key: pro-key-789" http://localhost:8000/api/v1/analytics
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|