Files
fastapi-traffic/fastapi_traffic/core/decorator.py

268 lines
7.7 KiB
Python

"""Rate limit decorator for FastAPI endpoints."""
from __future__ import annotations
import functools
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeVar, overload
from fastapi_traffic.core.algorithms import Algorithm
from fastapi_traffic.core.config import (
KeyExtractor,
RateLimitConfig,
default_key_extractor,
)
from fastapi_traffic.core.limiter import get_limiter
if TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
from fastapi_traffic.exceptions import RateLimitExceeded
F = TypeVar("F", bound=Callable[..., Any])
# Note: Config loader from secrets .env
@overload
def rate_limit(
limit: int,
*,
window_size: float = ...,
algorithm: Algorithm = ...,
key_prefix: str = ...,
key_extractor: KeyExtractor = ...,
burst_size: int | None = ...,
include_headers: bool = ...,
error_message: str = ...,
status_code: int = ...,
skip_on_error: bool = ...,
cost: int = ...,
exempt_when: Callable[[Request], bool] | None = ...,
on_blocked: Callable[[Request, Any], Any] | None = ...,
) -> Callable[[F], F]: ...
@overload
def rate_limit(
limit: int,
window_size: float,
/,
) -> Callable[[F], F]: ...
def rate_limit(
limit: int,
window_size: float = 60.0,
*,
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
key_prefix: str = "ratelimit",
key_extractor: KeyExtractor = default_key_extractor,
burst_size: int | None = None,
include_headers: bool = True,
error_message: str = "Rate limit exceeded",
status_code: int = 429,
skip_on_error: bool = False,
cost: int = 1,
exempt_when: Callable[[Request], bool] | None = None,
on_blocked: Callable[[Request, Any], Any] | None = None,
) -> Callable[[F], F]:
"""Decorator to apply rate limiting to a FastAPI endpoint.
Args:
limit: Maximum number of requests allowed in the window.
window_size: Time window in seconds.
algorithm: Rate limiting algorithm to use.
key_prefix: Prefix for the rate limit key.
key_extractor: Function to extract the client identifier from request.
burst_size: Maximum burst size (for token bucket/leaky bucket).
include_headers: Whether to include rate limit headers in response.
error_message: Error message when rate limit is exceeded.
status_code: HTTP status code when rate limit is exceeded.
skip_on_error: Skip rate limiting if backend errors occur.
cost: Cost of each request (default 1).
exempt_when: Function to determine if request should be exempt.
on_blocked: Callback when a request is blocked.
Returns:
Decorated function with rate limiting applied.
Example:
```python
from fastapi import FastAPI
from fastapi_traffic import rate_limit
app = FastAPI()
@app.get("/api/resource")
@rate_limit(100, 60) # 100 requests per minute
async def get_resource():
return {"message": "Hello"}
```
"""
config = RateLimitConfig(
limit=limit,
window_size=window_size,
algorithm=algorithm,
key_prefix=key_prefix,
key_extractor=key_extractor,
burst_size=burst_size,
include_headers=include_headers,
error_message=error_message,
status_code=status_code,
skip_on_error=skip_on_error,
cost=cost,
exempt_when=exempt_when,
on_blocked=on_blocked,
)
def decorator(func: F) -> F:
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
request = _extract_request(args, kwargs)
if request is None:
return await func(*args, **kwargs)
limiter = get_limiter()
result = await limiter.hit(request, config)
response = await func(*args, **kwargs)
if config.include_headers and hasattr(response, "headers"):
for key, value in result.info.to_headers().items():
response.headers[key] = value
return response
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
import asyncio
return asyncio.get_event_loop().run_until_complete(
async_wrapper(*args, **kwargs)
)
if _is_coroutine_function(func):
return async_wrapper # type: ignore[return-value]
return sync_wrapper # type: ignore[return-value]
return decorator
def _extract_request(
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Request | None:
"""Extract the Request object from function arguments."""
from starlette.requests import Request
for arg in args:
if isinstance(arg, Request):
return arg
for value in kwargs.values():
if isinstance(value, Request):
return value
if "request" in kwargs:
req = kwargs["request"]
if isinstance(req, Request):
return req
return None
def _is_coroutine_function(func: Callable[..., Any]) -> bool:
"""Check if a function is a coroutine function."""
import asyncio
import inspect
return asyncio.iscoroutinefunction(func) or inspect.iscoroutinefunction(func)
class RateLimitDependency:
"""FastAPI dependency for rate limiting.
Example:
```python
from fastapi import FastAPI, Depends
from fastapi_traffic import RateLimitDependency
app = FastAPI()
rate_limiter = RateLimitDependency(limit=100, window_size=60)
@app.get("/api/resource")
async def get_resource(rate_limit_info = Depends(rate_limiter)):
return {"remaining": rate_limit_info.remaining}
```
"""
__slots__ = ("_config",)
def __init__(
self,
limit: int,
window_size: float = 60.0,
*,
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
key_prefix: str = "ratelimit",
key_extractor: KeyExtractor = default_key_extractor,
burst_size: int | None = None,
error_message: str = "Rate limit exceeded",
status_code: int = 429,
skip_on_error: bool = False,
cost: int = 1,
exempt_when: Callable[[Request], bool] | None = None,
) -> None:
self._config = RateLimitConfig(
limit=limit,
window_size=window_size,
algorithm=algorithm,
key_prefix=key_prefix,
key_extractor=key_extractor,
burst_size=burst_size,
include_headers=True,
error_message=error_message,
status_code=status_code,
skip_on_error=skip_on_error,
cost=cost,
exempt_when=exempt_when,
)
async def __call__(self, request: Request) -> Any:
"""Check rate limit and return info."""
limiter = get_limiter()
result = await limiter.hit(request, self._config)
return result.info
def create_rate_limit_response(
exc: RateLimitExceeded,
*,
include_headers: bool = True,
) -> Response:
"""Create a rate limit exceeded response.
Args:
exc: The RateLimitExceeded exception.
include_headers: Whether to include rate limit headers.
Returns:
A JSONResponse with rate limit information.
"""
from starlette.responses import JSONResponse
headers: dict[str, str] = {}
if include_headers and exc.limit_info is not None:
headers = exc.limit_info.to_headers()
return JSONResponse(
status_code=429,
content={
"detail": exc.message,
"retry_after": exc.retry_after,
},
headers=headers,
)