Files
fastapi-traffic/tests/test_middleware.py

387 lines
12 KiB
Python

"""Tests for rate limiting middleware.
Comprehensive tests covering:
- Basic middleware functionality
- Path exemptions
- IP exemptions
- Custom key extractors
- Different algorithms
- Error handling and skip_on_error
- Header inclusion
- Multiple middleware configurations
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from fastapi_traffic import MemoryBackend
from fastapi_traffic.middleware import (
RateLimitMiddleware,
SlidingWindowMiddleware,
TokenBucketMiddleware,
)
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
class TestRateLimitMiddleware:
"""Tests for RateLimitMiddleware."""
@pytest.fixture
def app(self) -> FastAPI:
"""Create a test app with rate limit middleware."""
app = FastAPI()
backend = MemoryBackend()
app.add_middleware(
RateLimitMiddleware,
limit=5,
window_size=60,
backend=backend,
)
@app.get("/api/resource")
async def resource() -> dict[str, str]:
return {"status": "ok"}
@app.post("/api/create")
async def create() -> dict[str, str]:
return {"status": "created"}
return app
@pytest.fixture
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_allows_requests_within_limit(self, client: AsyncClient) -> None:
"""Test that requests within limit are allowed."""
for i in range(5):
response = await client.get("/api/resource")
assert response.status_code == 200, f"Request {i} should succeed"
async def test_blocks_requests_over_limit(self, client: AsyncClient) -> None:
"""Test that requests over limit are blocked."""
for _ in range(5):
await client.get("/api/resource")
response = await client.get("/api/resource")
assert response.status_code == 429
async def test_rate_limit_headers_included(self, client: AsyncClient) -> None:
"""Test that rate limit headers are included."""
response = await client.get("/api/resource")
assert "X-RateLimit-Limit" in response.headers
assert "X-RateLimit-Remaining" in response.headers
assert "X-RateLimit-Reset" in response.headers
async def test_different_endpoints_counted_separately(
self, client: AsyncClient
) -> None:
"""Test that different endpoints are counted separately by path."""
# Middleware includes path in the key by default
for _ in range(3):
response = await client.get("/api/resource")
assert response.status_code == 200
for _ in range(2):
response = await client.post("/api/create")
assert response.status_code == 200
async def test_rate_limit_response_format(self, client: AsyncClient) -> None:
"""Test rate limit response format."""
for _ in range(5):
await client.get("/api/resource")
response = await client.get("/api/resource")
assert response.status_code == 429
data = response.json()
assert "detail" in data
assert "retry_after" in data
class TestMiddlewareExemptions:
"""Tests for middleware path and IP exemptions."""
@pytest.fixture
def app(self) -> FastAPI:
"""Create app with exemptions configured."""
app = FastAPI()
backend = MemoryBackend()
app.add_middleware(
RateLimitMiddleware,
limit=2,
window_size=60,
backend=backend,
exempt_paths={"/health", "/metrics"},
exempt_ips={"10.0.0.1"},
)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "healthy"}
@app.get("/metrics")
async def metrics() -> dict[str, str]:
return {"requests": "100"}
@app.get("/api/data")
async def data() -> dict[str, str]:
return {"data": "value"}
return app
@pytest.fixture
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_exempt_paths_bypass_limit(self, client: AsyncClient) -> None:
"""Test that exempt paths bypass rate limiting."""
for _ in range(10):
response = await client.get("/health")
assert response.status_code == 200
async def test_multiple_exempt_paths(self, client: AsyncClient) -> None:
"""Test multiple exempt paths work correctly."""
for _ in range(5):
response = await client.get("/health")
assert response.status_code == 200
for _ in range(5):
response = await client.get("/metrics")
assert response.status_code == 200
async def test_non_exempt_paths_are_limited(self, client: AsyncClient) -> None:
"""Test that non-exempt paths are rate limited."""
for _ in range(2):
response = await client.get("/api/data")
assert response.status_code == 200
response = await client.get("/api/data")
assert response.status_code == 429
async def test_exempt_paths_dont_consume_limit(self, client: AsyncClient) -> None:
"""Test that exempt path requests don't consume rate limit."""
for _ in range(10):
await client.get("/health")
for _ in range(2):
response = await client.get("/api/data")
assert response.status_code == 200
class TestMiddlewareCustomKeyExtractor:
"""Tests for middleware with custom key extractor."""
@pytest.fixture
def app(self) -> FastAPI:
"""Create app with custom key extractor."""
app = FastAPI()
backend = MemoryBackend()
def user_id_extractor(request: object) -> str:
headers = getattr(request, "headers", {})
if hasattr(headers, "get"):
return headers.get("X-User-ID", "anonymous")
return "anonymous"
app.add_middleware(
RateLimitMiddleware,
limit=3,
window_size=60,
backend=backend,
key_extractor=user_id_extractor,
)
@app.get("/api/resource")
async def resource() -> dict[str, str]:
return {"status": "ok"}
return app
@pytest.fixture
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_different_users_have_separate_limits(
self, client: AsyncClient
) -> None:
"""Test that different users have separate rate limits."""
for _ in range(3):
response = await client.get(
"/api/resource", headers={"X-User-ID": "user-1"}
)
assert response.status_code == 200
response = await client.get("/api/resource", headers={"X-User-ID": "user-1"})
assert response.status_code == 429
response = await client.get("/api/resource", headers={"X-User-ID": "user-2"})
assert response.status_code == 200
class TestConvenienceMiddleware:
"""Tests for convenience middleware classes."""
async def test_sliding_window_middleware(self) -> None:
"""Test SlidingWindowMiddleware."""
app = FastAPI()
backend = MemoryBackend()
app.add_middleware(
SlidingWindowMiddleware,
limit=3,
window_size=60,
backend=backend,
)
@app.get("/test")
async def test_endpoint() -> dict[str, str]:
return {"status": "ok"}
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
for _ in range(3):
response = await client.get("/test")
assert response.status_code == 200
response = await client.get("/test")
assert response.status_code == 429
async def test_token_bucket_middleware(self) -> None:
"""Test TokenBucketMiddleware."""
app = FastAPI()
backend = MemoryBackend()
app.add_middleware(
TokenBucketMiddleware,
limit=3,
window_size=60,
backend=backend,
)
@app.get("/test")
async def test_endpoint() -> dict[str, str]:
return {"status": "ok"}
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
for _ in range(3):
response = await client.get("/test")
assert response.status_code == 200
response = await client.get("/test")
assert response.status_code == 429
class TestMiddlewareErrorHandling:
"""Tests for middleware error handling."""
@pytest.fixture
def app_skip_on_error(self) -> FastAPI:
"""Create app with skip_on_error enabled."""
app = FastAPI()
backend = MemoryBackend()
app.add_middleware(
RateLimitMiddleware,
limit=5,
window_size=60,
backend=backend,
skip_on_error=True,
)
@app.get("/api/resource")
async def resource() -> dict[str, str]:
return {"status": "ok"}
return app
@pytest.fixture
async def client(
self, app_skip_on_error: FastAPI
) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=app_skip_on_error)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_normal_operation_with_skip_on_error(
self, client: AsyncClient
) -> None:
"""Test normal operation when skip_on_error is enabled."""
for i in range(5):
response = await client.get("/api/resource")
assert response.status_code == 200, f"Request {i} should succeed"
response = await client.get("/api/resource")
assert response.status_code == 429
class TestMiddlewareHeaderConfiguration:
"""Tests for middleware header configuration."""
async def test_headers_disabled(self) -> None:
"""Test that headers can be disabled."""
app = FastAPI()
backend = MemoryBackend()
app.add_middleware(
RateLimitMiddleware,
limit=5,
window_size=60,
backend=backend,
include_headers=False,
)
@app.get("/test")
async def test_endpoint() -> dict[str, str]:
return {"status": "ok"}
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
assert response.status_code == 200
assert "X-RateLimit-Limit" not in response.headers
async def test_custom_error_message(self) -> None:
"""Test custom error message in middleware."""
app = FastAPI()
backend = MemoryBackend()
app.add_middleware(
RateLimitMiddleware,
limit=1,
window_size=60,
backend=backend,
error_message="Custom: Too many requests",
)
@app.get("/test")
async def test_endpoint() -> dict[str, str]:
return {"status": "ok"}
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
await client.get("/test")
response = await client.get("/test")
assert response.status_code == 429
data = response.json()
assert data["detail"] == "Custom: Too many requests"