268 lines
7.7 KiB
Python
268 lines
7.7 KiB
Python
"""Rate limit decorator for FastAPI endpoints."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
from collections.abc import Callable
|
|
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
|
|
|
from fastapi_traffic.core.algorithms import Algorithm
|
|
from fastapi_traffic.core.config import (
|
|
KeyExtractor,
|
|
RateLimitConfig,
|
|
default_key_extractor,
|
|
)
|
|
from fastapi_traffic.core.limiter import get_limiter
|
|
|
|
if TYPE_CHECKING:
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
from fastapi_traffic.exceptions import RateLimitExceeded
|
|
|
|
F = TypeVar("F", bound=Callable[..., Any])
|
|
|
|
# Note: Config loader from secrets .env
|
|
|
|
|
|
@overload
|
|
def rate_limit(
|
|
limit: int,
|
|
*,
|
|
window_size: float = ...,
|
|
algorithm: Algorithm = ...,
|
|
key_prefix: str = ...,
|
|
key_extractor: KeyExtractor = ...,
|
|
burst_size: int | None = ...,
|
|
include_headers: bool = ...,
|
|
error_message: str = ...,
|
|
status_code: int = ...,
|
|
skip_on_error: bool = ...,
|
|
cost: int = ...,
|
|
exempt_when: Callable[[Request], bool] | None = ...,
|
|
on_blocked: Callable[[Request, Any], Any] | None = ...,
|
|
) -> Callable[[F], F]: ...
|
|
|
|
|
|
@overload
|
|
def rate_limit(
|
|
limit: int,
|
|
window_size: float,
|
|
/,
|
|
) -> Callable[[F], F]: ...
|
|
|
|
|
|
def rate_limit(
|
|
limit: int,
|
|
window_size: float = 60.0,
|
|
*,
|
|
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
|
key_prefix: str = "ratelimit",
|
|
key_extractor: KeyExtractor = default_key_extractor,
|
|
burst_size: int | None = None,
|
|
include_headers: bool = True,
|
|
error_message: str = "Rate limit exceeded",
|
|
status_code: int = 429,
|
|
skip_on_error: bool = False,
|
|
cost: int = 1,
|
|
exempt_when: Callable[[Request], bool] | None = None,
|
|
on_blocked: Callable[[Request, Any], Any] | None = None,
|
|
) -> Callable[[F], F]:
|
|
"""Decorator to apply rate limiting to a FastAPI endpoint.
|
|
|
|
Args:
|
|
limit: Maximum number of requests allowed in the window.
|
|
window_size: Time window in seconds.
|
|
algorithm: Rate limiting algorithm to use.
|
|
key_prefix: Prefix for the rate limit key.
|
|
key_extractor: Function to extract the client identifier from request.
|
|
burst_size: Maximum burst size (for token bucket/leaky bucket).
|
|
include_headers: Whether to include rate limit headers in response.
|
|
error_message: Error message when rate limit is exceeded.
|
|
status_code: HTTP status code when rate limit is exceeded.
|
|
skip_on_error: Skip rate limiting if backend errors occur.
|
|
cost: Cost of each request (default 1).
|
|
exempt_when: Function to determine if request should be exempt.
|
|
on_blocked: Callback when a request is blocked.
|
|
|
|
Returns:
|
|
Decorated function with rate limiting applied.
|
|
|
|
Example:
|
|
```python
|
|
from fastapi import FastAPI
|
|
from fastapi_traffic import rate_limit
|
|
|
|
app = FastAPI()
|
|
|
|
@app.get("/api/resource")
|
|
@rate_limit(100, 60) # 100 requests per minute
|
|
async def get_resource():
|
|
return {"message": "Hello"}
|
|
```
|
|
"""
|
|
config = RateLimitConfig(
|
|
limit=limit,
|
|
window_size=window_size,
|
|
algorithm=algorithm,
|
|
key_prefix=key_prefix,
|
|
key_extractor=key_extractor,
|
|
burst_size=burst_size,
|
|
include_headers=include_headers,
|
|
error_message=error_message,
|
|
status_code=status_code,
|
|
skip_on_error=skip_on_error,
|
|
cost=cost,
|
|
exempt_when=exempt_when,
|
|
on_blocked=on_blocked,
|
|
)
|
|
|
|
def decorator(func: F) -> F:
|
|
@functools.wraps(func)
|
|
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
request = _extract_request(args, kwargs)
|
|
if request is None:
|
|
return await func(*args, **kwargs)
|
|
|
|
limiter = get_limiter()
|
|
result = await limiter.hit(request, config)
|
|
|
|
response = await func(*args, **kwargs)
|
|
|
|
if config.include_headers and hasattr(response, "headers"):
|
|
for key, value in result.info.to_headers().items():
|
|
response.headers[key] = value
|
|
|
|
return response
|
|
|
|
@functools.wraps(func)
|
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
import asyncio
|
|
|
|
return asyncio.get_event_loop().run_until_complete(
|
|
async_wrapper(*args, **kwargs)
|
|
)
|
|
|
|
if _is_coroutine_function(func):
|
|
return async_wrapper # type: ignore[return-value]
|
|
return sync_wrapper # type: ignore[return-value]
|
|
|
|
return decorator
|
|
|
|
|
|
def _extract_request(
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
) -> Request | None:
|
|
"""Extract the Request object from function arguments."""
|
|
from starlette.requests import Request
|
|
|
|
for arg in args:
|
|
if isinstance(arg, Request):
|
|
return arg
|
|
|
|
for value in kwargs.values():
|
|
if isinstance(value, Request):
|
|
return value
|
|
|
|
if "request" in kwargs:
|
|
req = kwargs["request"]
|
|
if isinstance(req, Request):
|
|
return req
|
|
|
|
return None
|
|
|
|
|
|
def _is_coroutine_function(func: Callable[..., Any]) -> bool:
|
|
"""Check if a function is a coroutine function."""
|
|
import asyncio
|
|
import inspect
|
|
|
|
return asyncio.iscoroutinefunction(func) or inspect.iscoroutinefunction(func)
|
|
|
|
|
|
class RateLimitDependency:
|
|
"""FastAPI dependency for rate limiting.
|
|
|
|
Example:
|
|
```python
|
|
from fastapi import FastAPI, Depends
|
|
from fastapi_traffic import RateLimitDependency
|
|
|
|
app = FastAPI()
|
|
rate_limiter = RateLimitDependency(limit=100, window_size=60)
|
|
|
|
@app.get("/api/resource")
|
|
async def get_resource(rate_limit_info = Depends(rate_limiter)):
|
|
return {"remaining": rate_limit_info.remaining}
|
|
```
|
|
"""
|
|
|
|
__slots__ = ("_config",)
|
|
|
|
def __init__(
|
|
self,
|
|
limit: int,
|
|
window_size: float = 60.0,
|
|
*,
|
|
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER,
|
|
key_prefix: str = "ratelimit",
|
|
key_extractor: KeyExtractor = default_key_extractor,
|
|
burst_size: int | None = None,
|
|
error_message: str = "Rate limit exceeded",
|
|
status_code: int = 429,
|
|
skip_on_error: bool = False,
|
|
cost: int = 1,
|
|
exempt_when: Callable[[Request], bool] | None = None,
|
|
) -> None:
|
|
self._config = RateLimitConfig(
|
|
limit=limit,
|
|
window_size=window_size,
|
|
algorithm=algorithm,
|
|
key_prefix=key_prefix,
|
|
key_extractor=key_extractor,
|
|
burst_size=burst_size,
|
|
include_headers=True,
|
|
error_message=error_message,
|
|
status_code=status_code,
|
|
skip_on_error=skip_on_error,
|
|
cost=cost,
|
|
exempt_when=exempt_when,
|
|
)
|
|
|
|
async def __call__(self, request: Request) -> Any:
|
|
"""Check rate limit and return info."""
|
|
limiter = get_limiter()
|
|
result = await limiter.hit(request, self._config)
|
|
return result.info
|
|
|
|
|
|
def create_rate_limit_response(
|
|
exc: RateLimitExceeded,
|
|
*,
|
|
include_headers: bool = True,
|
|
) -> Response:
|
|
"""Create a rate limit exceeded response.
|
|
|
|
Args:
|
|
exc: The RateLimitExceeded exception.
|
|
include_headers: Whether to include rate limit headers.
|
|
|
|
Returns:
|
|
A JSONResponse with rate limit information.
|
|
"""
|
|
from starlette.responses import JSONResponse
|
|
|
|
headers: dict[str, str] = {}
|
|
if include_headers and exc.limit_info is not None:
|
|
headers = exc.limit_info.to_headers()
|
|
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"detail": exc.message,
|
|
"retry_after": exc.retry_after,
|
|
},
|
|
headers=headers,
|
|
)
|