Files
fastapi-traffic/fastapi_traffic/middleware.py

191 lines
6.0 KiB
Python

"""Rate limiting middleware for Starlette/FastAPI applications."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi_traffic.backends.memory import MemoryBackend
from fastapi_traffic.core.algorithms import Algorithm
from fastapi_traffic.core.config import (
GlobalConfig,
RateLimitConfig,
default_key_extractor,
)
from fastapi_traffic.core.limiter import RateLimiter
from fastapi_traffic.exceptions import RateLimitExceeded
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp
from fastapi_traffic.backends.base import Backend
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware for global rate limiting across all endpoints."""
def __init__(
self,
app: ASGIApp,
*,
limit: int = 100,
window_size: float = 60.0,
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
backend: Backend | None = None,
key_prefix: str = "middleware",
include_headers: bool = True,
error_message: str = "Rate limit exceeded. Please try again later.",
status_code: int = 429,
skip_on_error: bool = False,
exempt_paths: set[str] | None = None,
exempt_ips: set[str] | None = None,
key_extractor: Callable[[Request], str] = default_key_extractor,
) -> None:
"""Initialize the rate limit middleware.
Args:
app: The ASGI application.
limit: Maximum requests per window.
window_size: Time window in seconds.
algorithm: Rate limiting algorithm.
backend: Storage backend (defaults to MemoryBackend).
key_prefix: Prefix for rate limit keys.
include_headers: Include rate limit headers in response.
error_message: Error message when rate limited.
status_code: HTTP status code when rate limited.
skip_on_error: Skip rate limiting on backend errors.
exempt_paths: Paths to exempt from rate limiting.
exempt_ips: IP addresses to exempt from rate limiting.
key_extractor: Function to extract client identifier.
"""
super().__init__(app)
self._backend = backend or MemoryBackend()
self._config = RateLimitConfig(
limit=limit,
window_size=window_size,
algorithm=algorithm,
key_prefix=key_prefix,
key_extractor=key_extractor,
include_headers=include_headers,
error_message=error_message,
status_code=status_code,
skip_on_error=skip_on_error,
)
global_config = GlobalConfig(
backend=self._backend,
exempt_paths=exempt_paths or set(),
exempt_ips=exempt_ips or set(),
)
self._limiter = RateLimiter(self._backend, config=global_config)
self._include_headers = include_headers
self._error_message = error_message
self._status_code = status_code
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Process the request with rate limiting."""
try:
result = await self._limiter.check(request, self._config)
if not result.allowed:
return self._create_rate_limit_response(result)
response = await call_next(request)
if self._include_headers:
for key, value in result.info.to_headers().items():
response.headers[key] = value
return response
except RateLimitExceeded as exc:
return JSONResponse(
status_code=self._status_code,
content={
"detail": exc.message,
"retry_after": exc.retry_after,
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
except Exception as e:
logger.exception("Error in rate limit middleware: %s", e)
if self._config.skip_on_error:
return await call_next(request)
raise
def _create_rate_limit_response(self, result: object) -> JSONResponse:
"""Create a rate limit exceeded response."""
from fastapi_traffic.core.models import RateLimitResult
if isinstance(result, RateLimitResult):
headers = result.info.to_headers()
retry_after = result.info.retry_after
else:
headers = {}
retry_after = None
return JSONResponse(
status_code=self._status_code,
content={
"detail": self._error_message,
"retry_after": retry_after,
},
headers=headers,
)
class SlidingWindowMiddleware(RateLimitMiddleware):
"""Convenience middleware using sliding window algorithm."""
def __init__(
self,
app: ASGIApp,
*,
limit: int = 100,
window_size: float = 60.0,
**kwargs: object,
) -> None:
super().__init__(
app,
limit=limit,
window_size=window_size,
algorithm=Algorithm.SLIDING_WINDOW,
**kwargs, # type: ignore[arg-type]
)
class TokenBucketMiddleware(RateLimitMiddleware):
"""Convenience middleware using token bucket algorithm."""
def __init__(
self,
app: ASGIApp,
*,
limit: int = 100,
window_size: float = 60.0,
**kwargs: object,
) -> None:
super().__init__(
app,
limit=limit,
window_size=window_size,
algorithm=Algorithm.TOKEN_BUCKET,
**kwargs, # type: ignore[arg-type]
)