214 lines
7.5 KiB
Python
214 lines
7.5 KiB
Python
"""Models for Summarizer API."""
|
|
from __future__ import annotations
|
|
|
|
from enum import StrEnum
|
|
from typing import Annotated
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class LlmProvider(StrEnum):
|
|
"""LLM provider options."""
|
|
|
|
ANTHROPIC = "anthropic"
|
|
OLLAMA = "ollama"
|
|
LMSTUDIO = "lmstudio"
|
|
|
|
|
|
class SummarizationStyle(StrEnum):
|
|
"""Predefined summarization styles."""
|
|
|
|
BRIEF = "brief"
|
|
DETAILED = "detailed"
|
|
FUNNY = "funny"
|
|
PROFESSIONAL = "professional"
|
|
TECHNICAL = "technical"
|
|
ELI5 = "eli5"
|
|
BULLETS = "bullets"
|
|
ACADEMIC = "academic"
|
|
ROAST = "roast"
|
|
|
|
|
|
class TokenUsage(BaseModel):
|
|
"""Token usage information from LLM response."""
|
|
|
|
input_tokens: Annotated[int, Field(description="Input tokens (prompt)")]
|
|
output_tokens: Annotated[int, Field(description="Output tokens (completion)")]
|
|
total_tokens: Annotated[int, Field(description="Total tokens used")]
|
|
estimated_cost_usd: Annotated[
|
|
float | None, Field(default=None, description="Estimated cost in USD")
|
|
]
|
|
|
|
|
|
class ImageInfo(BaseModel):
|
|
"""Image metadata from OCR processing."""
|
|
|
|
width: Annotated[int, Field(description="Image width in pixels")]
|
|
height: Annotated[int, Field(description="Image height in pixels")]
|
|
format: Annotated[str, Field(description="Image format (e.g., PNG, JPEG)")]
|
|
size_bytes: Annotated[int, Field(description="File size in bytes")]
|
|
|
|
|
|
class SummarizeRequest(BaseModel):
|
|
"""Request payload for text summarization."""
|
|
|
|
text: Annotated[
|
|
str, Field(min_length=1, max_length=100000, description="Text to summarize")
|
|
]
|
|
provider: Annotated[LlmProvider, Field(description="LLM provider")]
|
|
guild_id: Annotated[
|
|
str, Field(min_length=1, max_length=50, description="Discord guild ID")
|
|
]
|
|
user_id: Annotated[
|
|
str, Field(min_length=1, max_length=50, description="Discord user ID")
|
|
]
|
|
model: Annotated[
|
|
str | None, Field(default=None, max_length=100, description="Model name")
|
|
]
|
|
temperature: Annotated[
|
|
float | None,
|
|
Field(default=None, ge=0.0, le=2.0, description="Temperature (0.0 to 2.0)"),
|
|
]
|
|
max_tokens: Annotated[
|
|
int | None,
|
|
Field(default=None, ge=1, le=100000, description="Max tokens for response"),
|
|
]
|
|
top_p: Annotated[
|
|
float | None,
|
|
Field(default=None, ge=0.0, le=1.0, description="Top P sampling (0.0 to 1.0)"),
|
|
]
|
|
style: Annotated[
|
|
SummarizationStyle | None,
|
|
Field(default=None, description="Summarization style"),
|
|
]
|
|
system_prompt: Annotated[
|
|
str | None,
|
|
Field(default=None, max_length=5000, description="Custom system prompt"),
|
|
]
|
|
channel_id: Annotated[
|
|
str | None, Field(default=None, max_length=50, description="Discord channel ID")
|
|
]
|
|
|
|
|
|
class SummarizeResponse(BaseModel):
|
|
"""Response from text summarization endpoint."""
|
|
|
|
summary: Annotated[str, Field(description="Generated summary")]
|
|
model: Annotated[str, Field(description="Model used")]
|
|
provider: Annotated[str, Field(description="Provider used")]
|
|
from_cache: Annotated[bool, Field(description="Whether served from cache")]
|
|
checksum: Annotated[str, Field(description="Request checksum")]
|
|
timestamp: Annotated[str, Field(description="Timestamp")]
|
|
token_usage: Annotated[
|
|
TokenUsage | None, Field(default=None, description="Token usage info")
|
|
]
|
|
processing_time_ms: Annotated[
|
|
int | None, Field(default=None, description="Processing time in milliseconds")
|
|
]
|
|
style_used: Annotated[
|
|
str | None, Field(default=None, description="Summarization style used")
|
|
]
|
|
|
|
|
|
class OcrSummarizeResponse(BaseModel):
|
|
"""Response from OCR summarization endpoint."""
|
|
|
|
extracted_text: Annotated[str, Field(description="Extracted text from OCR")]
|
|
summary: Annotated[str, Field(description="Generated summary")]
|
|
model: Annotated[str, Field(description="Model used")]
|
|
provider: Annotated[str, Field(description="Provider used")]
|
|
from_cache: Annotated[bool, Field(description="Whether served from cache")]
|
|
checksum: Annotated[str, Field(description="Request checksum")]
|
|
timestamp: Annotated[str, Field(description="Timestamp")]
|
|
token_usage: Annotated[
|
|
TokenUsage | None, Field(default=None, description="Token usage info")
|
|
]
|
|
ocr_time_ms: Annotated[
|
|
int, Field(description="OCR processing time in milliseconds")
|
|
]
|
|
summarization_time_ms: Annotated[
|
|
int | None,
|
|
Field(default=None, description="Summarization time in milliseconds"),
|
|
]
|
|
total_time_ms: Annotated[
|
|
int, Field(description="Total processing time in milliseconds")
|
|
]
|
|
style_used: Annotated[
|
|
str | None, Field(default=None, description="Summarization style used")
|
|
]
|
|
image_info: Annotated[ImageInfo, Field(description="Image metadata")]
|
|
|
|
|
|
class LlmInfo(BaseModel):
|
|
"""Information about an available LLM model."""
|
|
|
|
provider: Annotated[str, Field(description="Provider name")]
|
|
model: Annotated[str, Field(description="Model name")]
|
|
version: Annotated[str | None, Field(default=None, description="Model version")]
|
|
available: Annotated[bool, Field(description="Whether model is available")]
|
|
|
|
|
|
class CacheEntry(BaseModel):
|
|
"""A single cache entry."""
|
|
|
|
id: Annotated[int, Field(description="Cache entry ID")]
|
|
checksum: Annotated[str, Field(description="Request checksum")]
|
|
text: Annotated[str, Field(description="Original text")]
|
|
summary: Annotated[str, Field(description="Generated summary")]
|
|
provider: Annotated[str, Field(description="Provider used")]
|
|
model: Annotated[str, Field(description="Model used")]
|
|
guild_id: Annotated[str, Field(description="Guild ID")]
|
|
user_id: Annotated[str, Field(description="User ID")]
|
|
channel_id: Annotated[str | None, Field(default=None, description="Channel ID")]
|
|
created_at: Annotated[str, Field(description="Created timestamp")]
|
|
last_accessed: Annotated[str, Field(description="Last accessed timestamp")]
|
|
access_count: Annotated[int, Field(description="Access count")]
|
|
|
|
|
|
class CacheStats(BaseModel):
|
|
"""Cache statistics."""
|
|
|
|
total_entries: Annotated[int, Field(description="Total cache entries")]
|
|
total_hits: Annotated[int, Field(description="Total cache hits")]
|
|
total_misses: Annotated[int, Field(description="Total cache misses")]
|
|
hit_rate: Annotated[float, Field(description="Cache hit rate (percentage)")]
|
|
total_size_bytes: Annotated[int, Field(description="Total size in bytes")]
|
|
|
|
|
|
class DeleteCacheRequest(BaseModel):
|
|
"""Request payload for deleting cache entries."""
|
|
|
|
id: Annotated[
|
|
int | None, Field(default=None, description="Cache entry ID to delete")
|
|
]
|
|
guild_id: Annotated[
|
|
str | None, Field(default=None, description="Delete all for guild")
|
|
]
|
|
user_id: Annotated[
|
|
str | None, Field(default=None, description="Delete all for user")
|
|
]
|
|
delete_all: Annotated[
|
|
bool | None, Field(default=None, description="Delete all entries")
|
|
]
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
|
|
status: Annotated[str, Field(description="Health status (e.g., 'healthy')")]
|
|
version: Annotated[str, Field(description="API version")]
|
|
|
|
|
|
class SuccessResponse(BaseModel):
|
|
"""Generic success response."""
|
|
|
|
success: Annotated[bool, Field(description="Operation success status")]
|
|
message: Annotated[str, Field(description="Success message")]
|
|
|
|
|
|
class ErrorResponse(BaseModel):
|
|
"""Error response from API."""
|
|
|
|
success: Annotated[bool, Field(description="Always False for errors")]
|
|
error: Annotated[str, Field(description="Error message")]
|