Files
SummaryDiscordBot/backend_api/src/llm.rs
2025-12-12 15:31:27 +00:00

403 lines
13 KiB
Rust

use anyhow::{anyhow, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::config::Config;
use crate::models::{SummarizeRequest, TokenUsage};
use crate::v1::enums::LlmProvider;
use crate::claude_models::ClaudeModelType;
pub struct LlmClient {
client: Client,
config: Config,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicRequest {
model: String,
max_tokens: u32,
temperature: f32,
messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicResponse {
content: Vec<AnthropicContent>,
model: String,
usage: AnthropicUsage,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicUsage {
input_tokens: u32,
output_tokens: u32,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicContent {
text: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaRequest {
model: String,
prompt: String,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
options: OllamaOptions,
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaOptions {
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<i32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaResponse {
response: String,
model: String,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct _OllamaModelInfo {
pub name: String,
pub modified_at: String,
pub size: i64,
}
#[derive(Debug, Serialize, Deserialize)]
struct _OllamaListResponse {
models: Vec<_OllamaModelInfo>,
}
impl LlmClient {
pub fn new(config: Config) -> Self {
Self {
client: Client::new(),
config,
}
}
pub async fn summarize(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
match request.provider {
LlmProvider::Anthropic => self.summarize_anthropic(request).await,
LlmProvider::Ollama => self.summarize_ollama(request).await,
LlmProvider::Lmstudio => self.summarize_lmstudio(request).await,
}
}
async fn summarize_anthropic(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
let anthropic_config = self.config.anthropic.as_ref()
.ok_or_else(|| anyhow!("Anthropic config not found"))?;
let model = request.model.clone()
.unwrap_or_else(|| anthropic_config.model.clone());
let max_tokens = request.max_tokens
.unwrap_or(anthropic_config.max_tokens);
let temperature = request.temperature
.unwrap_or(anthropic_config.temperature);
let system_prompt = if let Some(ref style) = request.style {
style.to_system_prompt()
} else {
request.system_prompt.clone()
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
};
let anthropic_request = AnthropicRequest {
model: model.clone(),
max_tokens,
temperature,
messages: vec![
AnthropicMessage {
role: "user".to_string(),
content: format!("Please summarize the following text:\n\n{}", request.text),
}
],
system: Some(system_prompt),
top_p: request.top_p,
};
let response = self.client
.post(format!("{}/messages", anthropic_config.base_url))
.header("x-api-key", &anthropic_config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&anthropic_request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow!("Anthropic API error: {}", error_text));
}
let anthropic_response: AnthropicResponse = response.json().await?;
let summary = anthropic_response.content
.first()
.map(|c| c.text.clone())
.ok_or_else(|| anyhow!("No content in Anthropic response"))?;
let token_usage = TokenUsage {
input_tokens: anthropic_response.usage.input_tokens,
output_tokens: anthropic_response.usage.output_tokens,
total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens,
estimated_cost_usd: Some(calculate_anthropic_cost(
&model,
anthropic_response.usage.input_tokens,
anthropic_response.usage.output_tokens,
)),
};
Ok((summary, model, token_usage))
}
async fn summarize_ollama(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
let ollama_config = self.config.ollama.as_ref()
.ok_or_else(|| anyhow!("Ollama config not found"))?;
let model = request.model.clone()
.unwrap_or_else(|| ollama_config.model.clone());
let temperature = request.temperature
.unwrap_or(ollama_config.temperature);
let system_prompt = if let Some(ref style) = request.style {
style.to_system_prompt()
} else {
request.system_prompt.clone()
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
};
let prompt = format!("Please summarize the following text:\n\n{}", request.text);
let ollama_request = OllamaRequest {
model: model.clone(),
prompt,
stream: false,
system: Some(system_prompt),
options: OllamaOptions {
temperature,
top_p: request.top_p,
num_predict: request.max_tokens.map(|v| v as i32),
},
};
let response = self.client
.post(format!("{}/api/generate", ollama_config.base_url))
.json(&ollama_request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow!("Ollama API error: {}", error_text));
}
let ollama_response: OllamaResponse = response.json().await?;
let token_usage = TokenUsage {
input_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
output_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0) + ollama_response.eval_count.unwrap_or(0),
estimated_cost_usd: None, // Ollama is local, no cost
};
Ok((ollama_response.response, model, token_usage))
}
pub async fn _list_ollama_models(&self) -> Result<Vec<_OllamaModelInfo>> {
let ollama_config = self.config.ollama.as_ref()
.ok_or_else(|| anyhow!("Ollama config not found"))?;
let response = self.client
.get(format!("{}/api/tags", ollama_config.base_url))
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow!("Failed to list Ollama models"));
}
let list_response: _OllamaListResponse = response.json().await?;
Ok(list_response.models)
}
pub fn _get_anthropic_models(&self) -> Vec<String> {
// Return known Claude models from claude_models module
ClaudeModelType::all_available()
.into_iter()
.map(|m| m.id())
.collect()
}
async fn summarize_lmstudio(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
let lmstudio_config = self.config.lmstudio.as_ref()
.ok_or_else(|| anyhow!("LMStudio config not found"))?;
let model = request.model.clone()
.unwrap_or_else(|| lmstudio_config.model.clone());
let temperature = request.temperature
.unwrap_or(lmstudio_config.temperature);
let max_tokens = request.max_tokens
.unwrap_or(lmstudio_config.max_tokens);
let system_prompt = if let Some(ref style) = request.style {
style.to_system_prompt()
} else {
request.system_prompt.clone()
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
};
// LMStudio uses OpenAI-compatible API
#[derive(Serialize)]
struct LmStudioRequest {
model: String,
messages: Vec<LmStudioMessage>,
temperature: f32,
max_tokens: u32,
}
#[derive(Serialize)]
struct LmStudioMessage {
role: String,
content: String,
}
#[derive(Deserialize)]
struct LmStudioResponse {
choices: Vec<LmStudioChoice>,
usage: Option<LmStudioUsage>,
}
#[derive(Deserialize)]
struct LmStudioChoice {
message: LmStudioResponseMessage,
}
#[derive(Deserialize)]
struct LmStudioResponseMessage {
content: String,
}
#[derive(Deserialize)]
struct LmStudioUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
let lmstudio_request = LmStudioRequest {
model: model.clone(),
messages: vec![
LmStudioMessage {
role: "system".to_string(),
content: system_prompt,
},
LmStudioMessage {
role: "user".to_string(),
content: format!("Please summarize the following text:\n\n{}", request.text),
},
],
temperature,
max_tokens,
};
let response = self.client
.post(format!("{}/chat/completions", lmstudio_config.base_url))
.json(&lmstudio_request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow!("LMStudio API error: {}", error_text));
}
let lmstudio_response: LmStudioResponse = response.json().await?;
let summary = lmstudio_response.choices
.first()
.map(|c| c.message.content.clone())
.ok_or_else(|| anyhow!("No content in LMStudio response"))?;
let token_usage = if let Some(usage) = lmstudio_response.usage {
TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
estimated_cost_usd: None, // LMStudio is local, no cost
}
} else {
TokenUsage {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
estimated_cost_usd: None,
}
};
Ok((summary, model, token_usage))
}
pub async fn _list_lmstudio_models(&self) -> Result<Vec<_OllamaModelInfo>> {
let lmstudio_config = self.config.lmstudio.as_ref()
.ok_or_else(|| anyhow!("LMStudio config not found"))?;
let response = self.client
.get(format!("{}/models", lmstudio_config.base_url))
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow!("Failed to list LMStudio models"));
}
#[derive(Deserialize)]
struct LmStudioModelsResponse {
data: Vec<LmStudioModel>,
}
#[derive(Deserialize)]
struct LmStudioModel {
id: String,
}
let list_response: LmStudioModelsResponse = response.json().await?;
// Convert to OllamaModelInfo format for compatibility
Ok(list_response.data.into_iter().map(|m| _OllamaModelInfo {
name: m.id,
modified_at: "".to_string(),
size: 0,
}).collect())
}
}
// Calculate Anthropic API costs based on model and token usage using ClaudeModelType
fn calculate_anthropic_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
let claude_model = ClaudeModelType::from_id(model);
claude_model.calculate_cost(input_tokens, output_tokens)
}