- Add ConfigLoader class for loading RateLimitConfig and GlobalConfig - Support .env files with FASTAPI_TRAFFIC_* prefixed variables - Support JSON configuration files with type validation - Add convenience functions: load_rate_limit_config, load_global_config - Add load_rate_limit_config_from_env, load_global_config_from_env - Support custom environment variable prefixes - Add comprehensive error handling with ConfigurationError - Add 47 tests for configuration loading - Add example 11_config_loader.py with 9 usage patterns - Update examples/README.md with config loader documentation - Update CHANGELOG.md with new feature - Fix typo in limiter.py (errant 'fi' on line 4)
533 lines
17 KiB
Python
533 lines
17 KiB
Python
"""Configuration loader for rate limiting settings from .env and .json files."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, TypeVar
|
|
|
|
from fastapi_traffic.core.algorithms import Algorithm
|
|
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
|
|
from fastapi_traffic.exceptions import ConfigurationError
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Mapping
|
|
|
|
T = TypeVar("T", RateLimitConfig, GlobalConfig)
|
|
|
|
# Environment variable prefix for config values
|
|
ENV_PREFIX = "FASTAPI_TRAFFIC_"
|
|
|
|
# Mapping of config field names to their types for validation
|
|
_RATE_LIMIT_FIELD_TYPES: dict[str, type[Any]] = {
|
|
"limit": int,
|
|
"window_size": float,
|
|
"algorithm": Algorithm,
|
|
"key_prefix": str,
|
|
"burst_size": int,
|
|
"include_headers": bool,
|
|
"error_message": str,
|
|
"status_code": int,
|
|
"skip_on_error": bool,
|
|
"cost": int,
|
|
}
|
|
|
|
_GLOBAL_FIELD_TYPES: dict[str, type[Any]] = {
|
|
"enabled": bool,
|
|
"default_limit": int,
|
|
"default_window_size": float,
|
|
"default_algorithm": Algorithm,
|
|
"key_prefix": str,
|
|
"include_headers": bool,
|
|
"error_message": str,
|
|
"status_code": int,
|
|
"skip_on_error": bool,
|
|
"exempt_ips": set,
|
|
"exempt_paths": set,
|
|
"headers_prefix": str,
|
|
}
|
|
|
|
# Fields that cannot be loaded from config files (callables, complex objects)
|
|
_NON_LOADABLE_FIELDS: frozenset[str] = frozenset({
|
|
"key_extractor",
|
|
"exempt_when",
|
|
"on_blocked",
|
|
"backend",
|
|
})
|
|
|
|
|
|
class ConfigLoader:
|
|
"""Loader for rate limiting configuration from various sources.
|
|
|
|
Supports loading configuration from:
|
|
- Environment variables (with FASTAPI_TRAFFIC_ prefix)
|
|
- .env files
|
|
- JSON files
|
|
|
|
Example usage:
|
|
>>> loader = ConfigLoader()
|
|
>>> global_config = loader.load_global_config_from_env()
|
|
>>> rate_config = loader.load_rate_limit_config_from_json("config.json")
|
|
"""
|
|
|
|
__slots__ = ("_env_prefix",)
|
|
|
|
def __init__(self, env_prefix: str = ENV_PREFIX) -> None:
|
|
"""Initialize the config loader.
|
|
|
|
Args:
|
|
env_prefix: Prefix for environment variables. Defaults to "FASTAPI_TRAFFIC_".
|
|
"""
|
|
self._env_prefix = env_prefix
|
|
|
|
def _parse_value(self, value: str, target_type: type[Any]) -> Any:
|
|
"""Parse a string value to the target type.
|
|
|
|
Args:
|
|
value: The string value to parse.
|
|
target_type: The target type to convert to.
|
|
|
|
Returns:
|
|
The parsed value.
|
|
|
|
Raises:
|
|
ConfigurationError: If the value cannot be parsed.
|
|
"""
|
|
try:
|
|
if target_type is bool:
|
|
return value.lower() in ("true", "1", "yes", "on")
|
|
if target_type is int:
|
|
return int(value)
|
|
if target_type is float:
|
|
return float(value)
|
|
if target_type is str:
|
|
return value
|
|
if target_type is Algorithm:
|
|
return Algorithm(value.lower())
|
|
if target_type is set:
|
|
# Parse comma-separated values
|
|
if not value.strip():
|
|
return set()
|
|
return {item.strip() for item in value.split(",") if item.strip()}
|
|
except (ValueError, KeyError) as e:
|
|
msg = f"Cannot parse value '{value}' as {target_type.__name__}: {e}"
|
|
raise ConfigurationError(msg) from e
|
|
|
|
msg = f"Unsupported type: {target_type}"
|
|
raise ConfigurationError(msg)
|
|
|
|
def _validate_and_convert(
|
|
self,
|
|
data: Mapping[str, Any],
|
|
field_types: dict[str, type[Any]],
|
|
) -> dict[str, Any]:
|
|
"""Validate and convert configuration data.
|
|
|
|
Args:
|
|
data: Raw configuration data.
|
|
field_types: Mapping of field names to their expected types.
|
|
|
|
Returns:
|
|
Validated and converted configuration dictionary.
|
|
|
|
Raises:
|
|
ConfigurationError: If validation fails.
|
|
"""
|
|
result: dict[str, Any] = {}
|
|
|
|
for key, value in data.items():
|
|
if key in _NON_LOADABLE_FIELDS:
|
|
msg = f"Field '{key}' cannot be loaded from configuration files"
|
|
raise ConfigurationError(msg)
|
|
|
|
if key not in field_types:
|
|
msg = f"Unknown configuration field: '{key}'"
|
|
raise ConfigurationError(msg)
|
|
|
|
target_type = field_types[key]
|
|
|
|
if isinstance(value, str):
|
|
result[key] = self._parse_value(value, target_type)
|
|
elif target_type is set and isinstance(value, list):
|
|
result[key] = set(value)
|
|
elif target_type is Algorithm and isinstance(value, str):
|
|
result[key] = Algorithm(value.lower())
|
|
elif isinstance(value, target_type):
|
|
result[key] = value
|
|
elif target_type is float and isinstance(value, int):
|
|
result[key] = float(value)
|
|
else:
|
|
msg = f"Invalid type for '{key}': expected {target_type.__name__}, got {type(value).__name__}"
|
|
raise ConfigurationError(msg)
|
|
|
|
return result
|
|
|
|
def _load_dotenv_file(self, file_path: Path) -> dict[str, str]:
|
|
"""Load environment variables from a .env file.
|
|
|
|
Args:
|
|
file_path: Path to the .env file.
|
|
|
|
Returns:
|
|
Dictionary of environment variable names to values.
|
|
|
|
Raises:
|
|
ConfigurationError: If the file cannot be read or parsed.
|
|
"""
|
|
if not file_path.exists():
|
|
msg = f"Configuration file not found: {file_path}"
|
|
raise ConfigurationError(msg)
|
|
|
|
env_vars: dict[str, str] = {}
|
|
|
|
try:
|
|
with file_path.open(encoding="utf-8") as f:
|
|
for line_num, line in enumerate(f, start=1):
|
|
line = line.strip()
|
|
|
|
# Skip empty lines and comments
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
|
|
# Parse key=value pairs
|
|
if "=" not in line:
|
|
msg = f"Invalid line {line_num} in {file_path}: missing '='"
|
|
raise ConfigurationError(msg)
|
|
|
|
key, _, value = line.partition("=")
|
|
key = key.strip()
|
|
value = value.strip()
|
|
|
|
# Remove surrounding quotes if present
|
|
if len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'"):
|
|
value = value[1:-1]
|
|
|
|
env_vars[key] = value
|
|
|
|
except OSError as e:
|
|
msg = f"Failed to read configuration file {file_path}: {e}"
|
|
raise ConfigurationError(msg) from e
|
|
|
|
return env_vars
|
|
|
|
def _load_json_file(self, file_path: Path) -> dict[str, Any]:
|
|
"""Load configuration from a JSON file.
|
|
|
|
Args:
|
|
file_path: Path to the JSON file.
|
|
|
|
Returns:
|
|
Configuration dictionary.
|
|
|
|
Raises:
|
|
ConfigurationError: If the file cannot be read or parsed.
|
|
"""
|
|
if not file_path.exists():
|
|
msg = f"Configuration file not found: {file_path}"
|
|
raise ConfigurationError(msg)
|
|
|
|
try:
|
|
with file_path.open(encoding="utf-8") as f:
|
|
data: dict[str, Any] = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
msg = f"Invalid JSON in {file_path}: {e}"
|
|
raise ConfigurationError(msg) from e
|
|
except OSError as e:
|
|
msg = f"Failed to read configuration file {file_path}: {e}"
|
|
raise ConfigurationError(msg) from e
|
|
|
|
return data
|
|
|
|
def _extract_env_config(
|
|
self,
|
|
prefix: str,
|
|
field_types: dict[str, type[Any]],
|
|
env_source: Mapping[str, str] | None = None,
|
|
) -> dict[str, str]:
|
|
"""Extract configuration from environment variables.
|
|
|
|
Args:
|
|
prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_").
|
|
field_types: Mapping of field names to their expected types.
|
|
env_source: Optional source of environment variables. Defaults to os.environ.
|
|
|
|
Returns:
|
|
Dictionary of field names to their string values.
|
|
"""
|
|
source = env_source if env_source is not None else os.environ
|
|
full_prefix = f"{self._env_prefix}{prefix}"
|
|
result: dict[str, str] = {}
|
|
|
|
for key, value in source.items():
|
|
if key.startswith(full_prefix):
|
|
field_name = key[len(full_prefix):].lower()
|
|
if field_name in field_types:
|
|
result[field_name] = value
|
|
|
|
return result
|
|
|
|
def load_rate_limit_config_from_env(
|
|
self,
|
|
env_source: Mapping[str, str] | None = None,
|
|
**overrides: Any,
|
|
) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from environment variables.
|
|
|
|
Environment variables should be prefixed with FASTAPI_TRAFFIC_RATE_LIMIT_
|
|
(e.g., FASTAPI_TRAFFIC_RATE_LIMIT_LIMIT=100).
|
|
|
|
Args:
|
|
env_source: Optional source of environment variables. Defaults to os.environ.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
raw_config = self._extract_env_config(
|
|
"RATE_LIMIT_", _RATE_LIMIT_FIELD_TYPES, env_source
|
|
)
|
|
config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES)
|
|
|
|
# Apply overrides
|
|
for key, value in overrides.items():
|
|
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES:
|
|
config_dict[key] = value
|
|
|
|
# Ensure required field 'limit' is present
|
|
if "limit" not in config_dict:
|
|
msg = "Required field 'limit' not found in environment configuration"
|
|
raise ConfigurationError(msg)
|
|
|
|
return RateLimitConfig(**config_dict)
|
|
|
|
def load_rate_limit_config_from_dotenv(
|
|
self,
|
|
file_path: str | Path,
|
|
**overrides: Any,
|
|
) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from a .env file.
|
|
|
|
Args:
|
|
file_path: Path to the .env file.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
path = Path(file_path)
|
|
env_vars = self._load_dotenv_file(path)
|
|
return self.load_rate_limit_config_from_env(env_vars, **overrides)
|
|
|
|
def load_rate_limit_config_from_json(
|
|
self,
|
|
file_path: str | Path,
|
|
**overrides: Any,
|
|
) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from a JSON file.
|
|
|
|
Args:
|
|
file_path: Path to the JSON file.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
path = Path(file_path)
|
|
raw_config = self._load_json_file(path)
|
|
config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES)
|
|
|
|
# Apply overrides
|
|
for key, value in overrides.items():
|
|
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES:
|
|
config_dict[key] = value
|
|
|
|
# Ensure required field 'limit' is present
|
|
if "limit" not in config_dict:
|
|
msg = "Required field 'limit' not found in JSON configuration"
|
|
raise ConfigurationError(msg)
|
|
|
|
return RateLimitConfig(**config_dict)
|
|
|
|
def load_global_config_from_env(
|
|
self,
|
|
env_source: Mapping[str, str] | None = None,
|
|
**overrides: Any,
|
|
) -> GlobalConfig:
|
|
"""Load GlobalConfig from environment variables.
|
|
|
|
Environment variables should be prefixed with FASTAPI_TRAFFIC_GLOBAL_
|
|
(e.g., FASTAPI_TRAFFIC_GLOBAL_ENABLED=true).
|
|
|
|
Args:
|
|
env_source: Optional source of environment variables. Defaults to os.environ.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
raw_config = self._extract_env_config(
|
|
"GLOBAL_", _GLOBAL_FIELD_TYPES, env_source
|
|
)
|
|
config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES)
|
|
|
|
# Apply overrides
|
|
for key, value in overrides.items():
|
|
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES:
|
|
config_dict[key] = value
|
|
|
|
return GlobalConfig(**config_dict)
|
|
|
|
def load_global_config_from_dotenv(
|
|
self,
|
|
file_path: str | Path,
|
|
**overrides: Any,
|
|
) -> GlobalConfig:
|
|
"""Load GlobalConfig from a .env file.
|
|
|
|
Args:
|
|
file_path: Path to the .env file.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
path = Path(file_path)
|
|
env_vars = self._load_dotenv_file(path)
|
|
return self.load_global_config_from_env(env_vars, **overrides)
|
|
|
|
def load_global_config_from_json(
|
|
self,
|
|
file_path: str | Path,
|
|
**overrides: Any,
|
|
) -> GlobalConfig:
|
|
"""Load GlobalConfig from a JSON file.
|
|
|
|
Args:
|
|
file_path: Path to the JSON file.
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid.
|
|
"""
|
|
path = Path(file_path)
|
|
raw_config = self._load_json_file(path)
|
|
config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES)
|
|
|
|
# Apply overrides
|
|
for key, value in overrides.items():
|
|
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES:
|
|
config_dict[key] = value
|
|
|
|
return GlobalConfig(**config_dict)
|
|
|
|
|
|
# Convenience functions for direct usage
|
|
_default_loader: ConfigLoader | None = None
|
|
|
|
|
|
def _get_default_loader() -> ConfigLoader:
|
|
"""Get or create the default config loader."""
|
|
global _default_loader
|
|
if _default_loader is None:
|
|
_default_loader = ConfigLoader()
|
|
return _default_loader
|
|
|
|
|
|
def load_rate_limit_config(
|
|
source: str | Path,
|
|
**overrides: Any,
|
|
) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from a file (auto-detects format).
|
|
|
|
Args:
|
|
source: Path to configuration file (.env or .json).
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid or format unknown.
|
|
"""
|
|
loader = _get_default_loader()
|
|
path = Path(source)
|
|
|
|
if path.suffix.lower() == ".json":
|
|
return loader.load_rate_limit_config_from_json(path, **overrides)
|
|
if path.suffix.lower() in (".env", "") or path.name.startswith(".env"):
|
|
return loader.load_rate_limit_config_from_dotenv(path, **overrides)
|
|
|
|
msg = f"Unknown configuration file format: {path.suffix}"
|
|
raise ConfigurationError(msg)
|
|
|
|
|
|
def load_global_config(
|
|
source: str | Path,
|
|
**overrides: Any,
|
|
) -> GlobalConfig:
|
|
"""Load GlobalConfig from a file (auto-detects format).
|
|
|
|
Args:
|
|
source: Path to configuration file (.env or .json).
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
|
|
Raises:
|
|
ConfigurationError: If configuration is invalid or format unknown.
|
|
"""
|
|
loader = _get_default_loader()
|
|
path = Path(source)
|
|
|
|
if path.suffix.lower() == ".json":
|
|
return loader.load_global_config_from_json(path, **overrides)
|
|
if path.suffix.lower() in (".env", "") or path.name.startswith(".env"):
|
|
return loader.load_global_config_from_dotenv(path, **overrides)
|
|
|
|
msg = f"Unknown configuration file format: {path.suffix}"
|
|
raise ConfigurationError(msg)
|
|
|
|
|
|
def load_rate_limit_config_from_env(**overrides: Any) -> RateLimitConfig:
|
|
"""Load RateLimitConfig from environment variables.
|
|
|
|
Args:
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured RateLimitConfig instance.
|
|
"""
|
|
return _get_default_loader().load_rate_limit_config_from_env(**overrides)
|
|
|
|
|
|
def load_global_config_from_env(**overrides: Any) -> GlobalConfig:
|
|
"""Load GlobalConfig from environment variables.
|
|
|
|
Args:
|
|
**overrides: Additional values to override loaded config.
|
|
|
|
Returns:
|
|
Configured GlobalConfig instance.
|
|
"""
|
|
return _get_default_loader().load_global_config_from_env(**overrides)
|