- Add ConfigLoader class for loading RateLimitConfig and GlobalConfig - Support .env files with FASTAPI_TRAFFIC_* prefixed variables - Support JSON configuration files with type validation - Add convenience functions: load_rate_limit_config, load_global_config - Add load_rate_limit_config_from_env, load_global_config_from_env - Support custom environment variable prefixes - Add comprehensive error handling with ConfigurationError - Add 47 tests for configuration loading - Add example 11_config_loader.py with 9 usage patterns - Update examples/README.md with config loader documentation - Update CHANGELOG.md with new feature - Fix typo in limiter.py (errant 'fi' on line 4)
298 lines
9.0 KiB
Python
298 lines
9.0 KiB
Python
"""Core rate limiter implementation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
from fastapi_traffic.backends.memory import MemoryBackend
|
|
from fastapi_traffic.core.algorithms import Algorithm, BaseAlgorithm, get_algorithm
|
|
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
|
|
from fastapi_traffic.core.models import RateLimitInfo, RateLimitResult
|
|
from fastapi_traffic.exceptions import BackendError, RateLimitExceeded
|
|
|
|
if TYPE_CHECKING:
|
|
from starlette.requests import Request
|
|
|
|
from fastapi_traffic.backends.base import Backend
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimiter:
|
|
"""Main rate limiter class that manages rate limiting logic."""
|
|
|
|
__slots__ = ("_algorithms", "_backend", "_config", "_initialized")
|
|
|
|
def __init__(
|
|
self,
|
|
backend: Backend | None = None,
|
|
*,
|
|
config: GlobalConfig | None = None,
|
|
) -> None:
|
|
"""Initialize the rate limiter.
|
|
|
|
Args:
|
|
backend: Storage backend for rate limit data.
|
|
config: Global configuration options.
|
|
"""
|
|
self._config = config or GlobalConfig()
|
|
self._backend = backend or self._config.backend or MemoryBackend()
|
|
self._algorithms: dict[str, BaseAlgorithm] = {}
|
|
self._initialized = False
|
|
|
|
@property
|
|
def backend(self) -> Backend:
|
|
"""Get the storage backend."""
|
|
return self._backend
|
|
|
|
@property
|
|
def config(self) -> GlobalConfig:
|
|
"""Get the global configuration."""
|
|
return self._config
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the rate limiter and backend."""
|
|
if self._initialized:
|
|
return
|
|
|
|
if hasattr(self._backend, "initialize"):
|
|
await self._backend.initialize() # type: ignore[union-attr]
|
|
|
|
if hasattr(self._backend, "start_cleanup"):
|
|
await self._backend.start_cleanup() # type: ignore[union-attr]
|
|
|
|
self._initialized = True
|
|
|
|
async def close(self) -> None:
|
|
"""Close the rate limiter and cleanup resources."""
|
|
await self._backend.close()
|
|
self._algorithms.clear()
|
|
self._initialized = False
|
|
|
|
def _get_algorithm(
|
|
self,
|
|
limit: int,
|
|
window_size: float,
|
|
algorithm: Algorithm,
|
|
burst_size: int | None = None,
|
|
) -> BaseAlgorithm:
|
|
"""Get or create an algorithm instance."""
|
|
cache_key = f"{algorithm.value}:{limit}:{window_size}:{burst_size}"
|
|
if cache_key not in self._algorithms:
|
|
self._algorithms[cache_key] = get_algorithm(
|
|
algorithm,
|
|
limit,
|
|
window_size,
|
|
self._backend,
|
|
burst_size=burst_size,
|
|
)
|
|
return self._algorithms[cache_key]
|
|
|
|
def _build_key(
|
|
self,
|
|
request: Request,
|
|
config: RateLimitConfig,
|
|
identifier: str | None = None,
|
|
) -> str:
|
|
"""Build the rate limit key for a request."""
|
|
client_id = identifier or config.key_extractor(request)
|
|
|
|
path = request.url.path
|
|
method = request.method
|
|
|
|
return (
|
|
f"{self._config.key_prefix}:{config.key_prefix}:{method}:{path}:{client_id}"
|
|
)
|
|
|
|
def _is_exempt(self, request: Request, config: RateLimitConfig) -> bool:
|
|
"""Check if the request is exempt from rate limiting."""
|
|
if not self._config.enabled:
|
|
return True
|
|
|
|
if config.exempt_when is not None and config.exempt_when(request):
|
|
return True
|
|
|
|
client_ip = config.key_extractor(request)
|
|
if client_ip in self._config.exempt_ips:
|
|
return True
|
|
|
|
return request.url.path in self._config.exempt_paths
|
|
|
|
async def check(
|
|
self,
|
|
request: Request,
|
|
config: RateLimitConfig,
|
|
*,
|
|
identifier: str | None = None,
|
|
cost: int | None = None,
|
|
) -> RateLimitResult:
|
|
"""Check if a request is allowed under the rate limit.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
config: Rate limit configuration for this endpoint.
|
|
identifier: Optional custom identifier override.
|
|
cost: Optional cost override for this request.
|
|
|
|
Returns:
|
|
RateLimitResult with allowed status and limit info.
|
|
"""
|
|
if not self._initialized:
|
|
await self.initialize()
|
|
|
|
if self._is_exempt(request, config):
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
info=RateLimitInfo(
|
|
limit=config.limit,
|
|
remaining=config.limit,
|
|
reset_at=0,
|
|
window_size=config.window_size,
|
|
),
|
|
key="exempt",
|
|
)
|
|
|
|
key = self._build_key(request, config, identifier)
|
|
actual_cost = cost or config.cost
|
|
|
|
try:
|
|
algorithm = self._get_algorithm(
|
|
config.limit,
|
|
config.window_size,
|
|
config.algorithm,
|
|
config.burst_size,
|
|
)
|
|
|
|
info: RateLimitInfo | None = None
|
|
for _ in range(actual_cost):
|
|
allowed, info = await algorithm.check(key)
|
|
if not allowed:
|
|
return RateLimitResult(allowed=False, info=info, key=key)
|
|
|
|
if info is None:
|
|
info = RateLimitInfo(
|
|
limit=config.limit,
|
|
remaining=config.limit,
|
|
reset_at=0,
|
|
window_size=config.window_size,
|
|
)
|
|
return RateLimitResult(allowed=True, info=info, key=key)
|
|
|
|
except BackendError as e:
|
|
logger.warning("Backend error during rate limit check: %s", e)
|
|
if config.skip_on_error:
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
info=RateLimitInfo(
|
|
limit=config.limit,
|
|
remaining=config.limit,
|
|
reset_at=0,
|
|
window_size=config.window_size,
|
|
),
|
|
key=key,
|
|
)
|
|
raise
|
|
|
|
async def hit(
|
|
self,
|
|
request: Request,
|
|
config: RateLimitConfig,
|
|
*,
|
|
identifier: str | None = None,
|
|
cost: int | None = None,
|
|
) -> RateLimitResult:
|
|
"""Check rate limit and raise exception if exceeded.
|
|
|
|
Args:
|
|
request: The incoming request.
|
|
config: Rate limit configuration for this endpoint.
|
|
identifier: Optional custom identifier override.
|
|
cost: Optional cost override for this request.
|
|
|
|
Returns:
|
|
RateLimitResult if allowed.
|
|
|
|
Raises:
|
|
RateLimitExceeded: If the rate limit is exceeded.
|
|
"""
|
|
result = await self.check(request, config, identifier=identifier, cost=cost)
|
|
|
|
if not result.allowed:
|
|
if config.on_blocked is not None:
|
|
config.on_blocked(request, result)
|
|
|
|
raise RateLimitExceeded(
|
|
config.error_message,
|
|
retry_after=result.info.retry_after,
|
|
limit_info=result.info,
|
|
)
|
|
|
|
return result
|
|
|
|
async def reset(
|
|
self,
|
|
request: Request,
|
|
config: RateLimitConfig,
|
|
*,
|
|
identifier: str | None = None,
|
|
) -> None:
|
|
"""Reset the rate limit for a specific key.
|
|
|
|
Args:
|
|
request: The request to reset limits for.
|
|
config: Rate limit configuration.
|
|
identifier: Optional custom identifier override.
|
|
"""
|
|
key = self._build_key(request, config, identifier)
|
|
algorithm = self._get_algorithm(
|
|
config.limit,
|
|
config.window_size,
|
|
config.algorithm,
|
|
config.burst_size,
|
|
)
|
|
await algorithm.reset(key)
|
|
|
|
async def get_state(
|
|
self,
|
|
request: Request,
|
|
config: RateLimitConfig,
|
|
*,
|
|
identifier: str | None = None,
|
|
) -> RateLimitInfo | None:
|
|
"""Get the current rate limit state without consuming a token.
|
|
|
|
Args:
|
|
request: The request to check.
|
|
config: Rate limit configuration.
|
|
identifier: Optional custom identifier override.
|
|
|
|
Returns:
|
|
Current rate limit info or None if no state exists.
|
|
"""
|
|
key = self._build_key(request, config, identifier)
|
|
algorithm = self._get_algorithm(
|
|
config.limit,
|
|
config.window_size,
|
|
config.algorithm,
|
|
config.burst_size,
|
|
)
|
|
return await algorithm.get_state(key)
|
|
|
|
|
|
_default_limiter: RateLimiter | None = None
|
|
|
|
|
|
def get_limiter() -> RateLimiter:
|
|
"""Get the default rate limiter instance."""
|
|
global _default_limiter
|
|
if _default_limiter is None:
|
|
_default_limiter = RateLimiter()
|
|
return _default_limiter
|
|
|
|
|
|
def set_limiter(limiter: RateLimiter) -> None:
|
|
"""Set the default rate limiter instance."""
|
|
global _default_limiter
|
|
_default_limiter = limiter
|