Files
fastapi-traffic/tests/test_integration.py

409 lines
14 KiB
Python

"""Integration tests for fastapi-traffic.
End-to-end tests that verify the complete rate limiting flow
across different configurations and usage patterns.
"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
import pytest
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from httpx import ASGITransport, AsyncClient
from fastapi_traffic import (
Algorithm,
MemoryBackend,
RateLimiter,
RateLimitExceeded,
rate_limit,
)
from fastapi_traffic.core.config import RateLimitConfig
from fastapi_traffic.core.limiter import set_limiter
from fastapi_traffic.middleware import RateLimitMiddleware
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
class TestFullApplicationFlow:
"""Integration tests for a complete application setup."""
@pytest.fixture
async def full_app(self) -> AsyncGenerator[FastAPI, None]:
"""Create a fully configured application."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(
request: Request, exc: RateLimitExceeded
) -> JSONResponse:
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"message": exc.message,
"retry_after": exc.retry_after,
},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
@app.get("/api/v1/users")
@rate_limit(10, 60)
async def list_users(request: Request) -> dict[str, object]:
return {"users": [], "count": 0}
@app.get("/api/v1/users/{user_id}")
@rate_limit(20, 60)
async def get_user(request: Request, user_id: int) -> dict[str, object]:
return {"id": user_id, "name": f"User {user_id}"}
@app.post("/api/v1/users")
@rate_limit(5, window_size=60, cost=2)
async def create_user(request: Request) -> dict[str, object]:
return {"id": 1, "created": True}
def get_api_key(request: Request) -> str:
return request.headers.get("X-API-Key", "anonymous")
@app.get("/api/v1/premium")
@rate_limit(100, window_size=60, key_extractor=get_api_key)
async def premium_endpoint(request: Request) -> dict[str, str]:
return {"tier": "premium"}
yield app
@pytest.fixture
async def client(self, full_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client with lifespan."""
transport = ASGITransport(app=full_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_different_endpoints_have_separate_limits(
self, client: AsyncClient
) -> None:
"""Test that different endpoints maintain separate rate limits."""
for _ in range(10):
response = await client.get("/api/v1/users")
assert response.status_code == 200
response = await client.get("/api/v1/users")
assert response.status_code == 429
response = await client.get("/api/v1/users/1")
assert response.status_code == 200
async def test_cost_based_limiting(self, client: AsyncClient) -> None:
"""Test that cost parameter affects rate limiting."""
for _ in range(2):
response = await client.post("/api/v1/users")
assert response.status_code == 200
response = await client.post("/api/v1/users")
assert response.status_code == 429
async def test_api_key_based_limiting(self, client: AsyncClient) -> None:
"""Test rate limiting by API key."""
for _ in range(5):
response = await client.get(
"/api/v1/premium", headers={"X-API-Key": "key-a"}
)
assert response.status_code == 200
for _ in range(5):
response = await client.get(
"/api/v1/premium", headers={"X-API-Key": "key-b"}
)
assert response.status_code == 200
async def test_basic_rate_limiting_works(self, client: AsyncClient) -> None:
"""Test that basic rate limiting is functional."""
# Make a request and verify it works
response = await client.get("/api/v1/users/1")
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
async def test_retry_after_in_429_response(self, client: AsyncClient) -> None:
"""Test that 429 responses include Retry-After header."""
for _ in range(10):
await client.get("/api/v1/users")
response = await client.get("/api/v1/users")
assert response.status_code == 429
assert "Retry-After" in response.headers
data = response.json()
assert data["error"] == "rate_limit_exceeded"
assert data["retry_after"] is not None
class TestMixedDecoratorAndMiddleware:
"""Test combining decorator and middleware rate limiting."""
@pytest.fixture
async def mixed_app(self) -> AsyncGenerator[FastAPI, None]:
"""Create app with both middleware and decorator limiting."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
RateLimitMiddleware,
limit=20,
window_size=60,
backend=backend,
exempt_paths={"/health"},
key_prefix="global",
)
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "healthy"}
@app.get("/api/strict")
@rate_limit(3, 60)
async def strict_endpoint(request: Request) -> dict[str, str]:
return {"status": "ok"}
@app.get("/api/normal")
async def normal_endpoint() -> dict[str, str]:
return {"status": "ok"}
yield app
@pytest.fixture
async def client(self, mixed_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=mixed_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_health_bypasses_middleware(self, client: AsyncClient) -> None:
"""Test that health endpoint bypasses middleware limiting."""
for _ in range(30):
response = await client.get("/health")
assert response.status_code == 200
async def test_decorator_limit_stricter_than_middleware(
self, client: AsyncClient
) -> None:
"""Test that decorator limit is enforced before middleware limit."""
for _ in range(3):
response = await client.get("/api/strict")
assert response.status_code == 200
response = await client.get("/api/strict")
assert response.status_code == 429
async def test_middleware_limit_applies_to_normal_endpoints(
self, client: AsyncClient
) -> None:
"""Test that middleware limit applies to non-decorated endpoints."""
for _ in range(20):
response = await client.get("/api/normal")
assert response.status_code == 200
response = await client.get("/api/normal")
assert response.status_code == 429
class TestConcurrentRequests:
"""Test rate limiting under concurrent load."""
@pytest.fixture
async def concurrent_app(self) -> AsyncGenerator[FastAPI, None]:
"""Create app for concurrent testing."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
@app.get("/api/resource")
@rate_limit(10, 60)
async def resource(request: Request) -> dict[str, str]:
await asyncio.sleep(0.01)
return {"status": "ok"}
yield app
async def test_concurrent_requests_respect_limit(
self, concurrent_app: FastAPI
) -> None:
"""Test that concurrent requests respect rate limit."""
transport = ASGITransport(app=concurrent_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
async def make_request() -> int:
response = await client.get("/api/resource")
return response.status_code
results = await asyncio.gather(*[make_request() for _ in range(15)])
success_count = sum(1 for r in results if r == 200)
rate_limited_count = sum(1 for r in results if r == 429)
assert success_count == 10
assert rate_limited_count == 5
class TestLimiterStateManagement:
"""Test RateLimiter state management."""
async def test_limiter_reset_clears_state(self) -> None:
"""Test that reset clears rate limit state."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
await limiter.initialize()
try:
config = RateLimitConfig(limit=3, window_size=60)
class MockRequest:
def __init__(self) -> None:
self.url = type("URL", (), {"path": "/test"})()
self.method = "GET"
self.client = type("Client", (), {"host": "127.0.0.1"})()
self.headers: dict[str, str] = {}
request = MockRequest()
for _ in range(3):
result = await limiter.check(request, config) # type: ignore[arg-type]
assert result.allowed
result = await limiter.check(request, config) # type: ignore[arg-type]
assert not result.allowed
await limiter.reset(request, config) # type: ignore[arg-type]
result = await limiter.check(request, config) # type: ignore[arg-type]
assert result.allowed
finally:
await limiter.close()
async def test_get_state_returns_current_info(self) -> None:
"""Test that get_state returns current rate limit info."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
await limiter.initialize()
try:
config = RateLimitConfig(limit=5, window_size=60)
class MockRequest:
def __init__(self) -> None:
self.url = type("URL", (), {"path": "/test"})()
self.method = "GET"
self.client = type("Client", (), {"host": "127.0.0.1"})()
self.headers: dict[str, str] = {}
request = MockRequest()
await limiter.check(request, config) # type: ignore[arg-type]
await limiter.check(request, config) # type: ignore[arg-type]
state = await limiter.get_state(request, config) # type: ignore[arg-type]
assert state is not None
assert state.remaining == 3
finally:
await limiter.close()
class TestMultipleAlgorithms:
"""Test different algorithms in the same application."""
@pytest.fixture
async def multi_algo_app(self) -> AsyncGenerator[FastAPI, None]:
"""Create app with multiple algorithms."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await limiter.initialize()
set_limiter(limiter)
yield
await limiter.close()
app = FastAPI(lifespan=lifespan)
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
@app.get("/sliding-window")
@rate_limit(5, window_size=60, algorithm=Algorithm.SLIDING_WINDOW)
async def sliding_window(request: Request) -> dict[str, str]:
return {"algorithm": "sliding_window"}
@app.get("/fixed-window")
@rate_limit(5, window_size=60, algorithm=Algorithm.FIXED_WINDOW)
async def fixed_window(request: Request) -> dict[str, str]:
return {"algorithm": "fixed_window"}
@app.get("/token-bucket")
@rate_limit(5, window_size=60, algorithm=Algorithm.TOKEN_BUCKET)
async def token_bucket(request: Request) -> dict[str, str]:
return {"algorithm": "token_bucket"}
yield app
@pytest.fixture
async def client(
self, multi_algo_app: FastAPI
) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=multi_algo_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_all_algorithms_enforce_limits(self, client: AsyncClient) -> None:
"""Test that all algorithms enforce their limits."""
endpoints = ["/sliding-window", "/fixed-window", "/token-bucket"]
for endpoint in endpoints:
for i in range(5):
response = await client.get(endpoint)
assert response.status_code == 200, f"{endpoint} request {i} failed"
response = await client.get(endpoint)
assert response.status_code == 429, f"{endpoint} should be rate limited"