387 lines
12 KiB
Python
387 lines
12 KiB
Python
"""Tests for rate limiting middleware.
|
|
|
|
Comprehensive tests covering:
|
|
- Basic middleware functionality
|
|
- Path exemptions
|
|
- IP exemptions
|
|
- Custom key extractors
|
|
- Different algorithms
|
|
- Error handling and skip_on_error
|
|
- Header inclusion
|
|
- Multiple middleware configurations
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from fastapi_traffic import MemoryBackend
|
|
from fastapi_traffic.middleware import (
|
|
RateLimitMiddleware,
|
|
SlidingWindowMiddleware,
|
|
TokenBucketMiddleware,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import AsyncGenerator
|
|
|
|
|
|
class TestRateLimitMiddleware:
|
|
"""Tests for RateLimitMiddleware."""
|
|
|
|
@pytest.fixture
|
|
def app(self) -> FastAPI:
|
|
"""Create a test app with rate limit middleware."""
|
|
app = FastAPI()
|
|
backend = MemoryBackend()
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=5,
|
|
window_size=60,
|
|
backend=backend,
|
|
)
|
|
|
|
@app.get("/api/resource")
|
|
async def resource() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
@app.post("/api/create")
|
|
async def create() -> dict[str, str]:
|
|
return {"status": "created"}
|
|
|
|
return app
|
|
|
|
@pytest.fixture
|
|
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_allows_requests_within_limit(self, client: AsyncClient) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
for i in range(5):
|
|
response = await client.get("/api/resource")
|
|
assert response.status_code == 200, f"Request {i} should succeed"
|
|
|
|
async def test_blocks_requests_over_limit(self, client: AsyncClient) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
for _ in range(5):
|
|
await client.get("/api/resource")
|
|
|
|
response = await client.get("/api/resource")
|
|
assert response.status_code == 429
|
|
|
|
async def test_rate_limit_headers_included(self, client: AsyncClient) -> None:
|
|
"""Test that rate limit headers are included."""
|
|
response = await client.get("/api/resource")
|
|
assert "X-RateLimit-Limit" in response.headers
|
|
assert "X-RateLimit-Remaining" in response.headers
|
|
assert "X-RateLimit-Reset" in response.headers
|
|
|
|
async def test_different_endpoints_counted_separately(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Test that different endpoints are counted separately by path."""
|
|
# Middleware includes path in the key by default
|
|
for _ in range(3):
|
|
response = await client.get("/api/resource")
|
|
assert response.status_code == 200
|
|
|
|
for _ in range(2):
|
|
response = await client.post("/api/create")
|
|
assert response.status_code == 200
|
|
|
|
async def test_rate_limit_response_format(self, client: AsyncClient) -> None:
|
|
"""Test rate limit response format."""
|
|
for _ in range(5):
|
|
await client.get("/api/resource")
|
|
|
|
response = await client.get("/api/resource")
|
|
assert response.status_code == 429
|
|
data = response.json()
|
|
assert "detail" in data
|
|
assert "retry_after" in data
|
|
|
|
|
|
class TestMiddlewareExemptions:
|
|
"""Tests for middleware path and IP exemptions."""
|
|
|
|
@pytest.fixture
|
|
def app(self) -> FastAPI:
|
|
"""Create app with exemptions configured."""
|
|
app = FastAPI()
|
|
backend = MemoryBackend()
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=2,
|
|
window_size=60,
|
|
backend=backend,
|
|
exempt_paths={"/health", "/metrics"},
|
|
exempt_ips={"10.0.0.1"},
|
|
)
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict[str, str]:
|
|
return {"status": "healthy"}
|
|
|
|
@app.get("/metrics")
|
|
async def metrics() -> dict[str, str]:
|
|
return {"requests": "100"}
|
|
|
|
@app.get("/api/data")
|
|
async def data() -> dict[str, str]:
|
|
return {"data": "value"}
|
|
|
|
return app
|
|
|
|
@pytest.fixture
|
|
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_exempt_paths_bypass_limit(self, client: AsyncClient) -> None:
|
|
"""Test that exempt paths bypass rate limiting."""
|
|
for _ in range(10):
|
|
response = await client.get("/health")
|
|
assert response.status_code == 200
|
|
|
|
async def test_multiple_exempt_paths(self, client: AsyncClient) -> None:
|
|
"""Test multiple exempt paths work correctly."""
|
|
for _ in range(5):
|
|
response = await client.get("/health")
|
|
assert response.status_code == 200
|
|
|
|
for _ in range(5):
|
|
response = await client.get("/metrics")
|
|
assert response.status_code == 200
|
|
|
|
async def test_non_exempt_paths_are_limited(self, client: AsyncClient) -> None:
|
|
"""Test that non-exempt paths are rate limited."""
|
|
for _ in range(2):
|
|
response = await client.get("/api/data")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/api/data")
|
|
assert response.status_code == 429
|
|
|
|
async def test_exempt_paths_dont_consume_limit(self, client: AsyncClient) -> None:
|
|
"""Test that exempt path requests don't consume rate limit."""
|
|
for _ in range(10):
|
|
await client.get("/health")
|
|
|
|
for _ in range(2):
|
|
response = await client.get("/api/data")
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestMiddlewareCustomKeyExtractor:
|
|
"""Tests for middleware with custom key extractor."""
|
|
|
|
@pytest.fixture
|
|
def app(self) -> FastAPI:
|
|
"""Create app with custom key extractor."""
|
|
app = FastAPI()
|
|
backend = MemoryBackend()
|
|
|
|
def user_id_extractor(request: object) -> str:
|
|
headers = getattr(request, "headers", {})
|
|
if hasattr(headers, "get"):
|
|
return headers.get("X-User-ID", "anonymous")
|
|
return "anonymous"
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=3,
|
|
window_size=60,
|
|
backend=backend,
|
|
key_extractor=user_id_extractor,
|
|
)
|
|
|
|
@app.get("/api/resource")
|
|
async def resource() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
return app
|
|
|
|
@pytest.fixture
|
|
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_different_users_have_separate_limits(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Test that different users have separate rate limits."""
|
|
for _ in range(3):
|
|
response = await client.get(
|
|
"/api/resource", headers={"X-User-ID": "user-1"}
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/api/resource", headers={"X-User-ID": "user-1"})
|
|
assert response.status_code == 429
|
|
|
|
response = await client.get("/api/resource", headers={"X-User-ID": "user-2"})
|
|
assert response.status_code == 200
|
|
|
|
|
|
class TestConvenienceMiddleware:
|
|
"""Tests for convenience middleware classes."""
|
|
|
|
async def test_sliding_window_middleware(self) -> None:
|
|
"""Test SlidingWindowMiddleware."""
|
|
app = FastAPI()
|
|
backend = MemoryBackend()
|
|
|
|
app.add_middleware(
|
|
SlidingWindowMiddleware,
|
|
limit=3,
|
|
window_size=60,
|
|
backend=backend,
|
|
)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
for _ in range(3):
|
|
response = await client.get("/test")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/test")
|
|
assert response.status_code == 429
|
|
|
|
async def test_token_bucket_middleware(self) -> None:
|
|
"""Test TokenBucketMiddleware."""
|
|
app = FastAPI()
|
|
backend = MemoryBackend()
|
|
|
|
app.add_middleware(
|
|
TokenBucketMiddleware,
|
|
limit=3,
|
|
window_size=60,
|
|
backend=backend,
|
|
)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
for _ in range(3):
|
|
response = await client.get("/test")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/test")
|
|
assert response.status_code == 429
|
|
|
|
|
|
class TestMiddlewareErrorHandling:
|
|
"""Tests for middleware error handling."""
|
|
|
|
@pytest.fixture
|
|
def app_skip_on_error(self) -> FastAPI:
|
|
"""Create app with skip_on_error enabled."""
|
|
app = FastAPI()
|
|
backend = MemoryBackend()
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=5,
|
|
window_size=60,
|
|
backend=backend,
|
|
skip_on_error=True,
|
|
)
|
|
|
|
@app.get("/api/resource")
|
|
async def resource() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
return app
|
|
|
|
@pytest.fixture
|
|
async def client(
|
|
self, app_skip_on_error: FastAPI
|
|
) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=app_skip_on_error)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_normal_operation_with_skip_on_error(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Test normal operation when skip_on_error is enabled."""
|
|
for i in range(5):
|
|
response = await client.get("/api/resource")
|
|
assert response.status_code == 200, f"Request {i} should succeed"
|
|
|
|
response = await client.get("/api/resource")
|
|
assert response.status_code == 429
|
|
|
|
|
|
class TestMiddlewareHeaderConfiguration:
|
|
"""Tests for middleware header configuration."""
|
|
|
|
async def test_headers_disabled(self) -> None:
|
|
"""Test that headers can be disabled."""
|
|
app = FastAPI()
|
|
backend = MemoryBackend()
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=5,
|
|
window_size=60,
|
|
backend=backend,
|
|
include_headers=False,
|
|
)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
response = await client.get("/test")
|
|
assert response.status_code == 200
|
|
assert "X-RateLimit-Limit" not in response.headers
|
|
|
|
async def test_custom_error_message(self) -> None:
|
|
"""Test custom error message in middleware."""
|
|
app = FastAPI()
|
|
backend = MemoryBackend()
|
|
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
limit=1,
|
|
window_size=60,
|
|
backend=backend,
|
|
error_message="Custom: Too many requests",
|
|
)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
await client.get("/test")
|
|
response = await client.get("/test")
|
|
assert response.status_code == 429
|
|
data = response.json()
|
|
assert data["detail"] == "Custom: Too many requests"
|