"""Discord UI components for summarization commands.""" from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Coroutine import disnake from loguru import logger from sdk import LlmInfo, SummarizerClient if TYPE_CHECKING: from src.bot import SummarizerBot class ProviderSelect(disnake.ui.StringSelect): """Dropdown for selecting LLM provider.""" def __init__(self, selected: str = "lmstudio") -> None: options = [ disnake.SelectOption( label="Anthropic (Claude)", value="anthropic", description="Claude AI models", emoji="🤖", default=(selected == "anthropic"), ), disnake.SelectOption( label="Ollama", value="ollama", description="Local Ollama models", emoji="🦙", default=(selected == "ollama"), ), disnake.SelectOption( label="LM Studio", value="lmstudio", description="Local LM Studio models", emoji="💻", default=(selected == "lmstudio"), ), ] super().__init__( placeholder="Select AI Provider", options=options, custom_id="provider_select", ) async def callback(self, inter: disnake.MessageInteraction) -> None: """Handle provider selection - triggers model list refresh.""" # The view will handle this self.view.selected_provider = self.values[0] # type: ignore[union-attr] await self.view.refresh_models(inter) # type: ignore[union-attr] class ModelSelect(disnake.ui.StringSelect): """Dropdown for selecting LLM model.""" def __init__(self, models: list[LlmInfo] | None = None) -> None: if models: options = [ disnake.SelectOption( label=m.model[:100], value=m.model, description=f"{m.provider}" + (f" v{m.version}" if m.version else ""), # noqa: E501 ) for m in models[:25] # Discord limit ] else: options = [ disnake.SelectOption( label="Select provider first", value="none", description="Choose a provider to see available models", ) ] super().__init__( placeholder="Select Model (optional)", options=options, custom_id="model_select", disabled=not models, ) async def callback(self, inter: disnake.MessageInteraction) -> None: """Handle model selection.""" self.view.selected_model = self.values[0] if self.values[0] != "none" else None # type: ignore[union-attr] await inter.response.defer() class ProviderModelView(disnake.ui.View): """View for selecting provider and model before summarization.""" def __init__( self, bot: SummarizerBot, on_confirm: Callable[[str, str | None], Coroutine[Any, Any, None]], *, timeout: float = 120.0, author_id: int, ) -> None: super().__init__(timeout=timeout) self.bot = bot self.on_confirm = on_confirm self.author_id = author_id self.selected_provider: str = "lmstudio" self.selected_model: str | None = None self._models_cache: dict[str, list[LlmInfo]] = {} # Add components self.provider_select = ProviderSelect() self.model_select = ModelSelect() self.add_item(self.provider_select) self.add_item(self.model_select) async def interaction_check( # type: ignore[override] self, interaction: disnake.MessageInteraction ) -> bool: """Only allow the command author to interact.""" if interaction.author.id != self.author_id: await interaction.response.send_message( "Only the command author can use these controls.", ephemeral=True ) return False return True async def fetch_models(self) -> list[LlmInfo]: """Fetch available models from API.""" try: async with SummarizerClient(base_url=self.bot.config.bot.api_url) as client: return await client.list_models() except Exception as e: # noqa: BLE001 logger.warning(f"Failed to fetch models: {e}") return [] async def prefetch_models(self) -> None: """Prefetch and cache all models.""" if not self._models_cache: all_models = await self.fetch_models() for model in all_models: provider = model.provider.lower() if provider not in self._models_cache: self._models_cache[provider] = [] self._models_cache[provider].append(model) def get_provider_models(self, provider: str) -> list[LlmInfo]: """Get available models for a specific provider.""" return [m for m in self._models_cache.get(provider, []) if m.available] async def refresh_models(self, inter: disnake.MessageInteraction) -> None: """Refresh model list based on selected provider.""" await inter.response.defer() # Fetch models if not cached if not self._models_cache: all_models = await self.fetch_models() for model in all_models: provider = model.provider.lower() if provider not in self._models_cache: self._models_cache[provider] = [] self._models_cache[provider].append(model) # Filter models for selected provider provider_models = self._models_cache.get(self.selected_provider, []) available_models = [m for m in provider_models if m.available] # Update provider select to show current selection self.remove_item(self.provider_select) self.provider_select = ProviderSelect(selected=self.selected_provider) self.add_item(self.provider_select) # Update model select self.remove_item(self.model_select) self.model_select = ModelSelect(available_models if available_models else None) self.add_item(self.model_select) self.selected_model = None await inter.edit_original_response(view=self) @disnake.ui.button(label="Use Defaults", style=disnake.ButtonStyle.secondary, row=2) async def use_defaults( self, _button: disnake.ui.Button, _inter: disnake.MessageInteraction # type: ignore[type-arg] ) -> None: """Use default provider and model.""" self.stop() await self.on_confirm("lmstudio", None) @disnake.ui.button(label="Confirm", style=disnake.ButtonStyle.primary, row=2) async def confirm( self, _button: disnake.ui.Button, _inter: disnake.MessageInteraction # type: ignore[type-arg] ) -> None: """Confirm selection and proceed.""" self.stop() await self.on_confirm(self.selected_provider, self.selected_model) @disnake.ui.button(label="Cancel", style=disnake.ButtonStyle.danger, row=2) async def cancel( self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg] ) -> None: """Cancel the operation.""" self.stop() await inter.response.edit_message( content="Operation cancelled.", embed=None, view=None ) class QuickSummarizeView(disnake.ui.View): """View with quick summarize button and advanced options.""" def __init__( self, bot: SummarizerBot, summarize_callback: Callable[[str, str | None], Coroutine[Any, Any, None]], *, timeout: float = 60.0, author_id: int, ) -> None: super().__init__(timeout=timeout) self.bot = bot self.summarize_callback = summarize_callback self.author_id = author_id async def interaction_check( # type: ignore[override] self, interaction: disnake.MessageInteraction ) -> bool: """Only allow the command author to interact.""" if interaction.author.id != self.author_id: await interaction.response.send_message( "Only the command author can use these controls.", ephemeral=True ) return False return True @disnake.ui.button( label="Summarize (Default)", style=disnake.ButtonStyle.primary, emoji="⚡" ) async def quick_summarize( self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg] ) -> None: """Quick summarize with defaults.""" self.stop() await inter.response.defer() await self.summarize_callback("lmstudio", None) @disnake.ui.button( label="Choose Provider/Model", style=disnake.ButtonStyle.secondary, emoji="⚙️" ) async def advanced_options( self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg] ) -> None: """Show advanced provider/model selection.""" self.stop() async def on_confirm(provider: str, model: str | None) -> None: await self.summarize_callback(provider, model) view = ProviderModelView(self.bot, on_confirm, author_id=self.author_id) # Prefetch models await view.prefetch_models() # Set up initial model select with lmstudio models lmstudio_models = view.get_provider_models("lmstudio") view.remove_item(view.model_select) view.model_select = ModelSelect(lmstudio_models if lmstudio_models else None) view.add_item(view.model_select) embed = disnake.Embed( title="🔧 Advanced Options", description="Select your preferred AI provider and model.", color=disnake.Color.blurple(), ) embed.add_field( name="Provider", value="Choose the AI service to use for summarization.", inline=False, ) embed.add_field( name="Model", value="Optionally select a specific model (leave empty for default).", inline=False, ) await inter.response.edit_message(embed=embed, view=view) @disnake.ui.button(label="Cancel", style=disnake.ButtonStyle.danger, emoji="✖️") async def cancel( self, _button: disnake.ui.Button, inter: disnake.MessageInteraction # type: ignore[type-arg] ) -> None: """Cancel the operation.""" self.stop() await inter.response.edit_message( content="Operation cancelled.", embed=None, view=None )