191 lines
6.0 KiB
Python
191 lines
6.0 KiB
Python
"""Rate limiting middleware for Starlette/FastAPI applications."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
|
|
from fastapi_traffic.backends.memory import MemoryBackend
|
|
from fastapi_traffic.core.algorithms import Algorithm
|
|
from fastapi_traffic.core.config import (
|
|
GlobalConfig,
|
|
RateLimitConfig,
|
|
default_key_extractor,
|
|
)
|
|
from fastapi_traffic.core.limiter import RateLimiter
|
|
from fastapi_traffic.exceptions import RateLimitExceeded
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Awaitable, Callable
|
|
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.types import ASGIApp
|
|
|
|
from fastapi_traffic.backends.base import Backend
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware for global rate limiting across all endpoints."""
|
|
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
*,
|
|
limit: int = 100,
|
|
window_size: float = 60.0,
|
|
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
|
backend: Backend | None = None,
|
|
key_prefix: str = "middleware",
|
|
include_headers: bool = True,
|
|
error_message: str = "Rate limit exceeded. Please try again later.",
|
|
status_code: int = 429,
|
|
skip_on_error: bool = False,
|
|
exempt_paths: set[str] | None = None,
|
|
exempt_ips: set[str] | None = None,
|
|
key_extractor: Callable[[Request], str] = default_key_extractor,
|
|
) -> None:
|
|
"""Initialize the rate limit middleware.
|
|
|
|
Args:
|
|
app: The ASGI application.
|
|
limit: Maximum requests per window.
|
|
window_size: Time window in seconds.
|
|
algorithm: Rate limiting algorithm.
|
|
backend: Storage backend (defaults to MemoryBackend).
|
|
key_prefix: Prefix for rate limit keys.
|
|
include_headers: Include rate limit headers in response.
|
|
error_message: Error message when rate limited.
|
|
status_code: HTTP status code when rate limited.
|
|
skip_on_error: Skip rate limiting on backend errors.
|
|
exempt_paths: Paths to exempt from rate limiting.
|
|
exempt_ips: IP addresses to exempt from rate limiting.
|
|
key_extractor: Function to extract client identifier.
|
|
"""
|
|
super().__init__(app)
|
|
|
|
self._backend = backend or MemoryBackend()
|
|
self._config = RateLimitConfig(
|
|
limit=limit,
|
|
window_size=window_size,
|
|
algorithm=algorithm,
|
|
key_prefix=key_prefix,
|
|
key_extractor=key_extractor,
|
|
include_headers=include_headers,
|
|
error_message=error_message,
|
|
status_code=status_code,
|
|
skip_on_error=skip_on_error,
|
|
)
|
|
|
|
global_config = GlobalConfig(
|
|
backend=self._backend,
|
|
exempt_paths=exempt_paths or set(),
|
|
exempt_ips=exempt_ips or set(),
|
|
)
|
|
|
|
self._limiter = RateLimiter(self._backend, config=global_config)
|
|
self._include_headers = include_headers
|
|
self._error_message = error_message
|
|
self._status_code = status_code
|
|
|
|
async def dispatch(
|
|
self,
|
|
request: Request,
|
|
call_next: Callable[[Request], Awaitable[Response]],
|
|
) -> Response:
|
|
"""Process the request with rate limiting."""
|
|
try:
|
|
result = await self._limiter.check(request, self._config)
|
|
|
|
if not result.allowed:
|
|
return self._create_rate_limit_response(result)
|
|
|
|
response = await call_next(request)
|
|
|
|
if self._include_headers:
|
|
for key, value in result.info.to_headers().items():
|
|
response.headers[key] = value
|
|
|
|
return response
|
|
|
|
except RateLimitExceeded as exc:
|
|
return JSONResponse(
|
|
status_code=self._status_code,
|
|
content={
|
|
"detail": exc.message,
|
|
"retry_after": exc.retry_after,
|
|
},
|
|
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("Error in rate limit middleware: %s", e)
|
|
if self._config.skip_on_error:
|
|
return await call_next(request)
|
|
raise
|
|
|
|
def _create_rate_limit_response(self, result: object) -> JSONResponse:
|
|
"""Create a rate limit exceeded response."""
|
|
from fastapi_traffic.core.models import RateLimitResult
|
|
|
|
if isinstance(result, RateLimitResult):
|
|
headers = result.info.to_headers()
|
|
retry_after = result.info.retry_after
|
|
else:
|
|
headers = {}
|
|
retry_after = None
|
|
|
|
return JSONResponse(
|
|
status_code=self._status_code,
|
|
content={
|
|
"detail": self._error_message,
|
|
"retry_after": retry_after,
|
|
},
|
|
headers=headers,
|
|
)
|
|
|
|
|
|
class SlidingWindowMiddleware(RateLimitMiddleware):
|
|
"""Convenience middleware using sliding window algorithm."""
|
|
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
*,
|
|
limit: int = 100,
|
|
window_size: float = 60.0,
|
|
**kwargs: object,
|
|
) -> None:
|
|
super().__init__(
|
|
app,
|
|
limit=limit,
|
|
window_size=window_size,
|
|
algorithm=Algorithm.SLIDING_WINDOW,
|
|
**kwargs, # type: ignore[arg-type]
|
|
)
|
|
|
|
|
|
class TokenBucketMiddleware(RateLimitMiddleware):
|
|
"""Convenience middleware using token bucket algorithm."""
|
|
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
*,
|
|
limit: int = 100,
|
|
window_size: float = 60.0,
|
|
**kwargs: object,
|
|
) -> None:
|
|
super().__init__(
|
|
app,
|
|
limit=limit,
|
|
window_size=window_size,
|
|
algorithm=Algorithm.TOKEN_BUCKET,
|
|
**kwargs, # type: ignore[arg-type]
|
|
)
|