Files
fastapi-traffic/fastapi_traffic/core/config_loader.py

542 lines
17 KiB
Python

"""Configuration loader for rate limiting settings from .env and .json files."""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
from fastapi_traffic.core.algorithms import Algorithm
from fastapi_traffic.core.config import GlobalConfig, RateLimitConfig
from fastapi_traffic.exceptions import ConfigurationError
if TYPE_CHECKING:
from collections.abc import Mapping
T = TypeVar("T", RateLimitConfig, GlobalConfig)
# Environment variable prefix for config values
ENV_PREFIX = "FASTAPI_TRAFFIC_"
# Mapping of config field names to their types for validation
_RATE_LIMIT_FIELD_TYPES: dict[str, type[Any]] = {
"limit": int,
"window_size": float,
"algorithm": Algorithm,
"key_prefix": str,
"burst_size": int,
"include_headers": bool,
"error_message": str,
"status_code": int,
"skip_on_error": bool,
"cost": int,
}
_GLOBAL_FIELD_TYPES: dict[str, type[Any]] = {
"enabled": bool,
"default_limit": int,
"default_window_size": float,
"default_algorithm": Algorithm,
"key_prefix": str,
"include_headers": bool,
"error_message": str,
"status_code": int,
"skip_on_error": bool,
"exempt_ips": set,
"exempt_paths": set,
"headers_prefix": str,
}
# Fields that cannot be loaded from config files (callables, complex objects)
_NON_LOADABLE_FIELDS: frozenset[str] = frozenset(
{
"key_extractor",
"exempt_when",
"on_blocked",
"backend",
}
)
class ConfigLoader:
"""Loader for rate limiting configuration from various sources.
Supports loading configuration from:
- Environment variables (with FASTAPI_TRAFFIC_ prefix)
- .env files
- JSON files
Example usage:
>>> loader = ConfigLoader()
>>> global_config = loader.load_global_config_from_env()
>>> rate_config = loader.load_rate_limit_config_from_json("config.json")
"""
__slots__ = ("_env_prefix",)
def __init__(self, env_prefix: str = ENV_PREFIX) -> None:
"""Initialize the config loader.
Args:
env_prefix: Prefix for environment variables. Defaults to "FASTAPI_TRAFFIC_".
"""
self._env_prefix = env_prefix
def _parse_value(self, value: str, target_type: type[Any]) -> Any:
"""Parse a string value to the target type.
Args:
value: The string value to parse.
target_type: The target type to convert to.
Returns:
The parsed value.
Raises:
ConfigurationError: If the value cannot be parsed.
"""
try:
if target_type is bool:
return value.lower() in ("true", "1", "yes", "on")
if target_type is int:
return int(value)
if target_type is float:
return float(value)
if target_type is str:
return value
if target_type is Algorithm:
return Algorithm(value.lower())
if target_type is set:
# Parse comma-separated values
if not value.strip():
return set()
return {item.strip() for item in value.split(",") if item.strip()}
except (ValueError, KeyError) as e:
msg = f"Cannot parse value '{value}' as {target_type.__name__}: {e}"
raise ConfigurationError(msg) from e
msg = f"Unsupported type: {target_type}"
raise ConfigurationError(msg)
def _validate_and_convert(
self,
data: Mapping[str, Any],
field_types: dict[str, type[Any]],
) -> dict[str, Any]:
"""Validate and convert configuration data.
Args:
data: Raw configuration data.
field_types: Mapping of field names to their expected types.
Returns:
Validated and converted configuration dictionary.
Raises:
ConfigurationError: If validation fails.
"""
result: dict[str, Any] = {}
for key, value in data.items():
if key in _NON_LOADABLE_FIELDS:
msg = f"Field '{key}' cannot be loaded from configuration files"
raise ConfigurationError(msg)
if key not in field_types:
msg = f"Unknown configuration field: '{key}'"
raise ConfigurationError(msg)
target_type = field_types[key]
if isinstance(value, str):
result[key] = self._parse_value(value, target_type)
elif target_type is set and isinstance(value, list):
result[key] = set(value)
elif target_type is Algorithm and isinstance(value, str):
result[key] = Algorithm(value.lower())
elif isinstance(value, target_type):
result[key] = value
elif target_type is float and isinstance(value, int):
result[key] = float(value)
else:
msg = f"Invalid type for '{key}': expected {target_type.__name__}, got {type(value).__name__}"
raise ConfigurationError(msg)
return result
def _load_dotenv_file(self, file_path: Path) -> dict[str, str]:
"""Load environment variables from a .env file.
Args:
file_path: Path to the .env file.
Returns:
Dictionary of environment variable names to values.
Raises:
ConfigurationError: If the file cannot be read or parsed.
"""
if not file_path.exists():
msg = f"Configuration file not found: {file_path}"
raise ConfigurationError(msg)
env_vars: dict[str, str] = {}
try:
with file_path.open(encoding="utf-8") as f:
for line_num, line in enumerate(f, start=1):
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
continue
# Parse key=value pairs
if "=" not in line:
msg = f"Invalid line {line_num} in {file_path}: missing '='"
raise ConfigurationError(msg)
key, _, value = line.partition("=")
key = key.strip()
value = value.strip()
# Remove surrounding quotes if present
if (
len(value) >= 2
and value[0] == value[-1]
and value[0] in ('"', "'")
):
value = value[1:-1]
env_vars[key] = value
except OSError as e:
msg = f"Failed to read configuration file {file_path}: {e}"
raise ConfigurationError(msg) from e
return env_vars
def _load_json_file(self, file_path: Path) -> Any:
"""Load configuration from a JSON file.
Args:
file_path: Path to the JSON file.
Returns:
Parsed JSON data (could be any JSON type).
Raises:
ConfigurationError: If the file cannot be read or parsed.
"""
if not file_path.exists():
msg = f"Configuration file not found: {file_path}"
raise ConfigurationError(msg)
try:
with file_path.open(encoding="utf-8") as f:
data: Any = json.load(f)
except json.JSONDecodeError as e:
msg = f"Invalid JSON in {file_path}: {e}"
raise ConfigurationError(msg) from e
except OSError as e:
msg = f"Failed to read configuration file {file_path}: {e}"
raise ConfigurationError(msg) from e
return data
def _extract_env_config(
self,
prefix: str,
field_types: dict[str, type[Any]],
env_source: Mapping[str, str] | None = None,
) -> dict[str, str]:
"""Extract configuration from environment variables.
Args:
prefix: The prefix to look for (e.g., "RATE_LIMIT_" or "GLOBAL_").
field_types: Mapping of field names to their expected types.
env_source: Optional source of environment variables. Defaults to os.environ.
Returns:
Dictionary of field names to their string values.
"""
source = env_source if env_source is not None else os.environ
full_prefix = f"{self._env_prefix}{prefix}"
result: dict[str, str] = {}
for key, value in source.items():
if key.startswith(full_prefix):
field_name = key[len(full_prefix) :].lower()
if field_name in field_types:
result[field_name] = value
return result
def load_rate_limit_config_from_env(
self,
env_source: Mapping[str, str] | None = None,
**overrides: Any,
) -> RateLimitConfig:
"""Load RateLimitConfig from environment variables.
Environment variables should be prefixed with FASTAPI_TRAFFIC_RATE_LIMIT_
(e.g., FASTAPI_TRAFFIC_RATE_LIMIT_LIMIT=100).
Args:
env_source: Optional source of environment variables. Defaults to os.environ.
**overrides: Additional values to override loaded config.
Returns:
Configured RateLimitConfig instance.
Raises:
ConfigurationError: If configuration is invalid.
"""
raw_config = self._extract_env_config(
"RATE_LIMIT_", _RATE_LIMIT_FIELD_TYPES, env_source
)
config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES)
# Apply overrides
for key, value in overrides.items():
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES:
config_dict[key] = value
# Ensure required field 'limit' is present
if "limit" not in config_dict:
msg = "Required field 'limit' not found in environment configuration"
raise ConfigurationError(msg)
return RateLimitConfig(**config_dict)
def load_rate_limit_config_from_dotenv(
self,
file_path: str | Path,
**overrides: Any,
) -> RateLimitConfig:
"""Load RateLimitConfig from a .env file.
Args:
file_path: Path to the .env file.
**overrides: Additional values to override loaded config.
Returns:
Configured RateLimitConfig instance.
Raises:
ConfigurationError: If configuration is invalid.
"""
path = Path(file_path)
env_vars = self._load_dotenv_file(path)
return self.load_rate_limit_config_from_env(env_vars, **overrides)
def load_rate_limit_config_from_json(
self,
file_path: str | Path,
**overrides: Any,
) -> RateLimitConfig:
"""Load RateLimitConfig from a JSON file.
Args:
file_path: Path to the JSON file.
**overrides: Additional values to override loaded config.
Returns:
Configured RateLimitConfig instance.
Raises:
ConfigurationError: If configuration is invalid.
"""
path = Path(file_path)
raw_config = self._load_json_file(path)
if not isinstance(raw_config, dict):
msg = "JSON root must be an object"
raise ConfigurationError(msg)
config_dict = self._validate_and_convert(raw_config, _RATE_LIMIT_FIELD_TYPES)
# Apply overrides
for key, value in overrides.items():
if key in _NON_LOADABLE_FIELDS or key in _RATE_LIMIT_FIELD_TYPES:
config_dict[key] = value
# Ensure required field 'limit' is present
if "limit" not in config_dict:
msg = "Required field 'limit' not found in JSON configuration"
raise ConfigurationError(msg)
return RateLimitConfig(**config_dict)
def load_global_config_from_env(
self,
env_source: Mapping[str, str] | None = None,
**overrides: Any,
) -> GlobalConfig:
"""Load GlobalConfig from environment variables.
Environment variables should be prefixed with FASTAPI_TRAFFIC_GLOBAL_
(e.g., FASTAPI_TRAFFIC_GLOBAL_ENABLED=true).
Args:
env_source: Optional source of environment variables. Defaults to os.environ.
**overrides: Additional values to override loaded config.
Returns:
Configured GlobalConfig instance.
Raises:
ConfigurationError: If configuration is invalid.
"""
raw_config = self._extract_env_config(
"GLOBAL_", _GLOBAL_FIELD_TYPES, env_source
)
config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES)
# Apply overrides
for key, value in overrides.items():
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES:
config_dict[key] = value
return GlobalConfig(**config_dict)
def load_global_config_from_dotenv(
self,
file_path: str | Path,
**overrides: Any,
) -> GlobalConfig:
"""Load GlobalConfig from a .env file.
Args:
file_path: Path to the .env file.
**overrides: Additional values to override loaded config.
Returns:
Configured GlobalConfig instance.
Raises:
ConfigurationError: If configuration is invalid.
"""
path = Path(file_path)
env_vars = self._load_dotenv_file(path)
return self.load_global_config_from_env(env_vars, **overrides)
def load_global_config_from_json(
self,
file_path: str | Path,
**overrides: Any,
) -> GlobalConfig:
"""Load GlobalConfig from a JSON file.
Args:
file_path: Path to the JSON file.
**overrides: Additional values to override loaded config.
Returns:
Configured GlobalConfig instance.
Raises:
ConfigurationError: If configuration is invalid.
"""
path = Path(file_path)
raw_config = self._load_json_file(path)
config_dict = self._validate_and_convert(raw_config, _GLOBAL_FIELD_TYPES)
# Apply overrides
for key, value in overrides.items():
if key in _NON_LOADABLE_FIELDS or key in _GLOBAL_FIELD_TYPES:
config_dict[key] = value
return GlobalConfig(**config_dict)
# Convenience functions for direct usage
_default_loader: ConfigLoader | None = None
def _get_default_loader() -> ConfigLoader:
"""Get or create the default config loader."""
global _default_loader
if _default_loader is None:
_default_loader = ConfigLoader()
return _default_loader
def load_rate_limit_config(
source: str | Path,
**overrides: Any,
) -> RateLimitConfig:
"""Load RateLimitConfig from a file (auto-detects format).
Args:
source: Path to configuration file (.env or .json).
**overrides: Additional values to override loaded config.
Returns:
Configured RateLimitConfig instance.
Raises:
ConfigurationError: If configuration is invalid or format unknown.
"""
loader = _get_default_loader()
path = Path(source)
if path.suffix.lower() == ".json":
return loader.load_rate_limit_config_from_json(path, **overrides)
if path.suffix.lower() in (".env", "") or path.name.startswith(".env"):
return loader.load_rate_limit_config_from_dotenv(path, **overrides)
msg = f"Unknown configuration file format: {path.suffix}"
raise ConfigurationError(msg)
def load_global_config(
source: str | Path,
**overrides: Any,
) -> GlobalConfig:
"""Load GlobalConfig from a file (auto-detects format).
Args:
source: Path to configuration file (.env or .json).
**overrides: Additional values to override loaded config.
Returns:
Configured GlobalConfig instance.
Raises:
ConfigurationError: If configuration is invalid or format unknown.
"""
loader = _get_default_loader()
path = Path(source)
if path.suffix.lower() == ".json":
return loader.load_global_config_from_json(path, **overrides)
if path.suffix.lower() in (".env", "") or path.name.startswith(".env"):
return loader.load_global_config_from_dotenv(path, **overrides)
msg = f"Unknown configuration file format: {path.suffix}"
raise ConfigurationError(msg)
def load_rate_limit_config_from_env(**overrides: Any) -> RateLimitConfig:
"""Load RateLimitConfig from environment variables.
Args:
**overrides: Additional values to override loaded config.
Returns:
Configured RateLimitConfig instance.
"""
return _get_default_loader().load_rate_limit_config_from_env(**overrides)
def load_global_config_from_env(**overrides: Any) -> GlobalConfig:
"""Load GlobalConfig from environment variables.
Args:
**overrides: Additional values to override loaded config.
Returns:
Configured GlobalConfig instance.
"""
return _get_default_loader().load_global_config_from_env(**overrides)