315 lines
11 KiB
Python
315 lines
11 KiB
Python
"""Tests for rate limit decorator and dependency injection.
|
|
|
|
Comprehensive tests covering:
|
|
- Basic decorator functionality
|
|
- Custom key extractors
|
|
- Different algorithms via decorator
|
|
- Cost parameter
|
|
- Exemption callbacks
|
|
- On-blocked callbacks
|
|
- RateLimitDependency usage
|
|
- Header inclusion
|
|
- Error handling
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from fastapi_traffic import (
|
|
MemoryBackend,
|
|
RateLimiter,
|
|
RateLimitExceeded,
|
|
rate_limit,
|
|
)
|
|
from fastapi_traffic.core.limiter import set_limiter
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import AsyncGenerator
|
|
|
|
|
|
class TestRateLimitDecorator:
|
|
"""Tests for the @rate_limit decorator."""
|
|
|
|
@pytest.fixture
|
|
async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]:
|
|
"""Set up a rate limiter for testing."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield limiter
|
|
await limiter.close()
|
|
|
|
@pytest.fixture
|
|
def app(self, setup_limiter: RateLimiter) -> FastAPI:
|
|
"""Create a test app with rate limited endpoints."""
|
|
app = FastAPI()
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": exc.message, "retry_after": exc.retry_after},
|
|
headers=exc.limit_info.to_headers() if exc.limit_info else {},
|
|
)
|
|
|
|
@app.get("/basic")
|
|
@rate_limit(3, 60)
|
|
async def basic_endpoint(request: Request) -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/no-headers")
|
|
@rate_limit(3, window_size=60, include_headers=False)
|
|
async def no_headers_endpoint(request: Request) -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/custom-message")
|
|
@rate_limit(2, window_size=60, error_message="Custom rate limit message")
|
|
async def custom_message_endpoint(request: Request) -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
return app
|
|
|
|
@pytest.fixture
|
|
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_allows_requests_within_limit(self, client: AsyncClient) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
for i in range(3):
|
|
response = await client.get("/basic")
|
|
assert response.status_code == 200, f"Request {i} should succeed"
|
|
|
|
async def test_blocks_requests_over_limit(self, client: AsyncClient) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
for _ in range(3):
|
|
await client.get("/basic")
|
|
|
|
response = await client.get("/basic")
|
|
assert response.status_code == 429
|
|
|
|
async def test_rate_limit_enforced(self, client: AsyncClient) -> None:
|
|
"""Test that rate limit is enforced."""
|
|
# Use up the limit
|
|
for _ in range(3):
|
|
response = await client.get("/basic")
|
|
assert response.status_code == 200
|
|
|
|
# Next request should be blocked
|
|
response = await client.get("/basic")
|
|
assert response.status_code == 429
|
|
|
|
async def test_headers_excluded_when_disabled(self, client: AsyncClient) -> None:
|
|
"""Test that headers are excluded when include_headers=False."""
|
|
response = await client.get("/no-headers")
|
|
assert response.status_code == 200
|
|
assert "X-RateLimit-Limit" not in response.headers
|
|
|
|
async def test_custom_error_message(self, client: AsyncClient) -> None:
|
|
"""Test custom error message is used."""
|
|
for _ in range(2):
|
|
await client.get("/custom-message")
|
|
|
|
response = await client.get("/custom-message")
|
|
assert response.status_code == 429
|
|
data = response.json()
|
|
assert data["detail"] == "Custom rate limit message"
|
|
|
|
async def test_retry_after_header_on_limit(self, client: AsyncClient) -> None:
|
|
"""Test Retry-After header is set when rate limited."""
|
|
for _ in range(3):
|
|
await client.get("/basic")
|
|
|
|
response = await client.get("/basic")
|
|
assert response.status_code == 429
|
|
assert "Retry-After" in response.headers
|
|
|
|
|
|
class TestCustomKeyExtractor:
|
|
"""Tests for custom key extraction."""
|
|
|
|
@pytest.fixture
|
|
async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]:
|
|
"""Set up a rate limiter for testing."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield limiter
|
|
await limiter.close()
|
|
|
|
@pytest.fixture
|
|
def app(self, setup_limiter: RateLimiter) -> FastAPI:
|
|
"""Create app with custom key extractor."""
|
|
app = FastAPI()
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(status_code=429, content={"detail": exc.message})
|
|
|
|
def api_key_extractor(request: Request) -> str:
|
|
return request.headers.get("X-API-Key", "anonymous")
|
|
|
|
@app.get("/by-api-key")
|
|
@rate_limit(2, window_size=60, key_extractor=api_key_extractor)
|
|
async def by_api_key(request: Request) -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
return app
|
|
|
|
@pytest.fixture
|
|
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_different_keys_have_separate_limits(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Test that different API keys have separate rate limits."""
|
|
for _ in range(2):
|
|
response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"})
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/by-api-key", headers={"X-API-Key": "key-a"})
|
|
assert response.status_code == 429
|
|
|
|
response = await client.get("/by-api-key", headers={"X-API-Key": "key-b"})
|
|
assert response.status_code == 200
|
|
|
|
async def test_anonymous_key_for_missing_header(self, client: AsyncClient) -> None:
|
|
"""Test that missing API key uses anonymous."""
|
|
for _ in range(2):
|
|
response = await client.get("/by-api-key")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/by-api-key")
|
|
assert response.status_code == 429
|
|
|
|
|
|
class TestExemptionCallback:
|
|
"""Tests for exempt_when callback."""
|
|
|
|
@pytest.fixture
|
|
async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]:
|
|
"""Set up a rate limiter for testing."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield limiter
|
|
await limiter.close()
|
|
|
|
@pytest.fixture
|
|
def app(self, setup_limiter: RateLimiter) -> FastAPI:
|
|
"""Create app with exemption callback."""
|
|
app = FastAPI()
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(status_code=429, content={"detail": exc.message})
|
|
|
|
def is_admin(request: Request) -> bool:
|
|
return request.headers.get("X-Admin-Token") == "secret"
|
|
|
|
@app.get("/with-exemption")
|
|
@rate_limit(2, window_size=60, exempt_when=is_admin)
|
|
async def with_exemption(request: Request) -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
return app
|
|
|
|
@pytest.fixture
|
|
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_exempt_requests_bypass_limit(self, client: AsyncClient) -> None:
|
|
"""Test that exempt requests bypass rate limiting."""
|
|
for _ in range(5):
|
|
response = await client.get(
|
|
"/with-exemption", headers={"X-Admin-Token": "secret"}
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
async def test_non_exempt_requests_are_limited(self, client: AsyncClient) -> None:
|
|
"""Test that non-exempt requests are rate limited."""
|
|
for _ in range(2):
|
|
response = await client.get("/with-exemption")
|
|
assert response.status_code == 200
|
|
|
|
response = await client.get("/with-exemption")
|
|
assert response.status_code == 429
|
|
|
|
|
|
class TestCostParameter:
|
|
"""Tests for the cost parameter."""
|
|
|
|
@pytest.fixture
|
|
async def setup_limiter(self) -> AsyncGenerator[RateLimiter, None]:
|
|
"""Set up a rate limiter for testing."""
|
|
backend = MemoryBackend()
|
|
limiter = RateLimiter(backend)
|
|
await limiter.initialize()
|
|
set_limiter(limiter)
|
|
yield limiter
|
|
await limiter.close()
|
|
|
|
@pytest.fixture
|
|
def app(self, setup_limiter: RateLimiter) -> FastAPI:
|
|
"""Create app with cost-based endpoints."""
|
|
app = FastAPI()
|
|
|
|
@app.exception_handler(RateLimitExceeded)
|
|
async def handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
return JSONResponse(status_code=429, content={"detail": exc.message})
|
|
|
|
@app.get("/low-cost")
|
|
@rate_limit(10, window_size=60, cost=1)
|
|
async def low_cost(request: Request) -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/high-cost")
|
|
@rate_limit(10, window_size=60, cost=5)
|
|
async def high_cost(request: Request) -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
return app
|
|
|
|
@pytest.fixture
|
|
async def client(self, app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
|
"""Create test client."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
async def test_low_cost_allows_more_requests(self, client: AsyncClient) -> None:
|
|
"""Test that low cost endpoints allow more requests."""
|
|
for i in range(10):
|
|
response = await client.get("/low-cost")
|
|
assert response.status_code == 200, f"Request {i} should succeed"
|
|
|
|
response = await client.get("/low-cost")
|
|
assert response.status_code == 429
|
|
|
|
async def test_high_cost_allows_fewer_requests(self, client: AsyncClient) -> None:
|
|
"""Test that high cost endpoints allow fewer requests."""
|
|
for i in range(2):
|
|
response = await client.get("/high-cost")
|
|
assert response.status_code == 200, f"Request {i} should succeed"
|
|
|
|
response = await client.get("/high-cost")
|
|
assert response.status_code == 429
|