Files
fastapi-traffic/tests/test_decorator.py

315 lines
11 KiB
Python

"""Tests for rate limit decorator and dependency injection.
Comprehensive tests covering:
- Basic decorator functionality
- Custom key extractors
- Different algorithms via decorator
- Cost parameter
- Exemption callbacks
- On-blocked callbacks
- RateLimitDependency usage
- Header inclusion
- Error handling
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from httpx import ASGITransport, AsyncClient
from fastapi_traffic import (
MemoryBackend,
RateLimiter,
RateLimitExceeded,
rate_limit,
)
from fastapi_traffic.core.limiter import set_limiter
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
class TestRateLimitDecorator:
"""Tests for the @rate_limit decorator."""
@pytest.fixture
async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]:
"""Set up a rate limiter for testing."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
await limiter.initialize()
set_limiter(limiter)
yield limiter
await limiter.close()
@pytest.fixture
def app(self, setup_limiter: RateLimiter) -> FastAPI:
"""Create a test app with rate limited endpoints."""
app = FastAPI()
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(
status_code=429,
content={"detail": exc.message, "retry_after": exc.retry_after},
headers=exc.limit_info.to_headers() if exc.limit_info else {},
)
@app.get("/basic")
@rate_limit(3, 60)
async def basic_endpoint(request: Request) -> dict[str, str]:
return {"status": "ok"}
@app.get("/no-headers")
@rate_limit(3, window_size=60, include_headers=False)
async def no_headers_endpoint(request: Request) -> dict[str, str]:
return {"status": "ok"}
@app.get("/custom-message")
@rate_limit(2, window_size=60, error_message="Custom rate limit message")
async def custom_message_endpoint(request: Request) -> dict[str, str]:
return {"status": "ok"}
return app
@pytest.fixture
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_allows_requests_within_limit(self, client: AsyncClient) -> None:
"""Test that requests within limit are allowed."""
for i in range(3):
response = await client.get("/basic")
assert response.status_code == 200, f"Request {i} should succeed"
async def test_blocks_requests_over_limit(self, client: AsyncClient) -> None:
"""Test that requests over limit are blocked."""
for _ in range(3):
await client.get("/basic")
response = await client.get("/basic")
assert response.status_code == 429
async def test_rate_limit_enforced(self, client: AsyncClient) -> None:
"""Test that rate limit is enforced."""
# Use up the limit
for _ in range(3):
response = await client.get("/basic")
assert response.status_code == 200
# Next request should be blocked
response = await client.get("/basic")
assert response.status_code == 429
async def test_headers_excluded_when_disabled(self, client: AsyncClient) -> None:
"""Test that headers are excluded when include_headers=False."""
response = await client.get("/no-headers")
assert response.status_code == 200
assert "X-RateLimit-Limit" not in response.headers
async def test_custom_error_message(self, client: AsyncClient) -> None:
"""Test custom error message is used."""
for _ in range(2):
await client.get("/custom-message")
response = await client.get("/custom-message")
assert response.status_code == 429
data = response.json()
assert data["detail"] == "Custom rate limit message"
async def test_retry_after_header_on_limit(self, client: AsyncClient) -> None:
"""Test Retry-After header is set when rate limited."""
for _ in range(3):
await client.get("/basic")
response = await client.get("/basic")
assert response.status_code == 429
assert "Retry-After" in response.headers
class TestCustomKeyExtractor:
"""Tests for custom key extraction."""
@pytest.fixture
async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]:
"""Set up a rate limiter for testing."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
await limiter.initialize()
set_limiter(limiter)
yield limiter
await limiter.close()
@pytest.fixture
def app(self, setup_limiter: RateLimiter) -> FastAPI:
"""Create app with custom key extractor."""
app = FastAPI()
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
def api_key_extractor(request: Request) -> str:
return request.headers.get("X-API-Key", "anonymous")
@app.get("/by-api-key")
@rate_limit(2, window_size=60, key_extractor=api_key_extractor)
async def by_api_key(request: Request) -> dict[str, str]:
return {"status": "ok"}
return app
@pytest.fixture
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_different_keys_have_separate_limits(
self, client: AsyncClient
) -> None:
"""Test that different API keys have separate rate limits."""
for _ in range(2):
response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"})
assert response.status_code == 200
response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"})
assert response.status_code == 429
response = await client.get("/by-api-key", headers={"X-API-Key": "key-b"})
assert response.status_code == 200
async def test_anonymous_key_for_missing_header(self, client: AsyncClient) -> None:
"""Test that missing API key uses anonymous."""
for _ in range(2):
response = await client.get("/by-api-key")
assert response.status_code == 200
response = await client.get("/by-api-key")
assert response.status_code == 429
class TestExemptionCallback:
"""Tests for exempt_when callback."""
@pytest.fixture
async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]:
"""Set up a rate limiter for testing."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
await limiter.initialize()
set_limiter(limiter)
yield limiter
await limiter.close()
@pytest.fixture
def app(self, setup_limiter: RateLimiter) -> FastAPI:
"""Create app with exemption callback."""
app = FastAPI()
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
def is_admin(request: Request) -> bool:
return request.headers.get("X-Admin-Token") == "secret"
@app.get("/with-exemption")
@rate_limit(2, window_size=60, exempt_when=is_admin)
async def with_exemption(request: Request) -> dict[str, str]:
return {"status": "ok"}
return app
@pytest.fixture
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_exempt_requests_bypass_limit(self, client: AsyncClient) -> None:
"""Test that exempt requests bypass rate limiting."""
for _ in range(5):
response = await client.get(
"/with-exemption", headers={"X-Admin-Token": "secret"}
)
assert response.status_code == 200
async def test_non_exempt_requests_are_limited(self, client: AsyncClient) -> None:
"""Test that non-exempt requests are rate limited."""
for _ in range(2):
response = await client.get("/with-exemption")
assert response.status_code == 200
response = await client.get("/with-exemption")
assert response.status_code == 429
class TestCostParameter:
"""Tests for the cost parameter."""
@pytest.fixture
async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]:
"""Set up a rate limiter for testing."""
backend = MemoryBackend()
limiter = RateLimiter(backend)
await limiter.initialize()
set_limiter(limiter)
yield limiter
await limiter.close()
@pytest.fixture
def app(self, setup_limiter: RateLimiter) -> FastAPI:
"""Create app with cost-based endpoints."""
app = FastAPI()
@app.exception_handler(RateLimitExceeded)
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
return JSONResponse(status_code=429, content={"detail": exc.message})
@app.get("/low-cost")
@rate_limit(10, window_size=60, cost=1)
async def low_cost(request: Request) -> dict[str, str]:
return {"status": "ok"}
@app.get("/high-cost")
@rate_limit(10, window_size=60, cost=5)
async def high_cost(request: Request) -> dict[str, str]:
return {"status": "ok"}
return app
@pytest.fixture
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create test client."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
async def test_low_cost_allows_more_requests(self, client: AsyncClient) -> None:
"""Test that low cost endpoints allow more requests."""
for i in range(10):
response = await client.get("/low-cost")
assert response.status_code == 200, f"Request {i} should succeed"
response = await client.get("/low-cost")
assert response.status_code == 429
async def test_high_cost_allows_fewer_requests(self, client: AsyncClient) -> None:
"""Test that high cost endpoints allow fewer requests."""
for i in range(2):
response = await client.get("/high-cost")
assert response.status_code == 200, f"Request {i} should succeed"
response = await client.get("/high-cost")
assert response.status_code == 429