Files
fastapi-traffic/fastapi_traffic/backends/memory.py
zanewalker 9fe700296d Add get_stats to MemoryBackend and update pytest config
Added get_stats() method for consistency with RedisBackend. Also added
httpx and pytest-asyncio as dev dependencies.
2026-01-09 00:50:43 +00:00

153 lines
5.0 KiB
Python

"""In-memory backend for rate limiting - suitable for single-process applications."""
from __future__ import annotations
import asyncio
import time
from collections import OrderedDict
from typing import Any
from fastapi_traffic.backends.base import Backend
class MemoryBackend(Backend):
"""Thread-safe in-memory backend with LRU eviction and TTL support."""
__slots__ = ("_data", "_lock", "_max_size", "_cleanup_interval", "_cleanup_task")
def __init__(
self,
*,
max_size: int = 10000,
cleanup_interval: float = 60.0,
) -> None:
"""Initialize the memory backend.
Args:
max_size: Maximum number of entries to store (LRU eviction).
cleanup_interval: Interval in seconds for cleaning expired entries.
"""
self._data: OrderedDict[str, tuple[dict[str, Any], float]] = OrderedDict()
self._lock = asyncio.Lock()
self._max_size = max_size
self._cleanup_interval = cleanup_interval
self._cleanup_task: asyncio.Task[None] | None = None
async def start_cleanup(self) -> None:
"""Start the background cleanup task."""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def _cleanup_loop(self) -> None:
"""Background task to clean up expired entries."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception:
pass
async def _cleanup_expired(self) -> None:
"""Remove expired entries."""
now = time.time()
async with self._lock:
expired_keys = [
key for key, (_, expires_at) in self._data.items() if expires_at <= now
]
for key in expired_keys:
del self._data[key]
def _evict_if_needed(self) -> None:
"""Evict oldest entries if over max size (must be called with lock held)."""
while len(self._data) > self._max_size:
self._data.popitem(last=False)
async def get(self, key: str) -> dict[str, Any] | None:
"""Get the current state for a key."""
async with self._lock:
if key not in self._data:
return None
value, expires_at = self._data[key]
if expires_at <= time.time():
del self._data[key]
return None
self._data.move_to_end(key)
return value.copy()
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
"""Set the state for a key with TTL."""
expires_at = time.time() + ttl
async with self._lock:
self._data[key] = (value.copy(), expires_at)
self._data.move_to_end(key)
self._evict_if_needed()
async def delete(self, key: str) -> None:
"""Delete the state for a key."""
async with self._lock:
self._data.pop(key, None)
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
async with self._lock:
if key not in self._data:
return False
_, expires_at = self._data[key]
if expires_at <= time.time():
del self._data[key]
return False
return True
async def increment(self, key: str, amount: int = 1) -> int:
"""Atomically increment a counter."""
async with self._lock:
if key in self._data:
value, expires_at = self._data[key]
if expires_at > time.time():
current = int(value.get("count", 0))
new_value = current + amount
value["count"] = new_value
self._data[key] = (value, expires_at)
return new_value
return amount
async def clear(self) -> None:
"""Clear all rate limit data."""
async with self._lock:
self._data.clear()
async def close(self) -> None:
"""Stop cleanup task and clear data."""
if self._cleanup_task is not None:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
await self.clear()
async def ping(self) -> bool:
"""Check if the backend is available. Always returns True for memory backend."""
return True
async def get_stats(self) -> dict[str, Any]:
"""Get statistics about the rate limit storage."""
async with self._lock:
return {
"total_keys": len(self._data),
"max_size": self._max_size,
"backend": "memory",
}
def __len__(self) -> int:
"""Return the number of stored entries."""
return len(self._data)