409 lines
14 KiB
Python
409 lines
14 KiB
Python
"""Integration tests for fastapi-traffic.
|
|
|
|
End-to-end tests that verify the complete rate limiting flow
|
|
across different configurations and usage patterns.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from fastapi_traffic import (
|
|
Algorithm,
|
|
MemoryBackend,
|
|
RateLimiter,
|
|
RateLimitExceeded,
|
|
rate_limit,
|
|
)
|
|
from fastapi_traffic.core.config import RateLimitConfig
|
|
from fastapi_traffic.core.limiter import set_limiter
|
|
from fastapi_traffic.middleware import RateLimitMiddleware
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import AsyncGenerator
|
|
|
|
|
|
class TestFullApplicationFlow:
|
|
"""Integration tests for a complete application setup."""
|
|
|
|
@pytest.fixture
|
|
async def full_app(self) -> AsyncGenerator[FastAPI, None]:
|
|
"""Create a fully configured application."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield
|
|
await limiter.close()
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def rate_limit_handler(
|
|
request: Request, exc: RateLimitExceeded
|
|
) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "rate_limit_exceeded",
|
|
"message": exc.message,
|
|
"retry_after": exc.retry_after,
|
|
},
|
|
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
|
)
|
|
|
|
@app.get("/api/v1/users")
|
|
@rate_limit(10, 60)
|
|
async def list_users(request: Request) -> dict[str, object]:
|
|
return {"users": [], "count": 0}
|
|
|
|
@app.get("/api/v1/users/{user_id}")
|
|
@rate_limit(20, 60)
|
|
async def get_user(request: Request, user_id: int) -> dict[str, object]:
|
|
return {"id": user_id, "name": f"User {user_id}"}
|
|
|
|
@app.post("/api/v1/users")
|
|
@rate_limit(5, window_size=60, cost=2)
|
|
async def create_user(request: Request) -> dict[str, object]:
|
|
return {"id": 1, "created": True}
|
|
|
|
def get_api_key(request: Request) -> str:
|
|
return request.headers.get("X-API-Key", "anonymous")
|
|
|
|
@app.get("/api/v1/premium")
|
|
@rate_limit(100, window_size=60, key_extractor=get_api_key)
|
|
async def premium_endpoint(request: Request) -> dict[str, str]:
|
|
return {"tier": "premium"}
|
|
|
|
yield app
|
|
|
|
@pytest.fixture
|
|
async def client(self, full_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client with lifespan."""
|
|
transport = ASGITransport(app=full_app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_different_endpoints_have_separate_limits(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Test that different endpoints maintain separate rate limits."""
|
|
for _ in range(10):
|
|
response = await client.get("/api/v1/users")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/api/v1/users")
|
|
assert response.status_code == 429
|
|
|
|
response = await client.get("/api/v1/users/1")
|
|
assert response.status_code == 200
|
|
|
|
async def test_cost_based_limiting(self, client: AsyncClient) -> None:
|
|
"""Test that cost parameter affects rate limiting."""
|
|
for _ in range(2):
|
|
response = await client.post("/api/v1/users")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.post("/api/v1/users")
|
|
assert response.status_code == 429
|
|
|
|
async def test_api_key_based_limiting(self, client: AsyncClient) -> None:
|
|
"""Test rate limiting by API key."""
|
|
for _ in range(5):
|
|
response = await client.get(
|
|
"/api/v1/premium", headers={"X-API-Key": "key-a"}
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
for _ in range(5):
|
|
response = await client.get(
|
|
"/api/v1/premium", headers={"X-API-Key": "key-b"}
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
async def test_basic_rate_limiting_works(self, client: AsyncClient) -> None:
|
|
"""Test that basic rate limiting is functional."""
|
|
# Make a request and verify it works
|
|
response = await client.get("/api/v1/users/1")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == 1
|
|
|
|
async def test_retry_after_in_429_response(self, client: AsyncClient) -> None:
|
|
"""Test that 429 responses include Retry-After header."""
|
|
for _ in range(10):
|
|
await client.get("/api/v1/users")
|
|
|
|
response = await client.get("/api/v1/users")
|
|
assert response.status_code == 429
|
|
assert "Retry-After" in response.headers
|
|
data = response.json()
|
|
assert data["error"] == "rate_limit_exceeded"
|
|
assert data["retry_after"] is not None
|
|
|
|
|
|
class TestMixedDecoratorAndMiddleware:
|
|
"""Test combining decorator and middleware rate limiting."""
|
|
|
|
@pytest.fixture
|
|
async def mixed_app(self) -> AsyncGenerator[FastAPI, None]:
|
|
"""Create app with both middleware and decorator limiting."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield
|
|
await limiter.close()
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=20,
|
|
window_size=60,
|
|
backend=backend,
|
|
exempt_paths={"/health"},
|
|
key_prefix="global",
|
|
)
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(status_code=429, content={"detail": exc.message})
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict[str, str]:
|
|
return {"status": "healthy"}
|
|
|
|
@app.get("/api/strict")
|
|
@rate_limit(3, 60)
|
|
async def strict_endpoint(request: Request) -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/api/normal")
|
|
async def normal_endpoint() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
yield app
|
|
|
|
@pytest.fixture
|
|
async def client(self, mixed_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=mixed_app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_health_bypasses_middleware(self, client: AsyncClient) -> None:
|
|
"""Test that health endpoint bypasses middleware limiting."""
|
|
for _ in range(30):
|
|
response = await client.get("/health")
|
|
assert response.status_code == 200
|
|
|
|
async def test_decorator_limit_stricter_than_middleware(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Test that decorator limit is enforced before middleware limit."""
|
|
for _ in range(3):
|
|
response = await client.get("/api/strict")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/api/strict")
|
|
assert response.status_code == 429
|
|
|
|
async def test_middleware_limit_applies_to_normal_endpoints(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Test that middleware limit applies to non-decorated endpoints."""
|
|
for _ in range(20):
|
|
response = await client.get("/api/normal")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/api/normal")
|
|
assert response.status_code == 429
|
|
|
|
|
|
class TestConcurrentRequests:
|
|
"""Test rate limiting under concurrent load."""
|
|
|
|
@pytest.fixture
|
|
async def concurrent_app(self) -> AsyncGenerator[FastAPI, None]:
|
|
"""Create app for concurrent testing."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield
|
|
await limiter.close()
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(status_code=429, content={"detail": exc.message})
|
|
|
|
@app.get("/api/resource")
|
|
@rate_limit(10, 60)
|
|
async def resource(request: Request) -> dict[str, str]:
|
|
await asyncio.sleep(0.01)
|
|
return {"status": "ok"}
|
|
|
|
yield app
|
|
|
|
async def test_concurrent_requests_respect_limit(
|
|
self, concurrent_app: FastAPI
|
|
) -> None:
|
|
"""Test that concurrent requests respect rate limit."""
|
|
transport = ASGITransport(app=concurrent_app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
|
|
async def make_request() -> int:
|
|
response = await client.get("/api/resource")
|
|
return response.status_code
|
|
|
|
results = await asyncio.gather(*[make_request() for _ in range(15)])
|
|
|
|
success_count = sum(1 for r in results if r == 200)
|
|
rate_limited_count = sum(1 for r in results if r == 429)
|
|
|
|
assert success_count == 10
|
|
assert rate_limited_count == 5
|
|
|
|
|
|
class TestLimiterStateManagement:
|
|
"""Test RateLimiter state management."""
|
|
|
|
async def test_limiter_reset_clears_state(self) -> None:
|
|
"""Test that reset clears rate limit state."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
await limiter.initialize()
|
|
|
|
try:
|
|
config = RateLimitConfig(limit=3, window_size=60)
|
|
|
|
class MockRequest:
|
|
def __init__(self) -> None:
|
|
self.url = type("URL", (), {"path": "/test"})()
|
|
self.method = "GET"
|
|
self.client = type("Client", (), {"host": "127.0.0.1"})()
|
|
self.headers: dict[str, str] = {}
|
|
|
|
request = MockRequest()
|
|
|
|
for _ in range(3):
|
|
result = await limiter.check(request, config) # type: ignore[arg-type]
|
|
assert result.allowed
|
|
|
|
result = await limiter.check(request, config) # type: ignore[arg-type]
|
|
assert not result.allowed
|
|
|
|
await limiter.reset(request, config) # type: ignore[arg-type]
|
|
|
|
result = await limiter.check(request, config) # type: ignore[arg-type]
|
|
assert result.allowed
|
|
finally:
|
|
await limiter.close()
|
|
|
|
async def test_get_state_returns_current_info(self) -> None:
|
|
"""Test that get_state returns current rate limit info."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
await limiter.initialize()
|
|
|
|
try:
|
|
config = RateLimitConfig(limit=5, window_size=60)
|
|
|
|
class MockRequest:
|
|
def __init__(self) -> None:
|
|
self.url = type("URL", (), {"path": "/test"})()
|
|
self.method = "GET"
|
|
self.client = type("Client", (), {"host": "127.0.0.1"})()
|
|
self.headers: dict[str, str] = {}
|
|
|
|
request = MockRequest()
|
|
|
|
await limiter.check(request, config) # type: ignore[arg-type]
|
|
await limiter.check(request, config) # type: ignore[arg-type]
|
|
|
|
state = await limiter.get_state(request, config) # type: ignore[arg-type]
|
|
assert state is not None
|
|
assert state.remaining == 3
|
|
finally:
|
|
await limiter.close()
|
|
|
|
|
|
class TestMultipleAlgorithms:
|
|
"""Test different algorithms in the same application."""
|
|
|
|
@pytest.fixture
|
|
async def multi_algo_app(self) -> AsyncGenerator[FastAPI, None]:
|
|
"""Create app with multiple algorithms."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield
|
|
await limiter.close()
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(status_code=429, content={"detail": exc.message})
|
|
|
|
@app.get("/sliding-window")
|
|
@rate_limit(5, window_size=60, algorithm=Algorithm.SLIDING_WINDOW)
|
|
async def sliding_window(request: Request) -> dict[str, str]:
|
|
return {"algorithm": "sliding_window"}
|
|
|
|
@app.get("/fixed-window")
|
|
@rate_limit(5, window_size=60, algorithm=Algorithm.FIXED_WINDOW)
|
|
async def fixed_window(request: Request) -> dict[str, str]:
|
|
return {"algorithm": "fixed_window"}
|
|
|
|
@app.get("/token-bucket")
|
|
@rate_limit(5, window_size=60, algorithm=Algorithm.TOKEN_BUCKET)
|
|
async def token_bucket(request: Request) -> dict[str, str]:
|
|
return {"algorithm": "token_bucket"}
|
|
|
|
yield app
|
|
|
|
@pytest.fixture
|
|
async def client(
|
|
self, multi_algo_app: FastAPI
|
|
) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=multi_algo_app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_all_algorithms_enforce_limits(self, client: AsyncClient) -> None:
|
|
"""Test that all algorithms enforce their limits."""
|
|
endpoints = ["/sliding-window", "/fixed-window", "/token-bucket"]
|
|
|
|
for endpoint in endpoints:
|
|
for i in range(5):
|
|
response = await client.get(endpoint)
|
|
assert response.status_code == 200, f"{endpoint} request {i} failed"
|
|
|
|
response = await client.get(endpoint)
|
|
assert response.status_code == 429, f"{endpoint} should be rate limited"
|