"""Async client for the Summarizer API.""" from __future__ import annotations from types import TracebackType from typing import Any, Self import httpx from .exceptions import ( SummarizerAPIError, SummarizerBadRequestError, SummarizerConnectionError, SummarizerRateLimitError, SummarizerServerError, SummarizerTimeoutError, ) from .models import ( CacheEntry, CacheStats, DeleteCacheRequest, ErrorResponse, HealthResponse, LlmInfo, LlmProvider, OcrSummarizeResponse, SuccessResponse, SummarizationStyle, SummarizeRequest, SummarizeResponse, ) DEFAULT_BASE_URL = "http://127.0.0.1:3001/api" DEFAULT_TIMEOUT = 60.0 class SummarizerClient: """Async client for the Summarizer API. Usage: async with SummarizerClient() as client: response = await client.summarize( text="Long text to summarize...", provider=LlmProvider.ANTHROPIC, guild_id="123456789", user_id="987654321", ) print(response.summary) """ def __init__( self, base_url: str = DEFAULT_BASE_URL, timeout: float = DEFAULT_TIMEOUT, headers: dict[str, str] | None = None, ) -> None: """Initialize the client. Args: base_url: Base URL for the API (default: http://127.0.0.1:3001/api) timeout: Request timeout in seconds (default: 60.0) headers: Additional headers to include in requests """ self._base_url = base_url.rstrip("/") self._timeout = timeout self._headers = headers or {} self._client: httpx.AsyncClient | None = None async def __aenter__(self) -> Self: """Enter async context manager and create HTTP session.""" self._client = httpx.AsyncClient( base_url=self._base_url, timeout=self._timeout, headers=self._headers, ) return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """Exit async context manager and close HTTP session.""" if self._client: await self._client.aclose() self._client = None def _ensure_client(self) -> httpx.AsyncClient: """Ensure client is initialized.""" if self._client is None: msg = ( "Client not initialized. Use 'async with SummarizerClient() as client:'" ) raise RuntimeError(msg) return self._client async def _handle_response(self, response: httpx.Response) -> httpx.Response: """Handle HTTP response and raise appropriate exceptions.""" if response.status_code == 200: return response try: error_data = response.json() error_response = ErrorResponse.model_validate(error_data) error_message = error_response.error except Exception: # noqa: BLE001 error_message = response.text or f"HTTP {response.status_code}" if response.status_code == 400: raise SummarizerBadRequestError(error_message, response) if response.status_code == 429: raise SummarizerRateLimitError(error_message) if response.status_code >= 500: raise SummarizerServerError(error_message, response.status_code) raise SummarizerAPIError(error_message, response.status_code) async def _request( self, method: str, path: str, *, json: dict[str, Any] | None = None, params: dict[str, str | int | float] | None = None, content: bytes | None = None, content_type: str | None = None, ) -> httpx.Response: """Make an HTTP request.""" client = self._ensure_client() headers: dict[str, str] = {} if content_type: headers["Content-Type"] = content_type try: response = await client.request( method, path, json=json, params=params, content=content, headers=headers if headers else None, ) return await self._handle_response(response) except httpx.ConnectError as e: raise SummarizerConnectionError(str(e)) from e except httpx.TimeoutException as e: raise SummarizerTimeoutError(str(e)) from e except httpx.HTTPError as e: raise SummarizerAPIError(str(e)) from e async def summarize( self, text: str, provider: LlmProvider, guild_id: str, user_id: str, *, model: str | None = None, temperature: float | None = None, max_tokens: int | None = None, top_p: float | None = None, style: SummarizationStyle | None = None, system_prompt: str | None = None, channel_id: str | None = None, ) -> SummarizeResponse: """Summarize text using the specified LLM provider. Args: text: The text to summarize provider: LLM provider (anthropic, ollama, or lmstudio) guild_id: Discord guild ID user_id: Discord user ID model: Model name (optional, uses default from config) temperature: Temperature (0.0 to 2.0) max_tokens: Max tokens for the response top_p: Top P sampling (0.0 to 1.0) style: Predefined summarization style system_prompt: Custom system prompt (ignored if style is set) channel_id: Discord channel ID Returns: SummarizeResponse with the generated summary """ request = SummarizeRequest( text=text, provider=provider, guild_id=guild_id, user_id=user_id, model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p, style=style, system_prompt=system_prompt, channel_id=channel_id, ) response = await self._request( "POST", "/summarize", json=request.model_dump(exclude_none=True), ) return SummarizeResponse.model_validate(response.json()) async def ocr_summarize( self, image_data: bytes, provider: LlmProvider, guild_id: str, user_id: str, *, model: str | None = None, temperature: float | None = None, max_tokens: int | None = None, top_p: float | None = None, style: SummarizationStyle | None = None, channel_id: str | None = None, ) -> OcrSummarizeResponse: """OCR and summarize an image. Args: image_data: Raw image bytes provider: LLM provider guild_id: Discord guild ID user_id: Discord user ID model: Model name (optional) temperature: Temperature (0.0 to 2.0) max_tokens: Max tokens for the response top_p: Top P sampling (0.0 to 1.0) style: Predefined summarization style channel_id: Discord channel ID Returns: OcrSummarizeResponse with extracted text and summary """ params: dict[str, str | int | float] = { "provider": provider.value, "guild_id": guild_id, "user_id": user_id, } if model is not None: params["model"] = model if temperature is not None: params["temperature"] = temperature if max_tokens is not None: params["max_tokens"] = max_tokens if top_p is not None: params["top_p"] = top_p if style is not None: params["style"] = style.value if channel_id is not None: params["channel_id"] = channel_id response = await self._request( "POST", "/ocr-summarize", params=params, content=image_data, content_type="application/octet-stream", ) return OcrSummarizeResponse.model_validate(response.json()) async def list_models(self) -> list[LlmInfo]: """List available LLM models. Returns: List of available LLM models """ response = await self._request("GET", "/models") data = response.json() return [LlmInfo.model_validate(item) for item in data] async def get_cache_stats(self) -> CacheStats: """Get cache statistics. Returns: CacheStats with cache metrics """ response = await self._request("GET", "/cache/stats") return CacheStats.model_validate(response.json()) async def list_cache_entries( self, *, limit: int | None = None, offset: int | None = None, ) -> list[CacheEntry]: """List cache entries with pagination. Args: limit: Maximum number of entries to return offset: Number of entries to skip Returns: List of cache entries """ params: dict[str, str | int | float] = {} if limit is not None: params["limit"] = limit if offset is not None: params["offset"] = offset response = await self._request("GET", "/cache/entries", params=params) data = response.json() return [CacheEntry.model_validate(item) for item in data] async def get_guild_cache(self, guild_id: str) -> list[CacheEntry]: """Get cache entries for a specific guild. Args: guild_id: Discord guild ID Returns: List of cache entries for the guild """ response = await self._request("GET", f"/cache/guild/{guild_id}") data = response.json() return [CacheEntry.model_validate(item) for item in data] async def get_user_cache(self, user_id: str) -> list[CacheEntry]: """Get cache entries for a specific user. Args: user_id: Discord user ID Returns: List of cache entries for the user """ response = await self._request("GET", f"/cache/user/{user_id}") data = response.json() return [CacheEntry.model_validate(item) for item in data] async def delete_cache( self, *, entry_id: int | None = None, guild_id: str | None = None, user_id: str | None = None, delete_all: bool = False, ) -> SuccessResponse: """Delete cache entries. Args: entry_id: Specific cache entry ID to delete guild_id: Delete all entries for a guild user_id: Delete all entries for a user delete_all: Delete all entries (use with caution) Returns: SuccessResponse confirming deletion """ request = DeleteCacheRequest( id=entry_id, guild_id=guild_id, user_id=user_id, delete_all=delete_all if delete_all else None, ) response = await self._request( "POST", "/cache/delete", json=request.model_dump(exclude_none=True), ) return SuccessResponse.model_validate(response.json()) async def health_check(self) -> HealthResponse: """Check API health. Returns: HealthResponse with status and version """ response = await self._request("GET", "/health") return HealthResponse.model_validate(response.json())