Files
fastapi-traffic/fastapi_traffic/core/config.py

83 lines
2.4 KiB
Python

"""Configuration for rate limiting."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from fastapi_traffic.core.algorithms import Algorithm
if TYPE_CHECKING:
from starlette.requests import Request
from fastapi_traffic.backends.base import Backend
KeyExtractor = Callable[["Request"], str]
def default_key_extractor(request: Request) -> str:
"""Extract client IP as the default rate limit key."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
@dataclass(slots=True)
class RateLimitConfig:
"""Configuration for a rate limit rule."""
limit: int
window_size: float = 60.0
algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
key_prefix: str = "ratelimit"
key_extractor: KeyExtractor = field(default=default_key_extractor)
burst_size: int | None = None
include_headers: bool = True
error_message: str = "Rate limit exceeded"
status_code: int = 429
skip_on_error: bool = False
cost: int = 1
exempt_when: Callable[[Request], bool] | None = None
on_blocked: Callable[[Request, Any], Any] | None = None
def __post_init__(self) -> None:
if self.limit <= 0:
msg = "limit must be positive"
raise ValueError(msg)
if self.window_size <= 0:
msg = "window_size must be positive"
raise ValueError(msg)
if self.cost <= 0:
msg = "cost must be positive"
raise ValueError(msg)
@dataclass(slots=True)
class GlobalConfig:
"""Global configuration for the rate limiter."""
backend: Backend | None = None
enabled: bool = True
default_limit: int = 100
default_window_size: float = 60.0
default_algorithm: Algorithm = Algorithm.SLIDING_WINDOW_COUNTER
key_prefix: str = "fastapi_traffic"
include_headers: bool = True
error_message: str = "Rate limit exceeded. Please try again later."
status_code: int = 429
skip_on_error: bool = False
exempt_ips: set[str] = field(default_factory=set)
exempt_paths: set[str] = field(default_factory=set)
headers_prefix: str = "X-RateLimit"