302 lines
10 KiB
Python
302 lines
10 KiB
Python
"""SQLite backend for rate limiting - persistent storage for single-node deployments."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import json
|
|
import sqlite3
|
|
import time
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from fastapi_traffic.backends.base import Backend
|
|
from fastapi_traffic.exceptions import BackendError
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
|
|
class SQLiteBackend(Backend):
|
|
"""SQLite-based backend with connection pooling and async support."""
|
|
|
|
__slots__ = (
|
|
"_cleanup_interval",
|
|
"_cleanup_task",
|
|
"_connection",
|
|
"_connections",
|
|
"_db_path",
|
|
"_lock",
|
|
"_pool_size",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
db_path: str | Path = ":memory:",
|
|
*,
|
|
cleanup_interval: float = 300.0,
|
|
pool_size: int = 5,
|
|
) -> None:
|
|
"""Initialize the SQLite backend.
|
|
|
|
Args:
|
|
db_path: Path to SQLite database file or ":memory:" for in-memory.
|
|
cleanup_interval: Interval in seconds for cleaning expired entries.
|
|
pool_size: Number of connections in the pool.
|
|
"""
|
|
self._db_path = str(db_path)
|
|
self._connection: sqlite3.Connection | None = None
|
|
self._lock = asyncio.Lock()
|
|
self._cleanup_interval = cleanup_interval
|
|
self._cleanup_task: asyncio.Task[None] | None = None
|
|
self._pool_size = pool_size
|
|
self._connections: list[sqlite3.Connection] = []
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the database and create tables."""
|
|
await self._ensure_connection()
|
|
await self._create_tables()
|
|
if self._cleanup_task is None:
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
|
|
async def _ensure_connection(self) -> sqlite3.Connection:
|
|
"""Ensure a database connection exists."""
|
|
if self._connection is None:
|
|
loop = asyncio.get_event_loop()
|
|
self._connection = await loop.run_in_executor(None, self._create_connection)
|
|
assert self._connection is not None
|
|
return self._connection
|
|
|
|
def _create_connection(self) -> sqlite3.Connection:
|
|
"""Create a new SQLite connection with optimized settings."""
|
|
conn = sqlite3.connect(
|
|
self._db_path,
|
|
check_same_thread=False,
|
|
isolation_level=None,
|
|
)
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute("PRAGMA synchronous=NORMAL")
|
|
conn.execute("PRAGMA cache_size=10000")
|
|
conn.execute("PRAGMA temp_store=MEMORY")
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
async def _create_tables(self) -> None:
|
|
"""Create the rate limit tables."""
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(None, self._create_tables_sync, conn)
|
|
|
|
def _create_tables_sync(self, conn: sqlite3.Connection) -> None:
|
|
"""Synchronously create tables."""
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS rate_limits (
|
|
key TEXT PRIMARY KEY,
|
|
data TEXT NOT NULL,
|
|
expires_at REAL NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
conn.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_expires_at ON rate_limits(expires_at)
|
|
"""
|
|
)
|
|
|
|
async def _cleanup_loop(self) -> None:
|
|
"""Background task to clean up expired entries."""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(self._cleanup_interval)
|
|
await self._cleanup_expired()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception:
|
|
pass
|
|
|
|
async def _cleanup_expired(self) -> None:
|
|
"""Remove expired entries."""
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(
|
|
None,
|
|
lambda: conn.execute(
|
|
"DELETE FROM rate_limits WHERE expires_at <= ?", (time.time(),)
|
|
),
|
|
)
|
|
except Exception as e:
|
|
raise BackendError("Failed to cleanup expired entries", original_error=e)
|
|
|
|
async def get(self, key: str) -> dict[str, Any] | None:
|
|
"""Get the current state for a key."""
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _get() -> dict[str, Any] | None:
|
|
cursor = conn.execute(
|
|
"SELECT data, expires_at FROM rate_limits WHERE key = ?",
|
|
(key,),
|
|
)
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
|
|
expires_at = row["expires_at"]
|
|
if expires_at <= time.time():
|
|
conn.execute("DELETE FROM rate_limits WHERE key = ?", (key,))
|
|
return None
|
|
|
|
data: dict[str, Any] = json.loads(row["data"])
|
|
return data
|
|
|
|
return await loop.run_in_executor(None, _get)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to get key {key}", original_error=e)
|
|
|
|
async def set(self, key: str, value: dict[str, Any], *, ttl: float) -> None:
|
|
"""Set the state for a key with TTL."""
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
expires_at = time.time() + ttl
|
|
data_json = json.dumps(value)
|
|
|
|
def _set() -> None:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO rate_limits (key, data, expires_at)
|
|
VALUES (?, ?, ?)
|
|
""",
|
|
(key, data_json, expires_at),
|
|
)
|
|
|
|
await loop.run_in_executor(None, _set)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to set key {key}", original_error=e)
|
|
|
|
async def delete(self, key: str) -> None:
|
|
"""Delete the state for a key."""
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(
|
|
None,
|
|
lambda: conn.execute("DELETE FROM rate_limits WHERE key = ?", (key,)),
|
|
)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to delete key {key}", original_error=e)
|
|
|
|
async def exists(self, key: str) -> bool:
|
|
"""Check if a key exists and is not expired."""
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _exists() -> bool:
|
|
cursor = conn.execute(
|
|
"SELECT 1 FROM rate_limits WHERE key = ? AND expires_at > ?",
|
|
(key, time.time()),
|
|
)
|
|
return cursor.fetchone() is not None
|
|
|
|
return await loop.run_in_executor(None, _exists)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to check key {key}", original_error=e)
|
|
|
|
async def increment(self, key: str, amount: int = 1) -> int:
|
|
"""Atomically increment a counter."""
|
|
async with self._lock:
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _increment() -> int:
|
|
cursor = conn.execute(
|
|
"SELECT data, expires_at FROM rate_limits WHERE key = ?",
|
|
(key,),
|
|
)
|
|
row = cursor.fetchone()
|
|
|
|
if row is None or row["expires_at"] <= time.time():
|
|
return amount
|
|
|
|
data: dict[str, Any] = json.loads(row["data"])
|
|
current = int(data.get("count", 0))
|
|
new_value = current + amount
|
|
data["count"] = new_value
|
|
|
|
conn.execute(
|
|
"UPDATE rate_limits SET data = ? WHERE key = ?",
|
|
(json.dumps(data), key),
|
|
)
|
|
return new_value
|
|
|
|
return await loop.run_in_executor(None, _increment)
|
|
except Exception as e:
|
|
raise BackendError(f"Failed to increment key {key}", original_error=e)
|
|
|
|
async def clear(self) -> None:
|
|
"""Clear all rate limit data."""
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(
|
|
None, lambda: conn.execute("DELETE FROM rate_limits")
|
|
)
|
|
except Exception as e:
|
|
raise BackendError("Failed to clear rate limits", original_error=e)
|
|
|
|
async def close(self) -> None:
|
|
"""Close the database connection."""
|
|
if self._cleanup_task is not None:
|
|
self._cleanup_task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await self._cleanup_task
|
|
self._cleanup_task = None
|
|
|
|
if self._connection is not None:
|
|
self._connection.close()
|
|
self._connection = None
|
|
|
|
for conn in self._connections:
|
|
conn.close()
|
|
self._connections.clear()
|
|
|
|
async def vacuum(self) -> None:
|
|
"""Optimize the database by running VACUUM."""
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(None, lambda: conn.execute("VACUUM"))
|
|
except Exception as e:
|
|
raise BackendError("Failed to vacuum database", original_error=e)
|
|
|
|
async def get_stats(self) -> dict[str, Any]:
|
|
"""Get statistics about the rate limit storage."""
|
|
try:
|
|
conn = await self._ensure_connection()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _stats() -> dict[str, Any]:
|
|
cursor = conn.execute("SELECT COUNT(*) as total FROM rate_limits")
|
|
total = cursor.fetchone()["total"]
|
|
|
|
cursor = conn.execute(
|
|
"SELECT COUNT(*) as active FROM rate_limits WHERE expires_at > ?",
|
|
(time.time(),),
|
|
)
|
|
active = cursor.fetchone()["active"]
|
|
|
|
return {
|
|
"total_entries": total,
|
|
"active_entries": active,
|
|
"expired_entries": total - active,
|
|
"db_path": self._db_path,
|
|
}
|
|
|
|
return await loop.run_in_executor(None, _stats)
|
|
except Exception as e:
|
|
raise BackendError("Failed to get stats", original_error=e)
|