375 lines
12 KiB
Python
375 lines
12 KiB
Python
"""Async client for the Summarizer API."""
|
|
from __future__ import annotations
|
|
|
|
from types import TracebackType
|
|
from typing import Any, Self
|
|
|
|
import httpx
|
|
|
|
from .exceptions import (
|
|
SummarizerAPIError,
|
|
SummarizerBadRequestError,
|
|
SummarizerConnectionError,
|
|
SummarizerRateLimitError,
|
|
SummarizerServerError,
|
|
SummarizerTimeoutError,
|
|
)
|
|
from .models import (
|
|
CacheEntry,
|
|
CacheStats,
|
|
DeleteCacheRequest,
|
|
ErrorResponse,
|
|
HealthResponse,
|
|
LlmInfo,
|
|
LlmProvider,
|
|
OcrSummarizeResponse,
|
|
SuccessResponse,
|
|
SummarizationStyle,
|
|
SummarizeRequest,
|
|
SummarizeResponse,
|
|
)
|
|
|
|
DEFAULT_BASE_URL = "http://127.0.0.1:3001/api"
|
|
DEFAULT_TIMEOUT = 60.0
|
|
|
|
|
|
class SummarizerClient:
|
|
"""Async client for the Summarizer API.
|
|
|
|
Usage:
|
|
async with SummarizerClient() as client:
|
|
response = await client.summarize(
|
|
text="Long text to summarize...",
|
|
provider=LlmProvider.ANTHROPIC,
|
|
guild_id="123456789",
|
|
user_id="987654321",
|
|
)
|
|
print(response.summary)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str = DEFAULT_BASE_URL,
|
|
timeout: float = DEFAULT_TIMEOUT,
|
|
headers: dict[str, str] | None = None,
|
|
) -> None:
|
|
"""Initialize the client.
|
|
|
|
Args:
|
|
base_url: Base URL for the API (default: http://127.0.0.1:3001/api)
|
|
timeout: Request timeout in seconds (default: 60.0)
|
|
headers: Additional headers to include in requests
|
|
"""
|
|
self._base_url = base_url.rstrip("/")
|
|
self._timeout = timeout
|
|
self._headers = headers or {}
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Enter async context manager and create HTTP session."""
|
|
self._client = httpx.AsyncClient(
|
|
base_url=self._base_url,
|
|
timeout=self._timeout,
|
|
headers=self._headers,
|
|
)
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Exit async context manager and close HTTP session."""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
def _ensure_client(self) -> httpx.AsyncClient:
|
|
"""Ensure client is initialized."""
|
|
if self._client is None:
|
|
msg = (
|
|
"Client not initialized. Use 'async with SummarizerClient() as client:'"
|
|
)
|
|
raise RuntimeError(msg)
|
|
return self._client
|
|
|
|
async def _handle_response(self, response: httpx.Response) -> httpx.Response:
|
|
"""Handle HTTP response and raise appropriate exceptions."""
|
|
if response.status_code == 200:
|
|
return response
|
|
|
|
try:
|
|
error_data = response.json()
|
|
error_response = ErrorResponse.model_validate(error_data)
|
|
error_message = error_response.error
|
|
except Exception: # noqa: BLE001
|
|
error_message = response.text or f"HTTP {response.status_code}"
|
|
|
|
if response.status_code == 400:
|
|
raise SummarizerBadRequestError(error_message, response)
|
|
if response.status_code == 429:
|
|
raise SummarizerRateLimitError(error_message)
|
|
if response.status_code >= 500:
|
|
raise SummarizerServerError(error_message, response.status_code)
|
|
|
|
raise SummarizerAPIError(error_message, response.status_code)
|
|
|
|
async def _request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
json: dict[str, Any] | None = None,
|
|
params: dict[str, str | int | float] | None = None,
|
|
content: bytes | None = None,
|
|
content_type: str | None = None,
|
|
) -> httpx.Response:
|
|
"""Make an HTTP request."""
|
|
client = self._ensure_client()
|
|
|
|
headers: dict[str, str] = {}
|
|
if content_type:
|
|
headers["Content-Type"] = content_type
|
|
|
|
try:
|
|
response = await client.request(
|
|
method,
|
|
path,
|
|
json=json,
|
|
params=params,
|
|
content=content,
|
|
headers=headers if headers else None,
|
|
)
|
|
return await self._handle_response(response)
|
|
except httpx.ConnectError as e:
|
|
raise SummarizerConnectionError(str(e)) from e
|
|
except httpx.TimeoutException as e:
|
|
raise SummarizerTimeoutError(str(e)) from e
|
|
except httpx.HTTPError as e:
|
|
raise SummarizerAPIError(str(e)) from e
|
|
|
|
async def summarize(
|
|
self,
|
|
text: str,
|
|
provider: LlmProvider,
|
|
guild_id: str,
|
|
user_id: str,
|
|
*,
|
|
model: str | None = None,
|
|
temperature: float | None = None,
|
|
max_tokens: int | None = None,
|
|
top_p: float | None = None,
|
|
style: SummarizationStyle | None = None,
|
|
system_prompt: str | None = None,
|
|
channel_id: str | None = None,
|
|
) -> SummarizeResponse:
|
|
"""Summarize text using the specified LLM provider.
|
|
|
|
Args:
|
|
text: The text to summarize
|
|
provider: LLM provider (anthropic, ollama, or lmstudio)
|
|
guild_id: Discord guild ID
|
|
user_id: Discord user ID
|
|
model: Model name (optional, uses default from config)
|
|
temperature: Temperature (0.0 to 2.0)
|
|
max_tokens: Max tokens for the response
|
|
top_p: Top P sampling (0.0 to 1.0)
|
|
style: Predefined summarization style
|
|
system_prompt: Custom system prompt (ignored if style is set)
|
|
channel_id: Discord channel ID
|
|
|
|
Returns:
|
|
SummarizeResponse with the generated summary
|
|
"""
|
|
request = SummarizeRequest(
|
|
text=text,
|
|
provider=provider,
|
|
guild_id=guild_id,
|
|
user_id=user_id,
|
|
model=model,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
style=style,
|
|
system_prompt=system_prompt,
|
|
channel_id=channel_id,
|
|
)
|
|
|
|
response = await self._request(
|
|
"POST",
|
|
"/summarize",
|
|
json=request.model_dump(exclude_none=True),
|
|
)
|
|
return SummarizeResponse.model_validate(response.json())
|
|
|
|
async def ocr_summarize(
|
|
self,
|
|
image_data: bytes,
|
|
provider: LlmProvider,
|
|
guild_id: str,
|
|
user_id: str,
|
|
*,
|
|
model: str | None = None,
|
|
temperature: float | None = None,
|
|
max_tokens: int | None = None,
|
|
top_p: float | None = None,
|
|
style: SummarizationStyle | None = None,
|
|
channel_id: str | None = None,
|
|
) -> OcrSummarizeResponse:
|
|
"""OCR and summarize an image.
|
|
|
|
Args:
|
|
image_data: Raw image bytes
|
|
provider: LLM provider
|
|
guild_id: Discord guild ID
|
|
user_id: Discord user ID
|
|
model: Model name (optional)
|
|
temperature: Temperature (0.0 to 2.0)
|
|
max_tokens: Max tokens for the response
|
|
top_p: Top P sampling (0.0 to 1.0)
|
|
style: Predefined summarization style
|
|
channel_id: Discord channel ID
|
|
|
|
Returns:
|
|
OcrSummarizeResponse with extracted text and summary
|
|
"""
|
|
params: dict[str, str | int | float] = {
|
|
"provider": provider.value,
|
|
"guild_id": guild_id,
|
|
"user_id": user_id,
|
|
}
|
|
if model is not None:
|
|
params["model"] = model
|
|
if temperature is not None:
|
|
params["temperature"] = temperature
|
|
if max_tokens is not None:
|
|
params["max_tokens"] = max_tokens
|
|
if top_p is not None:
|
|
params["top_p"] = top_p
|
|
if style is not None:
|
|
params["style"] = style.value
|
|
if channel_id is not None:
|
|
params["channel_id"] = channel_id
|
|
|
|
response = await self._request(
|
|
"POST",
|
|
"/ocr-summarize",
|
|
params=params,
|
|
content=image_data,
|
|
content_type="application/octet-stream",
|
|
)
|
|
return OcrSummarizeResponse.model_validate(response.json())
|
|
|
|
async def list_models(self) -> list[LlmInfo]:
|
|
"""List available LLM models.
|
|
|
|
Returns:
|
|
List of available LLM models
|
|
"""
|
|
response = await self._request("GET", "/models")
|
|
data = response.json()
|
|
return [LlmInfo.model_validate(item) for item in data]
|
|
|
|
async def get_cache_stats(self) -> CacheStats:
|
|
"""Get cache statistics.
|
|
|
|
Returns:
|
|
CacheStats with cache metrics
|
|
"""
|
|
response = await self._request("GET", "/cache/stats")
|
|
return CacheStats.model_validate(response.json())
|
|
|
|
async def list_cache_entries(
|
|
self,
|
|
*,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
) -> list[CacheEntry]:
|
|
"""List cache entries with pagination.
|
|
|
|
Args:
|
|
limit: Maximum number of entries to return
|
|
offset: Number of entries to skip
|
|
|
|
Returns:
|
|
List of cache entries
|
|
"""
|
|
params: dict[str, str | int | float] = {}
|
|
if limit is not None:
|
|
params["limit"] = limit
|
|
if offset is not None:
|
|
params["offset"] = offset
|
|
|
|
response = await self._request("GET", "/cache/entries", params=params)
|
|
data = response.json()
|
|
return [CacheEntry.model_validate(item) for item in data]
|
|
|
|
async def get_guild_cache(self, guild_id: str) -> list[CacheEntry]:
|
|
"""Get cache entries for a specific guild.
|
|
|
|
Args:
|
|
guild_id: Discord guild ID
|
|
|
|
Returns:
|
|
List of cache entries for the guild
|
|
"""
|
|
response = await self._request("GET", f"/cache/guild/{guild_id}")
|
|
data = response.json()
|
|
return [CacheEntry.model_validate(item) for item in data]
|
|
|
|
async def get_user_cache(self, user_id: str) -> list[CacheEntry]:
|
|
"""Get cache entries for a specific user.
|
|
|
|
Args:
|
|
user_id: Discord user ID
|
|
|
|
Returns:
|
|
List of cache entries for the user
|
|
"""
|
|
response = await self._request("GET", f"/cache/user/{user_id}")
|
|
data = response.json()
|
|
return [CacheEntry.model_validate(item) for item in data]
|
|
|
|
async def delete_cache(
|
|
self,
|
|
*,
|
|
entry_id: int | None = None,
|
|
guild_id: str | None = None,
|
|
user_id: str | None = None,
|
|
delete_all: bool = False,
|
|
) -> SuccessResponse:
|
|
"""Delete cache entries.
|
|
|
|
Args:
|
|
entry_id: Specific cache entry ID to delete
|
|
guild_id: Delete all entries for a guild
|
|
user_id: Delete all entries for a user
|
|
delete_all: Delete all entries (use with caution)
|
|
|
|
Returns:
|
|
SuccessResponse confirming deletion
|
|
"""
|
|
request = DeleteCacheRequest(
|
|
id=entry_id,
|
|
guild_id=guild_id,
|
|
user_id=user_id,
|
|
delete_all=delete_all if delete_all else None,
|
|
)
|
|
|
|
response = await self._request(
|
|
"POST",
|
|
"/cache/delete",
|
|
json=request.model_dump(exclude_none=True),
|
|
)
|
|
return SuccessResponse.model_validate(response.json())
|
|
|
|
async def health_check(self) -> HealthResponse:
|
|
"""Check API health.
|
|
|
|
Returns:
|
|
HealthResponse with status and version
|
|
"""
|
|
response = await self._request("GET", "/health")
|
|
return HealthResponse.model_validate(response.json())
|