476 lines
15 KiB
Python
476 lines
15 KiB
Python
"""Tests for rate limiting algorithms.
|
|
|
|
Comprehensive tests covering:
|
|
- Basic allow/block behavior
|
|
- Limit boundaries and edge cases
|
|
- Token refill and window reset timing
|
|
- Concurrent access patterns
|
|
- State persistence and recovery
|
|
- Different key isolation
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
|
|
from fastapi_traffic.backends.memory import MemoryBackend
|
|
from fastapi_traffic.core.algorithms import (
|
|
Algorithm,
|
|
FixedWindowAlgorithm,
|
|
LeakyBucketAlgorithm,
|
|
SlidingWindowAlgorithm,
|
|
SlidingWindowCounterAlgorithm,
|
|
TokenBucketAlgorithm,
|
|
get_algorithm,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import AsyncGenerator
|
|
|
|
|
|
@pytest.fixture
|
|
async def backend() -> AsyncGenerator[MemoryBackend, None]:
|
|
"""Create a memory backend for testing."""
|
|
backend = MemoryBackend()
|
|
yield backend
|
|
await backend.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestTokenBucketAlgorithm:
|
|
"""Tests for TokenBucketAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = TokenBucketAlgorithm(10, 60.0, backend)
|
|
|
|
for i in range(10):
|
|
allowed, _ = await algo.check(f"key_{i % 2}")
|
|
assert allowed, f"Request {i} should be allowed"
|
|
|
|
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = TokenBucketAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, info = await algo.check("test_key")
|
|
assert not allowed
|
|
assert info.retry_after is not None
|
|
assert info.retry_after > 0
|
|
|
|
async def test_reset(self, backend: MemoryBackend) -> None:
|
|
"""Test reset functionality."""
|
|
algo = TokenBucketAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
await algo.check("test_key")
|
|
|
|
allowed, _ = await algo.check("test_key")
|
|
assert not allowed
|
|
|
|
await algo.reset("test_key")
|
|
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSlidingWindowAlgorithm:
|
|
"""Tests for SlidingWindowAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = SlidingWindowAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = SlidingWindowAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, info = await algo.check("test_key")
|
|
assert not allowed
|
|
assert info.remaining == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestFixedWindowAlgorithm:
|
|
"""Tests for FixedWindowAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = FixedWindowAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = FixedWindowAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, info = await algo.check("test_key")
|
|
assert not allowed
|
|
assert info.remaining == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestLeakyBucketAlgorithm:
|
|
"""Tests for LeakyBucketAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = LeakyBucketAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = LeakyBucketAlgorithm(3, 60.0, backend)
|
|
|
|
# Leaky bucket allows burst_size requests initially
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
# After burst, should eventually block
|
|
# Note: Leaky bucket behavior depends on leak rate
|
|
allowed, info = await algo.check("test_key")
|
|
# Just verify we get valid info back
|
|
assert info.limit == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSlidingWindowCounterAlgorithm:
|
|
"""Tests for SlidingWindowCounterAlgorithm."""
|
|
|
|
async def test_allows_requests_within_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests within limit are allowed."""
|
|
algo = SlidingWindowCounterAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
async def test_blocks_requests_over_limit(self, backend: MemoryBackend) -> None:
|
|
"""Test that requests over limit are blocked."""
|
|
algo = SlidingWindowCounterAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("test_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("test_key")
|
|
assert not allowed
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestGetAlgorithm:
|
|
"""Tests for get_algorithm factory function."""
|
|
|
|
async def test_get_token_bucket(self, backend: MemoryBackend) -> None:
|
|
"""Test getting token bucket algorithm."""
|
|
algo = get_algorithm(Algorithm.TOKEN_BUCKET, 10, 60.0, backend)
|
|
assert isinstance(algo, TokenBucketAlgorithm)
|
|
|
|
async def test_get_sliding_window(self, backend: MemoryBackend) -> None:
|
|
"""Test getting sliding window algorithm."""
|
|
algo = get_algorithm(Algorithm.SLIDING_WINDOW, 10, 60.0, backend)
|
|
assert isinstance(algo, SlidingWindowAlgorithm)
|
|
|
|
async def test_get_fixed_window(self, backend: MemoryBackend) -> None:
|
|
"""Test getting fixed window algorithm."""
|
|
algo = get_algorithm(Algorithm.FIXED_WINDOW, 10, 60.0, backend)
|
|
assert isinstance(algo, FixedWindowAlgorithm)
|
|
|
|
async def test_get_leaky_bucket(self, backend: MemoryBackend) -> None:
|
|
"""Test getting leaky bucket algorithm."""
|
|
algo = get_algorithm(Algorithm.LEAKY_BUCKET, 10, 60.0, backend)
|
|
assert isinstance(algo, LeakyBucketAlgorithm)
|
|
|
|
async def test_get_sliding_window_counter(self, backend: MemoryBackend) -> None:
|
|
"""Test getting sliding window counter algorithm."""
|
|
algo = get_algorithm(Algorithm.SLIDING_WINDOW_COUNTER, 10, 60.0, backend)
|
|
assert isinstance(algo, SlidingWindowCounterAlgorithm)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestTokenBucketAdvanced:
|
|
"""Advanced tests for TokenBucketAlgorithm."""
|
|
|
|
async def test_token_refill_over_time(self, backend: MemoryBackend) -> None:
|
|
"""Test that tokens refill after time passes."""
|
|
algo = TokenBucketAlgorithm(5, 1.0, backend)
|
|
|
|
for _ in range(5):
|
|
allowed, _ = await algo.check("refill_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("refill_key")
|
|
assert not allowed
|
|
|
|
await asyncio.sleep(0.3)
|
|
|
|
allowed, _ = await algo.check("refill_key")
|
|
assert allowed
|
|
|
|
async def test_burst_size_configuration(self, backend: MemoryBackend) -> None:
|
|
"""Test that burst_size limits initial tokens."""
|
|
algo = TokenBucketAlgorithm(100, 60.0, backend, burst_size=5)
|
|
|
|
for i in range(5):
|
|
allowed, _ = await algo.check("burst_key")
|
|
assert allowed, f"Request {i} should be allowed"
|
|
|
|
allowed, _ = await algo.check("burst_key")
|
|
assert not allowed
|
|
|
|
async def test_key_isolation(self, backend: MemoryBackend) -> None:
|
|
"""Test that different keys have separate limits."""
|
|
algo = TokenBucketAlgorithm(3, 60.0, backend)
|
|
|
|
for _ in range(3):
|
|
await algo.check("key_a")
|
|
|
|
allowed_a, _ = await algo.check("key_a")
|
|
assert not allowed_a
|
|
|
|
allowed_b, _ = await algo.check("key_b")
|
|
assert allowed_b
|
|
|
|
async def test_concurrent_requests(self, backend: MemoryBackend) -> None:
|
|
"""Test concurrent request handling."""
|
|
algo = TokenBucketAlgorithm(10, 60.0, backend)
|
|
|
|
async def make_request() -> bool:
|
|
allowed, _ = await algo.check("concurrent_key")
|
|
return allowed
|
|
|
|
results = await asyncio.gather(*[make_request() for _ in range(15)])
|
|
allowed_count = sum(results)
|
|
assert allowed_count == 10
|
|
|
|
async def test_rate_limit_info_accuracy(self, backend: MemoryBackend) -> None:
|
|
"""Test that rate limit info is accurate."""
|
|
algo = TokenBucketAlgorithm(5, 60.0, backend)
|
|
|
|
allowed, info = await algo.check("info_key")
|
|
assert allowed
|
|
assert info.limit == 5
|
|
assert info.remaining == 4
|
|
|
|
for _ in range(4):
|
|
await algo.check("info_key")
|
|
|
|
allowed, info = await algo.check("info_key")
|
|
assert not allowed
|
|
assert info.remaining == 0
|
|
assert info.retry_after is not None
|
|
assert info.retry_after > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSlidingWindowAdvanced:
|
|
"""Advanced tests for SlidingWindowAlgorithm."""
|
|
|
|
async def test_window_expiration(self, backend: MemoryBackend) -> None:
|
|
"""Test that old requests expire from the window."""
|
|
algo = SlidingWindowAlgorithm(3, 0.5, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("expire_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("expire_key")
|
|
assert not allowed
|
|
|
|
await asyncio.sleep(0.6)
|
|
|
|
allowed, _ = await algo.check("expire_key")
|
|
assert allowed
|
|
|
|
async def test_sliding_behavior(self, backend: MemoryBackend) -> None:
|
|
"""Test that window slides correctly."""
|
|
algo = SlidingWindowAlgorithm(2, 1.0, backend)
|
|
|
|
allowed, _ = await algo.check("slide_key")
|
|
assert allowed
|
|
|
|
await asyncio.sleep(0.3)
|
|
|
|
allowed, _ = await algo.check("slide_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("slide_key")
|
|
assert not allowed
|
|
|
|
await asyncio.sleep(0.8)
|
|
|
|
allowed, _ = await algo.check("slide_key")
|
|
assert allowed
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestFixedWindowAdvanced:
|
|
"""Advanced tests for FixedWindowAlgorithm."""
|
|
|
|
async def test_window_boundary_reset(self, backend: MemoryBackend) -> None:
|
|
"""Test that counter resets at window boundary."""
|
|
algo = FixedWindowAlgorithm(3, 0.5, backend)
|
|
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("boundary_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("boundary_key")
|
|
assert not allowed
|
|
|
|
await asyncio.sleep(0.6)
|
|
|
|
allowed, _ = await algo.check("boundary_key")
|
|
assert allowed
|
|
|
|
async def test_multiple_windows(self, backend: MemoryBackend) -> None:
|
|
"""Test behavior across multiple windows."""
|
|
algo = FixedWindowAlgorithm(2, 0.3, backend)
|
|
|
|
for _ in range(2):
|
|
allowed, _ = await algo.check("multi_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("multi_key")
|
|
assert not allowed
|
|
|
|
await asyncio.sleep(0.35)
|
|
|
|
for _ in range(2):
|
|
allowed, _ = await algo.check("multi_key")
|
|
assert allowed
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestLeakyBucketAdvanced:
|
|
"""Advanced tests for LeakyBucketAlgorithm."""
|
|
|
|
async def test_leak_rate(self, backend: MemoryBackend) -> None:
|
|
"""Test that bucket leaks over time."""
|
|
algo = LeakyBucketAlgorithm(3, 1.0, backend)
|
|
|
|
# Make initial requests
|
|
for _ in range(3):
|
|
allowed, _ = await algo.check("leak_key")
|
|
assert allowed
|
|
|
|
# Wait for some leaking to occur
|
|
await asyncio.sleep(0.5)
|
|
|
|
# Should be able to make another request after leak
|
|
allowed, info = await algo.check("leak_key")
|
|
assert info.limit == 3
|
|
|
|
async def test_steady_rate_enforcement(self, backend: MemoryBackend) -> None:
|
|
"""Test that leaky bucket tracks requests."""
|
|
algo = LeakyBucketAlgorithm(5, 1.0, backend)
|
|
|
|
# Make several requests
|
|
for _ in range(5):
|
|
allowed, info = await algo.check("steady_key")
|
|
assert allowed
|
|
assert info.limit == 5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSlidingWindowCounterAdvanced:
|
|
"""Advanced tests for SlidingWindowCounterAlgorithm."""
|
|
|
|
async def test_weighted_counting(self, backend: MemoryBackend) -> None:
|
|
"""Test weighted counting between windows."""
|
|
algo = SlidingWindowCounterAlgorithm(10, 1.0, backend)
|
|
|
|
for _ in range(8):
|
|
allowed, _ = await algo.check("weighted_key")
|
|
assert allowed
|
|
|
|
await asyncio.sleep(0.6)
|
|
|
|
allowed, info = await algo.check("weighted_key")
|
|
assert allowed
|
|
assert info.remaining > 0
|
|
|
|
async def test_precision_vs_fixed_window(self, backend: MemoryBackend) -> None:
|
|
"""Test that sliding window counter is more precise than fixed window."""
|
|
algo = SlidingWindowCounterAlgorithm(4, 1.0, backend)
|
|
|
|
for _ in range(4):
|
|
allowed, _ = await algo.check("precision_key")
|
|
assert allowed
|
|
|
|
allowed, _ = await algo.check("precision_key")
|
|
assert not allowed
|
|
|
|
# Wait for the full window to pass to ensure tokens are fully replenished
|
|
await asyncio.sleep(1.1)
|
|
|
|
allowed, _ = await algo.check("precision_key")
|
|
assert allowed
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestAlgorithmStateManagement:
|
|
"""Tests for algorithm state management."""
|
|
|
|
async def test_get_state_without_consuming(self, backend: MemoryBackend) -> None:
|
|
"""Test getting state without consuming tokens."""
|
|
algo = TokenBucketAlgorithm(5, 60.0, backend)
|
|
|
|
await algo.check("state_key")
|
|
await algo.check("state_key")
|
|
|
|
state = await algo.get_state("state_key")
|
|
assert state is not None
|
|
assert state.remaining == 3
|
|
|
|
state2 = await algo.get_state("state_key")
|
|
assert state2 is not None
|
|
assert state2.remaining == 3
|
|
|
|
async def test_get_state_nonexistent_key(self, backend: MemoryBackend) -> None:
|
|
"""Test getting state for nonexistent key."""
|
|
algo = TokenBucketAlgorithm(5, 60.0, backend)
|
|
state = await algo.get_state("nonexistent_key")
|
|
assert state is None
|
|
|
|
async def test_reset_restores_full_capacity(self, backend: MemoryBackend) -> None:
|
|
"""Test that reset restores full capacity."""
|
|
algo = TokenBucketAlgorithm(5, 60.0, backend)
|
|
|
|
for _ in range(5):
|
|
await algo.check("reset_key")
|
|
|
|
allowed, _ = await algo.check("reset_key")
|
|
assert not allowed
|
|
|
|
await algo.reset("reset_key")
|
|
|
|
allowed, info = await algo.check("reset_key")
|
|
assert allowed
|
|
assert info.remaining == 4
|