403 lines
13 KiB
Rust
403 lines
13 KiB
Rust
use anyhow::{anyhow, Result};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use crate::config::Config;
|
|
use crate::models::{SummarizeRequest, TokenUsage};
|
|
use crate::v1::enums::LlmProvider;
|
|
use crate::claude_models::ClaudeModelType;
|
|
|
|
pub struct LlmClient {
|
|
client: Client,
|
|
config: Config,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct AnthropicRequest {
|
|
model: String,
|
|
max_tokens: u32,
|
|
temperature: f32,
|
|
messages: Vec<AnthropicMessage>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
system: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
top_p: Option<f32>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct AnthropicMessage {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct AnthropicResponse {
|
|
content: Vec<AnthropicContent>,
|
|
model: String,
|
|
usage: AnthropicUsage,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct AnthropicUsage {
|
|
input_tokens: u32,
|
|
output_tokens: u32,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct AnthropicContent {
|
|
text: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct OllamaRequest {
|
|
model: String,
|
|
prompt: String,
|
|
stream: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
system: Option<String>,
|
|
options: OllamaOptions,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct OllamaOptions {
|
|
temperature: f32,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
top_p: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
num_predict: Option<i32>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct OllamaResponse {
|
|
response: String,
|
|
model: String,
|
|
prompt_eval_count: Option<u32>,
|
|
eval_count: Option<u32>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct _OllamaModelInfo {
|
|
pub name: String,
|
|
pub modified_at: String,
|
|
pub size: i64,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct _OllamaListResponse {
|
|
models: Vec<_OllamaModelInfo>,
|
|
}
|
|
|
|
impl LlmClient {
|
|
pub fn new(config: Config) -> Self {
|
|
Self {
|
|
client: Client::new(),
|
|
config,
|
|
}
|
|
}
|
|
|
|
pub async fn summarize(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
|
|
match request.provider {
|
|
LlmProvider::Anthropic => self.summarize_anthropic(request).await,
|
|
LlmProvider::Ollama => self.summarize_ollama(request).await,
|
|
LlmProvider::Lmstudio => self.summarize_lmstudio(request).await,
|
|
}
|
|
}
|
|
|
|
async fn summarize_anthropic(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
|
|
let anthropic_config = self.config.anthropic.as_ref()
|
|
.ok_or_else(|| anyhow!("Anthropic config not found"))?;
|
|
|
|
let model = request.model.clone()
|
|
.unwrap_or_else(|| anthropic_config.model.clone());
|
|
|
|
let max_tokens = request.max_tokens
|
|
.unwrap_or(anthropic_config.max_tokens);
|
|
|
|
let temperature = request.temperature
|
|
.unwrap_or(anthropic_config.temperature);
|
|
|
|
let system_prompt = if let Some(ref style) = request.style {
|
|
style.to_system_prompt()
|
|
} else {
|
|
request.system_prompt.clone()
|
|
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
|
|
};
|
|
|
|
let anthropic_request = AnthropicRequest {
|
|
model: model.clone(),
|
|
max_tokens,
|
|
temperature,
|
|
messages: vec![
|
|
AnthropicMessage {
|
|
role: "user".to_string(),
|
|
content: format!("Please summarize the following text:\n\n{}", request.text),
|
|
}
|
|
],
|
|
system: Some(system_prompt),
|
|
top_p: request.top_p,
|
|
};
|
|
|
|
let response = self.client
|
|
.post(format!("{}/messages", anthropic_config.base_url))
|
|
.header("x-api-key", &anthropic_config.api_key)
|
|
.header("anthropic-version", "2023-06-01")
|
|
.header("content-type", "application/json")
|
|
.json(&anthropic_request)
|
|
.send()
|
|
.await?;
|
|
|
|
if !response.status().is_success() {
|
|
let error_text = response.text().await?;
|
|
return Err(anyhow!("Anthropic API error: {}", error_text));
|
|
}
|
|
|
|
let anthropic_response: AnthropicResponse = response.json().await?;
|
|
|
|
let summary = anthropic_response.content
|
|
.first()
|
|
.map(|c| c.text.clone())
|
|
.ok_or_else(|| anyhow!("No content in Anthropic response"))?;
|
|
|
|
let token_usage = TokenUsage {
|
|
input_tokens: anthropic_response.usage.input_tokens,
|
|
output_tokens: anthropic_response.usage.output_tokens,
|
|
total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens,
|
|
estimated_cost_usd: Some(calculate_anthropic_cost(
|
|
&model,
|
|
anthropic_response.usage.input_tokens,
|
|
anthropic_response.usage.output_tokens,
|
|
)),
|
|
};
|
|
|
|
Ok((summary, model, token_usage))
|
|
}
|
|
|
|
async fn summarize_ollama(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
|
|
let ollama_config = self.config.ollama.as_ref()
|
|
.ok_or_else(|| anyhow!("Ollama config not found"))?;
|
|
|
|
let model = request.model.clone()
|
|
.unwrap_or_else(|| ollama_config.model.clone());
|
|
|
|
let temperature = request.temperature
|
|
.unwrap_or(ollama_config.temperature);
|
|
|
|
let system_prompt = if let Some(ref style) = request.style {
|
|
style.to_system_prompt()
|
|
} else {
|
|
request.system_prompt.clone()
|
|
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
|
|
};
|
|
|
|
let prompt = format!("Please summarize the following text:\n\n{}", request.text);
|
|
|
|
let ollama_request = OllamaRequest {
|
|
model: model.clone(),
|
|
prompt,
|
|
stream: false,
|
|
system: Some(system_prompt),
|
|
options: OllamaOptions {
|
|
temperature,
|
|
top_p: request.top_p,
|
|
num_predict: request.max_tokens.map(|v| v as i32),
|
|
},
|
|
};
|
|
|
|
let response = self.client
|
|
.post(format!("{}/api/generate", ollama_config.base_url))
|
|
.json(&ollama_request)
|
|
.send()
|
|
.await?;
|
|
|
|
if !response.status().is_success() {
|
|
let error_text = response.text().await?;
|
|
return Err(anyhow!("Ollama API error: {}", error_text));
|
|
}
|
|
|
|
let ollama_response: OllamaResponse = response.json().await?;
|
|
|
|
let token_usage = TokenUsage {
|
|
input_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
|
|
output_tokens: ollama_response.eval_count.unwrap_or(0),
|
|
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0) + ollama_response.eval_count.unwrap_or(0),
|
|
estimated_cost_usd: None, // Ollama is local, no cost
|
|
};
|
|
|
|
Ok((ollama_response.response, model, token_usage))
|
|
}
|
|
|
|
pub async fn _list_ollama_models(&self) -> Result<Vec<_OllamaModelInfo>> {
|
|
let ollama_config = self.config.ollama.as_ref()
|
|
.ok_or_else(|| anyhow!("Ollama config not found"))?;
|
|
|
|
let response = self.client
|
|
.get(format!("{}/api/tags", ollama_config.base_url))
|
|
.send()
|
|
.await?;
|
|
|
|
if !response.status().is_success() {
|
|
return Err(anyhow!("Failed to list Ollama models"));
|
|
}
|
|
|
|
let list_response: _OllamaListResponse = response.json().await?;
|
|
Ok(list_response.models)
|
|
}
|
|
|
|
pub fn _get_anthropic_models(&self) -> Vec<String> {
|
|
// Return known Claude models from claude_models module
|
|
ClaudeModelType::all_available()
|
|
.into_iter()
|
|
.map(|m| m.id())
|
|
.collect()
|
|
}
|
|
|
|
async fn summarize_lmstudio(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
|
|
let lmstudio_config = self.config.lmstudio.as_ref()
|
|
.ok_or_else(|| anyhow!("LMStudio config not found"))?;
|
|
|
|
let model = request.model.clone()
|
|
.unwrap_or_else(|| lmstudio_config.model.clone());
|
|
|
|
let temperature = request.temperature
|
|
.unwrap_or(lmstudio_config.temperature);
|
|
|
|
let max_tokens = request.max_tokens
|
|
.unwrap_or(lmstudio_config.max_tokens);
|
|
|
|
let system_prompt = if let Some(ref style) = request.style {
|
|
style.to_system_prompt()
|
|
} else {
|
|
request.system_prompt.clone()
|
|
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
|
|
};
|
|
|
|
// LMStudio uses OpenAI-compatible API
|
|
#[derive(Serialize)]
|
|
struct LmStudioRequest {
|
|
model: String,
|
|
messages: Vec<LmStudioMessage>,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct LmStudioMessage {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct LmStudioResponse {
|
|
choices: Vec<LmStudioChoice>,
|
|
usage: Option<LmStudioUsage>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct LmStudioChoice {
|
|
message: LmStudioResponseMessage,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct LmStudioResponseMessage {
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct LmStudioUsage {
|
|
prompt_tokens: u32,
|
|
completion_tokens: u32,
|
|
total_tokens: u32,
|
|
}
|
|
|
|
let lmstudio_request = LmStudioRequest {
|
|
model: model.clone(),
|
|
messages: vec![
|
|
LmStudioMessage {
|
|
role: "system".to_string(),
|
|
content: system_prompt,
|
|
},
|
|
LmStudioMessage {
|
|
role: "user".to_string(),
|
|
content: format!("Please summarize the following text:\n\n{}", request.text),
|
|
},
|
|
],
|
|
temperature,
|
|
max_tokens,
|
|
};
|
|
|
|
let response = self.client
|
|
.post(format!("{}/chat/completions", lmstudio_config.base_url))
|
|
.json(&lmstudio_request)
|
|
.send()
|
|
.await?;
|
|
|
|
if !response.status().is_success() {
|
|
let error_text = response.text().await?;
|
|
return Err(anyhow!("LMStudio API error: {}", error_text));
|
|
}
|
|
|
|
let lmstudio_response: LmStudioResponse = response.json().await?;
|
|
|
|
let summary = lmstudio_response.choices
|
|
.first()
|
|
.map(|c| c.message.content.clone())
|
|
.ok_or_else(|| anyhow!("No content in LMStudio response"))?;
|
|
|
|
let token_usage = if let Some(usage) = lmstudio_response.usage {
|
|
TokenUsage {
|
|
input_tokens: usage.prompt_tokens,
|
|
output_tokens: usage.completion_tokens,
|
|
total_tokens: usage.total_tokens,
|
|
estimated_cost_usd: None, // LMStudio is local, no cost
|
|
}
|
|
} else {
|
|
TokenUsage {
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
total_tokens: 0,
|
|
estimated_cost_usd: None,
|
|
}
|
|
};
|
|
|
|
Ok((summary, model, token_usage))
|
|
}
|
|
|
|
pub async fn _list_lmstudio_models(&self) -> Result<Vec<_OllamaModelInfo>> {
|
|
let lmstudio_config = self.config.lmstudio.as_ref()
|
|
.ok_or_else(|| anyhow!("LMStudio config not found"))?;
|
|
|
|
let response = self.client
|
|
.get(format!("{}/models", lmstudio_config.base_url))
|
|
.send()
|
|
.await?;
|
|
|
|
if !response.status().is_success() {
|
|
return Err(anyhow!("Failed to list LMStudio models"));
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct LmStudioModelsResponse {
|
|
data: Vec<LmStudioModel>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct LmStudioModel {
|
|
id: String,
|
|
}
|
|
|
|
let list_response: LmStudioModelsResponse = response.json().await?;
|
|
|
|
// Convert to OllamaModelInfo format for compatibility
|
|
Ok(list_response.data.into_iter().map(|m| _OllamaModelInfo {
|
|
name: m.id,
|
|
modified_at: "".to_string(),
|
|
size: 0,
|
|
}).collect())
|
|
}
|
|
}
|
|
|
|
// Calculate Anthropic API costs based on model and token usage using ClaudeModelType
|
|
fn calculate_anthropic_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
|
|
let claude_model = ClaudeModelType::from_id(model);
|
|
claude_model.calculate_cost(input_tokens, output_tokens)
|
|
}
|