83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
"""Configuration for rate limiting."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from fastapi_traffic.core.algorithms import Algorithm
|
|
|
|
if TYPE_CHECKING:
|
|
from starlette.requests import Request
|
|
|
|
from fastapi_traffic.backends.base import Backend
|
|
|
|
|
|
KeyExtractor = Callable[["Request"], str]
|
|
|
|
|
|
def default_key_extractor(request: Request) -> str:
|
|
"""Extract client IP as the default rate limit key."""
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class RateLimitConfig:
|
|
"""Configuration for a rate limit rule."""
|
|
|
|
limit: int
|
|
window_size: float = 60.0
|
|
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
|
key_prefix: str = "ratelimit"
|
|
key_extractor: KeyExtractor = field(default=default_key_extractor)
|
|
burst_size: int | None = None
|
|
include_headers: bool = True
|
|
error_message: str = "Rate limit exceeded"
|
|
status_code: int = 429
|
|
skip_on_error: bool = False
|
|
cost: int = 1
|
|
exempt_when: Callable[[Request], bool] | None = None
|
|
on_blocked: Callable[[Request, Any], Any] | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.limit <= 0:
|
|
msg = "limit must be positive"
|
|
raise ValueError(msg)
|
|
if self.window_size <= 0:
|
|
msg = "window_size must be positive"
|
|
raise ValueError(msg)
|
|
if self.cost <= 0:
|
|
msg = "cost must be positive"
|
|
raise ValueError(msg)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class GlobalConfig:
|
|
"""Global configuration for the rate limiter."""
|
|
|
|
backend: Backend | None = None
|
|
enabled: bool = True
|
|
default_limit: int = 100
|
|
default_window_size: float = 60.0
|
|
default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
|
|
key_prefix: str = "fastapi_traffic"
|
|
include_headers: bool = True
|
|
error_message: str = "Rate limit exceeded. Please try again later."
|
|
status_code: int = 429
|
|
skip_on_error: bool = False
|
|
exempt_ips: set[str] = field(default_factory=set)
|
|
exempt_paths: set[str] = field(default_factory=set)
|
|
headers_prefix: str = "X-RateLimit"
|