Files
fastapi-traffic/fastapi_traffic/core/limiter.py
zanewalker fb23e3c7cf feat: add configuration loader for .env and JSON files
- Add ConfigLoader class for loading RateLimitConfig and GlobalConfig
- Support .env files with FASTAPI_TRAFFIC_* prefixed variables
- Support JSON configuration files with type validation
- Add convenience functions: load_rate_limit_config, load_global_config
- Add load_rate_limit_config_from_env, load_global_config_from_env
- Support custom environment variable prefixes
- Add comprehensive error handling with ConfigurationError
- Add 47 tests for configuration loading
- Add example 11_config_loader.py with 9 usage patterns
- Update examples/README.md with config loader documentation
- Update CHANGELOG.md with new feature
- Fix typo in limiter.py (errant 'fi' on line 4)
2026-02-01 13:59:32 +00:00

298 lines
9.0 KiB
Python

"""Core rate limiter implementation."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from fastapi_traffic.backends.memory import MemoryBackend
from fastapi_traffic.core.algorithms import Algorithm, BaseAlgorithm, get_algorithm
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
from fastapi_traffic.core.models import RateLimitInfo, RateLimitResult
from fastapi_traffic.exceptions import BackendError, RateLimitExceeded
if TYPE_CHECKING:
from starlette.requests import Request
from fastapi_traffic.backends.base import Backend
logger = logging.getLogger(__name__)
class RateLimiter:
"""Main rate limiter class that manages rate limiting logic."""
__slots__ = ("_algorithms", "_backend", "_config", "_initialized")
def __init__(
self,
backend: Backend | None = None,
*,
config: GlobalConfig | None = None,
) -> None:
"""Initialize the rate limiter.
Args:
backend: Storage backend for rate limit data.
config: Global configuration options.
"""
self._config = config or GlobalConfig()
self._backend = backend or self._config.backend or MemoryBackend()
self._algorithms: dict[str, BaseAlgorithm] = {}
self._initialized = False
@property
def backend(self) -> Backend:
"""Get the storage backend."""
return self._backend
@property
def config(self) -> GlobalConfig:
"""Get the global configuration."""
return self._config
async def initialize(self) -> None:
"""Initialize the rate limiter and backend."""
if self._initialized:
return
if hasattr(self._backend, "initialize"):
await self._backend.initialize() # type: ignore[union-attr]
if hasattr(self._backend, "start_cleanup"):
await self._backend.start_cleanup() # type: ignore[union-attr]
self._initialized = True
async def close(self) -> None:
"""Close the rate limiter and cleanup resources."""
await self._backend.close()
self._algorithms.clear()
self._initialized = False
def _get_algorithm(
self,
limit: int,
window_size: float,
algorithm: Algorithm,
burst_size: int | None = None,
) -> BaseAlgorithm:
"""Get or create an algorithm instance."""
cache_key = f"{algorithm.value}:{limit}:{window_size}:{burst_size}"
if cache_key not in self._algorithms:
self._algorithms[cache_key] = get_algorithm(
algorithm,
limit,
window_size,
self._backend,
burst_size=burst_size,
)
return self._algorithms[cache_key]
def _build_key(
self,
request: Request,
config: RateLimitConfig,
identifier: str | None = None,
) -> str:
"""Build the rate limit key for a request."""
client_id = identifier or config.key_extractor(request)
path = request.url.path
method = request.method
return (
f"{self._config.key_prefix}:{config.key_prefix}:{method}:{path}:{client_id}"
)
def _is_exempt(self, request: Request, config: RateLimitConfig) -> bool:
"""Check if the request is exempt from rate limiting."""
if not self._config.enabled:
return True
if config.exempt_when is not None and config.exempt_when(request):
return True
client_ip = config.key_extractor(request)
if client_ip in self._config.exempt_ips:
return True
return request.url.path in self._config.exempt_paths
async def check(
self,
request: Request,
config: RateLimitConfig,
*,
identifier: str | None = None,
cost: int | None = None,
) -> RateLimitResult:
"""Check if a request is allowed under the rate limit.
Args:
request: The incoming request.
config: Rate limit configuration for this endpoint.
identifier: Optional custom identifier override.
cost: Optional cost override for this request.
Returns:
RateLimitResult with allowed status and limit info.
"""
if not self._initialized:
await self.initialize()
if self._is_exempt(request, config):
return RateLimitResult(
allowed=True,
info=RateLimitInfo(
limit=config.limit,
remaining=config.limit,
reset_at=0,
window_size=config.window_size,
),
key="exempt",
)
key = self._build_key(request, config, identifier)
actual_cost = cost or config.cost
try:
algorithm = self._get_algorithm(
config.limit,
config.window_size,
config.algorithm,
config.burst_size,
)
info: RateLimitInfo | None = None
for _ in range(actual_cost):
allowed, info = await algorithm.check(key)
if not allowed:
return RateLimitResult(allowed=False, info=info, key=key)
if info is None:
info = RateLimitInfo(
limit=config.limit,
remaining=config.limit,
reset_at=0,
window_size=config.window_size,
)
return RateLimitResult(allowed=True, info=info, key=key)
except BackendError as e:
logger.warning("Backend error during rate limit check: %s", e)
if config.skip_on_error:
return RateLimitResult(
allowed=True,
info=RateLimitInfo(
limit=config.limit,
remaining=config.limit,
reset_at=0,
window_size=config.window_size,
),
key=key,
)
raise
async def hit(
self,
request: Request,
config: RateLimitConfig,
*,
identifier: str | None = None,
cost: int | None = None,
) -> RateLimitResult:
"""Check rate limit and raise exception if exceeded.
Args:
request: The incoming request.
config: Rate limit configuration for this endpoint.
identifier: Optional custom identifier override.
cost: Optional cost override for this request.
Returns:
RateLimitResult if allowed.
Raises:
RateLimitExceeded: If the rate limit is exceeded.
"""
result = await self.check(request, config, identifier=identifier, cost=cost)
if not result.allowed:
if config.on_blocked is not None:
config.on_blocked(request, result)
raise RateLimitExceeded(
config.error_message,
retry_after=result.info.retry_after,
limit_info=result.info,
)
return result
async def reset(
self,
request: Request,
config: RateLimitConfig,
*,
identifier: str | None = None,
) -> None:
"""Reset the rate limit for a specific key.
Args:
request: The request to reset limits for.
config: Rate limit configuration.
identifier: Optional custom identifier override.
"""
key = self._build_key(request, config, identifier)
algorithm = self._get_algorithm(
config.limit,
config.window_size,
config.algorithm,
config.burst_size,
)
await algorithm.reset(key)
async def get_state(
self,
request: Request,
config: RateLimitConfig,
*,
identifier: str | None = None,
) -> RateLimitInfo | None:
"""Get the current rate limit state without consuming a token.
Args:
request: The request to check.
config: Rate limit configuration.
identifier: Optional custom identifier override.
Returns:
Current rate limit info or None if no state exists.
"""
key = self._build_key(request, config, identifier)
algorithm = self._get_algorithm(
config.limit,
config.window_size,
config.algorithm,
config.burst_size,
)
return await algorithm.get_state(key)
_default_limiter: RateLimiter | None = None
def get_limiter() -> RateLimiter:
"""Get the default rate limiter instance."""
global _default_limiter
if _default_limiter is None:
_default_limiter = RateLimiter()
return _default_limiter
def set_limiter(limiter: RateLimiter) -> None:
"""Set the default rate limiter instance."""
global _default_limiter
_default_limiter = limiter