Files
fastapi-traffic/tests/conftest.py

194 lines
5.4 KiB
Python

"""Shared fixtures and configuration for tests."""
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
import pytest
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from httpx import ASGITransport, AsyncClient
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimiter,
RateLimitExceeded,
SQLiteBackend,
rate_limit,
)
from fastapi_traffic.core.config import RateLimitConfig
from fastapi_traffic.core.limiter import set_limiter
from fastapi_traffic.middleware import RateLimitMiddleware
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator
pass
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""Create an event loop for the test session."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
async def memory_backend() -> AsyncGenerator[MemoryBackend, None]:
"""Create a fresh memory backend for each test."""
backend = MemoryBackend(max_size=1000, cleanup_interval=60.0)
yield backend
await backend.close()
@pytest.fixture
async def sqlite_backend(tmp_path: object) -> AsyncGenerator[SQLiteBackend, None]:
"""Create an in-memory SQLite backend for each test."""
backend = SQLiteBackend(":memory:", cleanup_interval=60.0)
await backend.initialize()
yield backend
await backend.close()
@pytest.fixture
async def limiter(memory_backend: MemoryBackend) -> AsyncGenerator[RateLimiter, None]:
"""Create a rate limiter with memory backend."""
limiter = RateLimiter(memory_backend)
await limiter.initialize()
set_limiter(limiter)
yield limiter
await limiter.close()
@pytest.fixture
def rate_limit_config() -> RateLimitConfig:
"""Create a default rate limit config for testing."""
return RateLimitConfig(
limit=10,
window_size=60.0,
algorithm=Algorithm.SLIDING_WINDOW_COUNTER,
)
@pytest.fixture
def app(limiter: RateLimiter) -> FastAPI:
"""Create a FastAPI app with rate limiting configured."""
app = FastAPI()
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(
request: Request, exc: RateLimitExceeded
) -> JSONResponse:
return JSONResponse(
status_code=429,
content={
"detail": exc.message,
"retry_after": exc.retry_after,
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
@app.get("/limited")
@rate_limit(5, 60)
async def limited_endpoint(request: Request) -> dict[str, str]:
return {"message": "success"}
@app.get("/unlimited")
async def unlimited_endpoint() -> dict[str, str]:
return {"message": "no limit"}
def api_key_extractor(request: Request) -> str:
return request.headers.get("X-API-Key", "anon")
@app.get("/custom-key")
@rate_limit(5, window_size=60, key_extractor=api_key_extractor)
async def custom_key_endpoint(request: Request) -> dict[str, str]:
return {"message": "success"}
@app.get("/token-bucket")
@rate_limit(10, window_size=60, algorithm=Algorithm.TOKEN_BUCKET, burst_size=5)
async def token_bucket_endpoint(request: Request) -> dict[str, str]:
return {"message": "success"}
@app.get("/high-cost")
@rate_limit(10, window_size=60, cost=3)
async def high_cost_endpoint(request: Request) -> dict[str, str]:
return {"message": "success"}
return app
@pytest.fixture
async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create an async test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
@pytest.fixture
def app_with_middleware(memory_backend: MemoryBackend) -> FastAPI:
"""Create a FastAPI app with rate limit middleware."""
app = FastAPI()
app.add_middleware(
RateLimitMiddleware,
limit=10,
window_size=60,
backend=memory_backend,
exempt_paths={"/health"},
exempt_ips={"192.168.1.100"},
)
@app.get("/api/resource")
async def resource() -> dict[str, str]:
return {"message": "success"}
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok"}
return app
@pytest.fixture
async def middleware_client(
app_with_middleware: FastAPI,
) -> AsyncGenerator[AsyncClient, None]:
"""Create an async test client for middleware tests."""
transport = ASGITransport(app=app_with_middleware)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
class MockRequest:
"""Mock request object for unit tests."""
def __init__(
self,
path: str = "/test",
method: str = "GET",
client_host: str = "127.0.0.1",
headers: dict[str, str] | None = None,
) -> None:
self.url = type("URL", (), {"path": path})()
self.method = method
self.client = type("Client", (), {"host": client_host})()
self._headers = headers or {}
@property
def headers(self) -> dict[str, str]:
return self._headers
def get(self, key: str, default: str | None = None) -> str | None:
return self._headers.get(key, default)
@pytest.fixture
def mock_request() -> MockRequest:
"""Create a mock request for unit tests."""
return MockRequest()