Files
SummaryDiscordBot/discord_bot/sdk/client.py
2025-12-12 15:31:27 +00:00

375 lines
12 KiB
Python

"""Async client for the Summarizer API."""
from __future__ import annotations
from types import TracebackType
from typing import Any, Self
import httpx
from .exceptions import (
SummarizerAPIError,
SummarizerBadRequestError,
SummarizerConnectionError,
SummarizerRateLimitError,
SummarizerServerError,
SummarizerTimeoutError,
)
from .models import (
CacheEntry,
CacheStats,
DeleteCacheRequest,
ErrorResponse,
HealthResponse,
LlmInfo,
LlmProvider,
OcrSummarizeResponse,
SuccessResponse,
SummarizationStyle,
SummarizeRequest,
SummarizeResponse,
)
DEFAULT_BASE_URL = "http://127.0.0.1:3001/api"
DEFAULT_TIMEOUT = 60.0
class SummarizerClient:
"""Async client for the Summarizer API.
Usage:
async with SummarizerClient() as client:
response = await client.summarize(
text="Long text to summarize...",
provider=LlmProvider.ANTHROPIC,
guild_id="123456789",
user_id="987654321",
)
print(response.summary)
"""
def __init__(
self,
base_url: str = DEFAULT_BASE_URL,
timeout: float = DEFAULT_TIMEOUT,
headers: dict[str, str] | None = None,
) -> None:
"""Initialize the client.
Args:
base_url: Base URL for the API (default: http://127.0.0.1:3001/api)
timeout: Request timeout in seconds (default: 60.0)
headers: Additional headers to include in requests
"""
self._base_url = base_url.rstrip("/")
self._timeout = timeout
self._headers = headers or {}
self._client: httpx.AsyncClient | None = None
async def __aenter__(self) -> Self:
"""Enter async context manager and create HTTP session."""
self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=self._timeout,
headers=self._headers,
)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit async context manager and close HTTP session."""
if self._client:
await self._client.aclose()
self._client = None
def _ensure_client(self) -> httpx.AsyncClient:
"""Ensure client is initialized."""
if self._client is None:
msg = (
"Client not initialized. Use 'async with SummarizerClient() as client:'"
)
raise RuntimeError(msg)
return self._client
async def _handle_response(self, response: httpx.Response) -> httpx.Response:
"""Handle HTTP response and raise appropriate exceptions."""
if response.status_code == 200:
return response
try:
error_data = response.json()
error_response = ErrorResponse.model_validate(error_data)
error_message = error_response.error
except Exception: # noqa: BLE001
error_message = response.text or f"HTTP {response.status_code}"
if response.status_code == 400:
raise SummarizerBadRequestError(error_message, response)
if response.status_code == 429:
raise SummarizerRateLimitError(error_message)
if response.status_code >= 500:
raise SummarizerServerError(error_message, response.status_code)
raise SummarizerAPIError(error_message, response.status_code)
async def _request(
self,
method: str,
path: str,
*,
json: dict[str, Any] | None = None,
params: dict[str, str | int | float] | None = None,
content: bytes | None = None,
content_type: str | None = None,
) -> httpx.Response:
"""Make an HTTP request."""
client = self._ensure_client()
headers: dict[str, str] = {}
if content_type:
headers["Content-Type"] = content_type
try:
response = await client.request(
method,
path,
json=json,
params=params,
content=content,
headers=headers if headers else None,
)
return await self._handle_response(response)
except httpx.ConnectError as e:
raise SummarizerConnectionError(str(e)) from e
except httpx.TimeoutException as e:
raise SummarizerTimeoutError(str(e)) from e
except httpx.HTTPError as e:
raise SummarizerAPIError(str(e)) from e
async def summarize(
self,
text: str,
provider: LlmProvider,
guild_id: str,
user_id: str,
*,
model: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
top_p: float | None = None,
style: SummarizationStyle | None = None,
system_prompt: str | None = None,
channel_id: str | None = None,
) -> SummarizeResponse:
"""Summarize text using the specified LLM provider.
Args:
text: The text to summarize
provider: LLM provider (anthropic, ollama, or lmstudio)
guild_id: Discord guild ID
user_id: Discord user ID
model: Model name (optional, uses default from config)
temperature: Temperature (0.0 to 2.0)
max_tokens: Max tokens for the response
top_p: Top P sampling (0.0 to 1.0)
style: Predefined summarization style
system_prompt: Custom system prompt (ignored if style is set)
channel_id: Discord channel ID
Returns:
SummarizeResponse with the generated summary
"""
request = SummarizeRequest(
text=text,
provider=provider,
guild_id=guild_id,
user_id=user_id,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
style=style,
system_prompt=system_prompt,
channel_id=channel_id,
)
response = await self._request(
"POST",
"/summarize",
json=request.model_dump(exclude_none=True),
)
return SummarizeResponse.model_validate(response.json())
async def ocr_summarize(
self,
image_data: bytes,
provider: LlmProvider,
guild_id: str,
user_id: str,
*,
model: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
top_p: float | None = None,
style: SummarizationStyle | None = None,
channel_id: str | None = None,
) -> OcrSummarizeResponse:
"""OCR and summarize an image.
Args:
image_data: Raw image bytes
provider: LLM provider
guild_id: Discord guild ID
user_id: Discord user ID
model: Model name (optional)
temperature: Temperature (0.0 to 2.0)
max_tokens: Max tokens for the response
top_p: Top P sampling (0.0 to 1.0)
style: Predefined summarization style
channel_id: Discord channel ID
Returns:
OcrSummarizeResponse with extracted text and summary
"""
params: dict[str, str | int | float] = {
"provider": provider.value,
"guild_id": guild_id,
"user_id": user_id,
}
if model is not None:
params["model"] = model
if temperature is not None:
params["temperature"] = temperature
if max_tokens is not None:
params["max_tokens"] = max_tokens
if top_p is not None:
params["top_p"] = top_p
if style is not None:
params["style"] = style.value
if channel_id is not None:
params["channel_id"] = channel_id
response = await self._request(
"POST",
"/ocr-summarize",
params=params,
content=image_data,
content_type="application/octet-stream",
)
return OcrSummarizeResponse.model_validate(response.json())
async def list_models(self) -> list[LlmInfo]:
"""List available LLM models.
Returns:
List of available LLM models
"""
response = await self._request("GET", "/models")
data = response.json()
return [LlmInfo.model_validate(item) for item in data]
async def get_cache_stats(self) -> CacheStats:
"""Get cache statistics.
Returns:
CacheStats with cache metrics
"""
response = await self._request("GET", "/cache/stats")
return CacheStats.model_validate(response.json())
async def list_cache_entries(
self,
*,
limit: int | None = None,
offset: int | None = None,
) -> list[CacheEntry]:
"""List cache entries with pagination.
Args:
limit: Maximum number of entries to return
offset: Number of entries to skip
Returns:
List of cache entries
"""
params: dict[str, str | int | float] = {}
if limit is not None:
params["limit"] = limit
if offset is not None:
params["offset"] = offset
response = await self._request("GET", "/cache/entries", params=params)
data = response.json()
return [CacheEntry.model_validate(item) for item in data]
async def get_guild_cache(self, guild_id: str) -> list[CacheEntry]:
"""Get cache entries for a specific guild.
Args:
guild_id: Discord guild ID
Returns:
List of cache entries for the guild
"""
response = await self._request("GET", f"/cache/guild/{guild_id}")
data = response.json()
return [CacheEntry.model_validate(item) for item in data]
async def get_user_cache(self, user_id: str) -> list[CacheEntry]:
"""Get cache entries for a specific user.
Args:
user_id: Discord user ID
Returns:
List of cache entries for the user
"""
response = await self._request("GET", f"/cache/user/{user_id}")
data = response.json()
return [CacheEntry.model_validate(item) for item in data]
async def delete_cache(
self,
*,
entry_id: int | None = None,
guild_id: str | None = None,
user_id: str | None = None,
delete_all: bool = False,
) -> SuccessResponse:
"""Delete cache entries.
Args:
entry_id: Specific cache entry ID to delete
guild_id: Delete all entries for a guild
user_id: Delete all entries for a user
delete_all: Delete all entries (use with caution)
Returns:
SuccessResponse confirming deletion
"""
request = DeleteCacheRequest(
id=entry_id,
guild_id=guild_id,
user_id=user_id,
delete_all=delete_all if delete_all else None,
)
response = await self._request(
"POST",
"/cache/delete",
json=request.model_dump(exclude_none=True),
)
return SuccessResponse.model_validate(response.json())
async def health_check(self) -> HealthResponse:
"""Check API health.
Returns:
HealthResponse with status and version
"""
response = await self._request("GET", "/health")
return HealthResponse.model_validate(response.json())