Files
SummaryDiscordBot/discord_bot/src/views.py
2025-12-12 15:31:27 +00:00

295 lines
11 KiB
Python

"""Discord UI components for summarization commands."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Coroutine
import disnake
from loguru import logger
from sdk import LlmInfo, SummarizerClient
if TYPE_CHECKING:
from src.bot import SummarizerBot
class ProviderSelect(disnake.ui.StringSelect):
"""Dropdown for selecting LLM provider."""
def __init__(self, selected: str = "lmstudio") -> None:
options = [
disnake.SelectOption(
label="Anthropic (Claude)",
value="anthropic",
description="Claude AI models",
emoji="🤖",
default=(selected == "anthropic"),
),
disnake.SelectOption(
label="Ollama",
value="ollama",
description="Local Ollama models",
emoji="🦙",
default=(selected == "ollama"),
),
disnake.SelectOption(
label="LM Studio",
value="lmstudio",
description="Local LM Studio models",
emoji="💻",
default=(selected == "lmstudio"),
),
]
super().__init__(
placeholder="Select AI Provider",
options=options,
custom_id="provider_select",
)
async def callback(self, inter: disnake.MessageInteraction) -> None:
"""Handle provider selection - triggers model list refresh."""
# The view will handle this
self.view.selected_provider = self.values[0] # type: ignore[union-attr]
await self.view.refresh_models(inter) # type: ignore[union-attr]
class ModelSelect(disnake.ui.StringSelect):
"""Dropdown for selecting LLM model."""
def __init__(self, models: list[LlmInfo] | None = None) -> None:
if models:
options = [
disnake.SelectOption(
label=m.model[:100],
value=m.model,
description=f"{m.provider}"
+ (f" v{m.version}" if m.version else ""), # noqa: E501
)
for m in models[:25] # Discord limit
]
else:
options = [
disnake.SelectOption(
label="Select provider first",
value="none",
description="Choose a provider to see available models",
)
]
super().__init__(
placeholder="Select Model (optional)",
options=options,
custom_id="model_select",
disabled=not models,
)
async def callback(self, inter: disnake.MessageInteraction) -> None:
"""Handle model selection."""
self.view.selected_model = self.values[0] if self.values[0] != "none" else None # type: ignore[union-attr]
await inter.response.defer()
class ProviderModelView(disnake.ui.View):
"""View for selecting provider and model before summarization."""
def __init__(
self,
bot: SummarizerBot,
on_confirm: Callable[[str, str | None], Coroutine[Any, Any, None]],
*,
timeout: float = 120.0,
author_id: int,
) -> None:
super().__init__(timeout=timeout)
self.bot = bot
self.on_confirm = on_confirm
self.author_id = author_id
self.selected_provider: str = "lmstudio"
self.selected_model: str | None = None
self._models_cache: dict[str, list[LlmInfo]] = {}
# Add components
self.provider_select = ProviderSelect()
self.model_select = ModelSelect()
self.add_item(self.provider_select)
self.add_item(self.model_select)
async def interaction_check( # type: ignore[override]
self, interaction: disnake.MessageInteraction
) -> bool:
"""Only allow the command author to interact."""
if interaction.author.id != self.author_id:
await interaction.response.send_message(
"Only the command author can use these controls.", ephemeral=True
)
return False
return True
async def fetch_models(self) -> list[LlmInfo]:
"""Fetch available models from API."""
try:
async with SummarizerClient(base_url=self.bot.config.bot.api_url) as client:
return await client.list_models()
except Exception as e: # noqa: BLE001
logger.warning(f"Failed to fetch models: {e}")
return []
async def prefetch_models(self) -> None:
"""Prefetch and cache all models."""
if not self._models_cache:
all_models = await self.fetch_models()
for model in all_models:
provider = model.provider.lower()
if provider not in self._models_cache:
self._models_cache[provider] = []
self._models_cache[provider].append(model)
def get_provider_models(self, provider: str) -> list[LlmInfo]:
"""Get available models for a specific provider."""
return [m for m in self._models_cache.get(provider, []) if m.available]
async def refresh_models(self, inter: disnake.MessageInteraction) -> None:
"""Refresh model list based on selected provider."""
await inter.response.defer()
# Fetch models if not cached
if not self._models_cache:
all_models = await self.fetch_models()
for model in all_models:
provider = model.provider.lower()
if provider not in self._models_cache:
self._models_cache[provider] = []
self._models_cache[provider].append(model)
# Filter models for selected provider
provider_models = self._models_cache.get(self.selected_provider, [])
available_models = [m for m in provider_models if m.available]
# Update provider select to show current selection
self.remove_item(self.provider_select)
self.provider_select = ProviderSelect(selected=self.selected_provider)
self.add_item(self.provider_select)
# Update model select
self.remove_item(self.model_select)
self.model_select = ModelSelect(available_models if available_models else None)
self.add_item(self.model_select)
self.selected_model = None
await inter.edit_original_response(view=self)
@disnake.ui.button(label="Use Defaults", style=disnake.ButtonStyle.secondary, row=2)
async def use_defaults(
self, _button: disnake.ui.Button, _inter: disnake.MessageInteraction # type: ignore[type-arg]
) -> None:
"""Use default provider and model."""
self.stop()
await self.on_confirm("lmstudio", None)
@disnake.ui.button(label="Confirm", style=disnake.ButtonStyle.primary, row=2)
async def confirm(
self, _button: disnake.ui.Button, _inter: disnake.MessageInteraction # type: ignore[type-arg]
) -> None:
"""Confirm selection and proceed."""
self.stop()
await self.on_confirm(self.selected_provider, self.selected_model)
@disnake.ui.button(label="Cancel", style=disnake.ButtonStyle.danger, row=2)
async def cancel(
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
) -> None:
"""Cancel the operation."""
self.stop()
await inter.response.edit_message(
content="Operation cancelled.", embed=None, view=None
)
class QuickSummarizeView(disnake.ui.View):
"""View with quick summarize button and advanced options."""
def __init__(
self,
bot: SummarizerBot,
summarize_callback: Callable[[str, str | None], Coroutine[Any, Any, None]],
*,
timeout: float = 60.0,
author_id: int,
) -> None:
super().__init__(timeout=timeout)
self.bot = bot
self.summarize_callback = summarize_callback
self.author_id = author_id
async def interaction_check( # type: ignore[override]
self, interaction: disnake.MessageInteraction
) -> bool:
"""Only allow the command author to interact."""
if interaction.author.id != self.author_id:
await interaction.response.send_message(
"Only the command author can use these controls.", ephemeral=True
)
return False
return True
@disnake.ui.button(
label="Summarize (Default)", style=disnake.ButtonStyle.primary, emoji=""
)
async def quick_summarize(
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
) -> None:
"""Quick summarize with defaults."""
self.stop()
await inter.response.defer()
await self.summarize_callback("lmstudio", None)
@disnake.ui.button(
label="Choose Provider/Model", style=disnake.ButtonStyle.secondary, emoji="⚙️"
)
async def advanced_options(
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
) -> None:
"""Show advanced provider/model selection."""
self.stop()
async def on_confirm(provider: str, model: str | None) -> None:
await self.summarize_callback(provider, model)
view = ProviderModelView(self.bot, on_confirm, author_id=self.author_id)
# Prefetch models
await view.prefetch_models()
# Set up initial model select with lmstudio models
lmstudio_models = view.get_provider_models("lmstudio")
view.remove_item(view.model_select)
view.model_select = ModelSelect(lmstudio_models if lmstudio_models else None)
view.add_item(view.model_select)
embed = disnake.Embed(
title="🔧 Advanced Options",
description="Select your preferred AI provider and model.",
color=disnake.Color.blurple(),
)
embed.add_field(
name="Provider",
value="Choose the AI service to use for summarization.",
inline=False,
)
embed.add_field(
name="Model",
value="Optionally select a specific model (leave empty for default).",
inline=False,
)
await inter.response.edit_message(embed=embed, view=view)
@disnake.ui.button(label="Cancel", style=disnake.ButtonStyle.danger, emoji="✖️")
async def cancel(
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
) -> None:
"""Cancel the operation."""
self.stop()
await inter.response.edit_message(
content="Operation cancelled.", embed=None, view=None
)