295 lines
11 KiB
Python
295 lines
11 KiB
Python
"""Discord UI components for summarization commands."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
|
|
|
import disnake
|
|
from loguru import logger
|
|
|
|
from sdk import LlmInfo, SummarizerClient
|
|
|
|
if TYPE_CHECKING:
|
|
from src.bot import SummarizerBot
|
|
|
|
|
|
class ProviderSelect(disnake.ui.StringSelect):
|
|
"""Dropdown for selecting LLM provider."""
|
|
|
|
def __init__(self, selected: str = "lmstudio") -> None:
|
|
options = [
|
|
disnake.SelectOption(
|
|
label="Anthropic (Claude)",
|
|
value="anthropic",
|
|
description="Claude AI models",
|
|
emoji="🤖",
|
|
default=(selected == "anthropic"),
|
|
),
|
|
disnake.SelectOption(
|
|
label="Ollama",
|
|
value="ollama",
|
|
description="Local Ollama models",
|
|
emoji="🦙",
|
|
default=(selected == "ollama"),
|
|
),
|
|
disnake.SelectOption(
|
|
label="LM Studio",
|
|
value="lmstudio",
|
|
description="Local LM Studio models",
|
|
emoji="💻",
|
|
default=(selected == "lmstudio"),
|
|
),
|
|
]
|
|
super().__init__(
|
|
placeholder="Select AI Provider",
|
|
options=options,
|
|
custom_id="provider_select",
|
|
)
|
|
|
|
async def callback(self, inter: disnake.MessageInteraction) -> None:
|
|
"""Handle provider selection - triggers model list refresh."""
|
|
# The view will handle this
|
|
self.view.selected_provider = self.values[0] # type: ignore[union-attr]
|
|
await self.view.refresh_models(inter) # type: ignore[union-attr]
|
|
|
|
|
|
class ModelSelect(disnake.ui.StringSelect):
|
|
"""Dropdown for selecting LLM model."""
|
|
|
|
def __init__(self, models: list[LlmInfo] | None = None) -> None:
|
|
if models:
|
|
options = [
|
|
disnake.SelectOption(
|
|
label=m.model[:100],
|
|
value=m.model,
|
|
description=f"{m.provider}"
|
|
+ (f" v{m.version}" if m.version else ""), # noqa: E501
|
|
)
|
|
for m in models[:25] # Discord limit
|
|
]
|
|
else:
|
|
options = [
|
|
disnake.SelectOption(
|
|
label="Select provider first",
|
|
value="none",
|
|
description="Choose a provider to see available models",
|
|
)
|
|
]
|
|
super().__init__(
|
|
placeholder="Select Model (optional)",
|
|
options=options,
|
|
custom_id="model_select",
|
|
disabled=not models,
|
|
)
|
|
|
|
async def callback(self, inter: disnake.MessageInteraction) -> None:
|
|
"""Handle model selection."""
|
|
self.view.selected_model = self.values[0] if self.values[0] != "none" else None # type: ignore[union-attr]
|
|
await inter.response.defer()
|
|
|
|
|
|
class ProviderModelView(disnake.ui.View):
|
|
"""View for selecting provider and model before summarization."""
|
|
|
|
def __init__(
|
|
self,
|
|
bot: SummarizerBot,
|
|
on_confirm: Callable[[str, str | None], Coroutine[Any, Any, None]],
|
|
*,
|
|
timeout: float = 120.0,
|
|
author_id: int,
|
|
) -> None:
|
|
super().__init__(timeout=timeout)
|
|
self.bot = bot
|
|
self.on_confirm = on_confirm
|
|
self.author_id = author_id
|
|
self.selected_provider: str = "lmstudio"
|
|
self.selected_model: str | None = None
|
|
self._models_cache: dict[str, list[LlmInfo]] = {}
|
|
|
|
# Add components
|
|
self.provider_select = ProviderSelect()
|
|
self.model_select = ModelSelect()
|
|
self.add_item(self.provider_select)
|
|
self.add_item(self.model_select)
|
|
|
|
async def interaction_check( # type: ignore[override]
|
|
self, interaction: disnake.MessageInteraction
|
|
) -> bool:
|
|
"""Only allow the command author to interact."""
|
|
if interaction.author.id != self.author_id:
|
|
await interaction.response.send_message(
|
|
"Only the command author can use these controls.", ephemeral=True
|
|
)
|
|
return False
|
|
return True
|
|
|
|
async def fetch_models(self) -> list[LlmInfo]:
|
|
"""Fetch available models from API."""
|
|
try:
|
|
async with SummarizerClient(base_url=self.bot.config.bot.api_url) as client:
|
|
return await client.list_models()
|
|
except Exception as e: # noqa: BLE001
|
|
logger.warning(f"Failed to fetch models: {e}")
|
|
return []
|
|
|
|
async def prefetch_models(self) -> None:
|
|
"""Prefetch and cache all models."""
|
|
if not self._models_cache:
|
|
all_models = await self.fetch_models()
|
|
for model in all_models:
|
|
provider = model.provider.lower()
|
|
if provider not in self._models_cache:
|
|
self._models_cache[provider] = []
|
|
self._models_cache[provider].append(model)
|
|
|
|
def get_provider_models(self, provider: str) -> list[LlmInfo]:
|
|
"""Get available models for a specific provider."""
|
|
return [m for m in self._models_cache.get(provider, []) if m.available]
|
|
|
|
async def refresh_models(self, inter: disnake.MessageInteraction) -> None:
|
|
"""Refresh model list based on selected provider."""
|
|
await inter.response.defer()
|
|
|
|
# Fetch models if not cached
|
|
if not self._models_cache:
|
|
all_models = await self.fetch_models()
|
|
for model in all_models:
|
|
provider = model.provider.lower()
|
|
if provider not in self._models_cache:
|
|
self._models_cache[provider] = []
|
|
self._models_cache[provider].append(model)
|
|
|
|
# Filter models for selected provider
|
|
provider_models = self._models_cache.get(self.selected_provider, [])
|
|
available_models = [m for m in provider_models if m.available]
|
|
|
|
# Update provider select to show current selection
|
|
self.remove_item(self.provider_select)
|
|
self.provider_select = ProviderSelect(selected=self.selected_provider)
|
|
self.add_item(self.provider_select)
|
|
|
|
# Update model select
|
|
self.remove_item(self.model_select)
|
|
self.model_select = ModelSelect(available_models if available_models else None)
|
|
self.add_item(self.model_select)
|
|
self.selected_model = None
|
|
|
|
await inter.edit_original_response(view=self)
|
|
|
|
@disnake.ui.button(label="Use Defaults", style=disnake.ButtonStyle.secondary, row=2)
|
|
async def use_defaults(
|
|
self, _button: disnake.ui.Button, _inter: disnake.MessageInteraction # type: ignore[type-arg]
|
|
) -> None:
|
|
"""Use default provider and model."""
|
|
self.stop()
|
|
await self.on_confirm("lmstudio", None)
|
|
|
|
@disnake.ui.button(label="Confirm", style=disnake.ButtonStyle.primary, row=2)
|
|
async def confirm(
|
|
self, _button: disnake.ui.Button, _inter: disnake.MessageInteraction # type: ignore[type-arg]
|
|
) -> None:
|
|
"""Confirm selection and proceed."""
|
|
self.stop()
|
|
await self.on_confirm(self.selected_provider, self.selected_model)
|
|
|
|
@disnake.ui.button(label="Cancel", style=disnake.ButtonStyle.danger, row=2)
|
|
async def cancel(
|
|
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
|
|
) -> None:
|
|
"""Cancel the operation."""
|
|
self.stop()
|
|
await inter.response.edit_message(
|
|
content="Operation cancelled.", embed=None, view=None
|
|
)
|
|
|
|
|
|
class QuickSummarizeView(disnake.ui.View):
|
|
"""View with quick summarize button and advanced options."""
|
|
|
|
def __init__(
|
|
self,
|
|
bot: SummarizerBot,
|
|
summarize_callback: Callable[[str, str | None], Coroutine[Any, Any, None]],
|
|
*,
|
|
timeout: float = 60.0,
|
|
author_id: int,
|
|
) -> None:
|
|
super().__init__(timeout=timeout)
|
|
self.bot = bot
|
|
self.summarize_callback = summarize_callback
|
|
self.author_id = author_id
|
|
|
|
async def interaction_check( # type: ignore[override]
|
|
self, interaction: disnake.MessageInteraction
|
|
) -> bool:
|
|
"""Only allow the command author to interact."""
|
|
if interaction.author.id != self.author_id:
|
|
await interaction.response.send_message(
|
|
"Only the command author can use these controls.", ephemeral=True
|
|
)
|
|
return False
|
|
return True
|
|
|
|
@disnake.ui.button(
|
|
label="Summarize (Default)", style=disnake.ButtonStyle.primary, emoji="⚡"
|
|
)
|
|
async def quick_summarize(
|
|
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
|
|
) -> None:
|
|
"""Quick summarize with defaults."""
|
|
self.stop()
|
|
await inter.response.defer()
|
|
await self.summarize_callback("lmstudio", None)
|
|
|
|
@disnake.ui.button(
|
|
label="Choose Provider/Model", style=disnake.ButtonStyle.secondary, emoji="⚙️"
|
|
)
|
|
async def advanced_options(
|
|
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
|
|
) -> None:
|
|
"""Show advanced provider/model selection."""
|
|
self.stop()
|
|
|
|
async def on_confirm(provider: str, model: str | None) -> None:
|
|
await self.summarize_callback(provider, model)
|
|
|
|
view = ProviderModelView(self.bot, on_confirm, author_id=self.author_id)
|
|
|
|
# Prefetch models
|
|
await view.prefetch_models()
|
|
|
|
# Set up initial model select with lmstudio models
|
|
lmstudio_models = view.get_provider_models("lmstudio")
|
|
view.remove_item(view.model_select)
|
|
view.model_select = ModelSelect(lmstudio_models if lmstudio_models else None)
|
|
view.add_item(view.model_select)
|
|
|
|
embed = disnake.Embed(
|
|
title="🔧 Advanced Options",
|
|
description="Select your preferred AI provider and model.",
|
|
color=disnake.Color.blurple(),
|
|
)
|
|
embed.add_field(
|
|
name="Provider",
|
|
value="Choose the AI service to use for summarization.",
|
|
inline=False,
|
|
)
|
|
embed.add_field(
|
|
name="Model",
|
|
value="Optionally select a specific model (leave empty for default).",
|
|
inline=False,
|
|
)
|
|
|
|
await inter.response.edit_message(embed=embed, view=view)
|
|
|
|
@disnake.ui.button(label="Cancel", style=disnake.ButtonStyle.danger, emoji="✖️")
|
|
async def cancel(
|
|
self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg]
|
|
) -> None:
|
|
"""Cancel the operation."""
|
|
self.stop()
|
|
await inter.response.edit_message(
|
|
content="Operation cancelled.", embed=None, view=None
|
|
)
|