194 lines
5.4 KiB
Python
194 lines
5.4 KiB
Python
"""Shared fixtures and configuration for tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from fastapi_traffic import (
|
|
Algorithm,
|
|
MemoryBackend,
|
|
RateLimiter,
|
|
RateLimitExceeded,
|
|
SQLiteBackend,
|
|
rate_limit,
|
|
)
|
|
from fastapi_traffic.core.config import RateLimitConfig
|
|
from fastapi_traffic.core.limiter import set_limiter
|
|
from fastapi_traffic.middleware import RateLimitMiddleware
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import AsyncGenerator, Generator
|
|
|
|
pass
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
|
"""Create an event loop for the test session."""
|
|
loop = asyncio.new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture
|
|
async def memory_backend() -> AsyncGenerator[MemoryBackend, None]:
|
|
"""Create a fresh memory backend for each test."""
|
|
backend = MemoryBackend(max_size=1000, cleanup_interval=60.0)
|
|
yield backend
|
|
await backend.close()
|
|
|
|
|
|
@pytest.fixture
|
|
async def sqlite_backend(tmp_path: object) -> AsyncGenerator[SQLiteBackend, None]:
|
|
"""Create an in-memory SQLite backend for each test."""
|
|
backend = SQLiteBackend(":memory:", cleanup_interval=60.0)
|
|
await backend.initialize()
|
|
yield backend
|
|
await backend.close()
|
|
|
|
|
|
@pytest.fixture
|
|
async def limiter(memory_backend: MemoryBackend) -> AsyncGenerator[RateLimiter, None]:
|
|
"""Create a rate limiter with memory backend."""
|
|
limiter = RateLimiter(memory_backend)
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield limiter
|
|
await limiter.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def rate_limit_config() -> RateLimitConfig:
|
|
"""Create a default rate limit config for testing."""
|
|
return RateLimitConfig(
|
|
limit=10,
|
|
window_size=60.0,
|
|
algorithm=Algorithm.SLIDING_WINDOW_COUNTER,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def app(limiter: RateLimiter) -> FastAPI:
|
|
"""Create a FastAPI app with rate limiting configured."""
|
|
app = FastAPI()
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def rate_limit_handler(
|
|
request: Request, exc: RateLimitExceeded
|
|
) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"detail": exc.message,
|
|
"retry_after": exc.retry_after,
|
|
},
|
|
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
|
)
|
|
|
|
@app.get("/limited")
|
|
@rate_limit(5, 60)
|
|
async def limited_endpoint(request: Request) -> dict[str, str]:
|
|
return {"message": "success"}
|
|
|
|
@app.get("/unlimited")
|
|
async def unlimited_endpoint() -> dict[str, str]:
|
|
return {"message": "no limit"}
|
|
|
|
def api_key_extractor(request: Request) -> str:
|
|
return request.headers.get("X-API-Key", "anon")
|
|
|
|
@app.get("/custom-key")
|
|
@rate_limit(5, window_size=60, key_extractor=api_key_extractor)
|
|
async def custom_key_endpoint(request: Request) -> dict[str, str]:
|
|
return {"message": "success"}
|
|
|
|
@app.get("/token-bucket")
|
|
@rate_limit(10, window_size=60, algorithm=Algorithm.TOKEN_BUCKET, burst_size=5)
|
|
async def token_bucket_endpoint(request: Request) -> dict[str, str]:
|
|
return {"message": "success"}
|
|
|
|
@app.get("/high-cost")
|
|
@rate_limit(10, window_size=60, cost=3)
|
|
async def high_cost_endpoint(request: Request) -> dict[str, str]:
|
|
return {"message": "success"}
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create an async test client."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
|
|
@pytest.fixture
|
|
def app_with_middleware(memory_backend: MemoryBackend) -> FastAPI:
|
|
"""Create a FastAPI app with rate limit middleware."""
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=10,
|
|
window_size=60,
|
|
backend=memory_backend,
|
|
exempt_paths={"/health"},
|
|
exempt_ips={"192.168.1.100"},
|
|
)
|
|
|
|
@app.get("/api/resource")
|
|
async def resource() -> dict[str, str]:
|
|
return {"message": "success"}
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
async def middleware_client(
|
|
app_with_middleware: FastAPI,
|
|
) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create an async test client for middleware tests."""
|
|
transport = ASGITransport(app=app_with_middleware)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
|
|
class MockRequest:
|
|
"""Mock request object for unit tests."""
|
|
|
|
def __init__(
|
|
self,
|
|
path: str = "/test",
|
|
method: str = "GET",
|
|
client_host: str = "127.0.0.1",
|
|
headers: dict[str, str] | None = None,
|
|
) -> None:
|
|
self.url = type("URL", (), {"path": path})()
|
|
self.method = method
|
|
self.client = type("Client", (), {"host": client_host})()
|
|
self._headers = headers or {}
|
|
|
|
@property
|
|
def headers(self) -> dict[str, str]:
|
|
return self._headers
|
|
|
|
def get(self, key: str, default: str | None = None) -> str | None:
|
|
return self._headers.get(key, default)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_request() -> MockRequest:
|
|
"""Create a mock request for unit tests."""
|
|
return MockRequest()
|