(Feat): Initial Commit

This commit is contained in:
2025-12-12 15:31:27 +00:00
commit 5b13236b80
48 changed files with 13146 additions and 0 deletions

4070
backend_api/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

30
backend_api/Cargo.toml Normal file
View File

@@ -0,0 +1,30 @@
[package]
name = "backend_api"
version = "0.1.0"
edition = "2021"
[dependencies]
poem = "3.1.12"
poem-openapi = { version = "5.1.16", features = ["swagger-ui"] }
tokio = { version = "1.48.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls", "multipart"] }
sha2 = "0.10"
hex = "0.4"
chrono = { version = "0.4", features = ["serde"] }
toml = "0.8"
anyhow = "1.0"
thiserror = "1.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
image = "0.25"
base64 = "0.22"
tesseract = { version = "0.14", optional = true }
validator = { version = "0.18", features = ["derive"] }
strum = { version = "0.26", features = ["derive"] }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "sqlite", "postgres", "chrono"] }
[features]
default = []
ocr = ["tesseract"]

0
backend_api/README.md Normal file
View File

170
backend_api/src/cache.rs Normal file
View File

@@ -0,0 +1,170 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use chrono::{DateTime, Utc, Duration};
use crate::models::CacheEntry;
#[derive(Clone)]
pub struct CacheItem {
pub entry: CacheEntry,
pub expires_at: DateTime<Utc>,
}
pub struct Cache {
store: Arc<RwLock<HashMap<String, CacheItem>>>,
ttl_seconds: i64,
max_size: usize,
hits: Arc<RwLock<u64>>,
misses: Arc<RwLock<u64>>,
}
impl Cache {
pub fn new(ttl_seconds: i64, max_size: usize) -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
ttl_seconds,
max_size,
hits: Arc::new(RwLock::new(0)),
misses: Arc::new(RwLock::new(0)),
}
}
pub async fn get(&self, checksum: &str) -> Option<CacheEntry> {
let mut store = self.store.write().await;
if let Some(item) = store.get_mut(checksum) {
// Check if expired
if Utc::now() > item.expires_at {
store.remove(checksum);
let mut misses = self.misses.write().await;
*misses += 1;
return None;
}
// Update access count and last accessed
item.entry.access_count += 1;
item.entry.last_accessed = Utc::now().to_rfc3339();
let mut hits = self.hits.write().await;
*hits += 1;
return Some(item.entry.clone());
}
let mut misses = self.misses.write().await;
*misses += 1;
None
}
pub async fn insert(&self, checksum: String, entry: CacheEntry) {
let mut store = self.store.write().await;
// If cache is full, remove oldest entry
if store.len() >= self.max_size {
if let Some(oldest_key) = store
.iter()
.min_by_key(|(_, item)| &item.entry.created_at)
.map(|(k, _)| k.clone())
{
store.remove(&oldest_key);
}
}
let expires_at = Utc::now() + Duration::seconds(self.ttl_seconds);
store.insert(checksum, CacheItem { entry, expires_at });
}
pub async fn get_all(&self, limit: Option<usize>, offset: Option<usize>) -> Vec<CacheEntry> {
let store = self.store.read().await;
let now = Utc::now();
let mut entries: Vec<CacheEntry> = store
.values()
.filter(|item| now <= item.expires_at)
.map(|item| item.entry.clone())
.collect();
// Sort by created_at descending
entries.sort_by(|a, b| b.created_at.cmp(&a.created_at));
let offset = offset.unwrap_or(0);
let limit = limit.unwrap_or(entries.len());
entries.into_iter().skip(offset).take(limit).collect()
}
pub async fn get_by_guild(&self, guild_id: &str) -> Vec<CacheEntry> {
let store = self.store.read().await;
let now = Utc::now();
let mut entries: Vec<CacheEntry> = store
.values()
.filter(|item| now <= item.expires_at && item.entry.guild_id == guild_id)
.map(|item| item.entry.clone())
.collect();
entries.sort_by(|a, b| b.created_at.cmp(&a.created_at));
entries
}
pub async fn get_by_user(&self, user_id: &str) -> Vec<CacheEntry> {
let store = self.store.read().await;
let now = Utc::now();
let mut entries: Vec<CacheEntry> = store
.values()
.filter(|item| now <= item.expires_at && item.entry.user_id == user_id)
.map(|item| item.entry.clone())
.collect();
entries.sort_by(|a, b| b.created_at.cmp(&a.created_at));
entries
}
pub async fn delete_by_id(&self, id: i64) -> u64 {
let mut store = self.store.write().await;
let initial_len = store.len();
store.retain(|_, item| item.entry.id != id);
(initial_len - store.len()) as u64
}
pub async fn delete_by_guild(&self, guild_id: &str) -> u64 {
let mut store = self.store.write().await;
let initial_len = store.len();
store.retain(|_, item| item.entry.guild_id != guild_id);
(initial_len - store.len()) as u64
}
pub async fn delete_by_user(&self, user_id: &str) -> u64 {
let mut store = self.store.write().await;
let initial_len = store.len();
store.retain(|_, item| item.entry.user_id != user_id);
(initial_len - store.len()) as u64
}
pub async fn delete_all(&self) -> u64 {
let mut store = self.store.write().await;
let count = store.len() as u64;
store.clear();
count
}
pub async fn get_stats(&self) -> (u64, u64, usize) {
let hits = *self.hits.read().await;
let misses = *self.misses.read().await;
let size = self.store.read().await.len();
(hits, misses, size)
}
pub async fn cleanup_expired(&self) {
let mut store = self.store.write().await;
let now = Utc::now();
store.retain(|_, item| now <= item.expires_at);
}
}

View File

@@ -0,0 +1,33 @@
use sha2::{Sha256, Digest};
use crate::models::SummarizeRequest;
pub fn calculate_checksum(request: &SummarizeRequest) -> String {
let mut hasher = Sha256::new();
// Include all relevant fields in the checksum
hasher.update(request.text.as_bytes());
hasher.update(format!("{:?}", request.provider).as_bytes());
if let Some(model) = &request.model {
hasher.update(model.as_bytes());
}
if let Some(temp) = request.temperature {
hasher.update(temp.to_string().as_bytes());
}
if let Some(max_tokens) = request.max_tokens {
hasher.update(max_tokens.to_string().as_bytes());
}
if let Some(top_p) = request.top_p {
hasher.update(top_p.to_string().as_bytes());
}
if let Some(system_prompt) = &request.system_prompt {
hasher.update(system_prompt.as_bytes());
}
let result = hasher.finalize();
hex::encode(result)
}

View File

@@ -0,0 +1,153 @@
use serde::{Deserialize, Serialize};
use std::fmt;
/// Claude model configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeModel {
/// Model identifier used in API calls
pub id: String,
/// Human-readable model name
pub name: String,
/// Input token cost per million tokens (USD)
pub input_price_per_million: f64,
/// Output token cost per million tokens (USD)
pub output_price_per_million: f64,
/// Maximum context window size
pub max_context_tokens: u32,
/// Whether this model is currently available
pub available: bool,
}
/// Enum of available Claude models
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ClaudeModelType {
/// Claude Sonnet 4.5 (2025-09-29) - Balanced performance and cost
Sonnet4_5,
/// Claude Haiku 4.5 (2025-10-01) - Fast and cost-effective
Haiku4_5,
/// Claude Opus 4.5 (2025-10-01) - Most capable, highest cost
Opus4_5,
/// Claude Sonnet 3.7 (2025-02-19) - Previous generation balanced model
Sonnet3_7,
/// Custom model (for flexibility)
Custom(String),
}
impl ClaudeModelType {
/// Get the model ID string for API calls
pub fn id(&self) -> String {
match self {
ClaudeModelType::Sonnet4_5 => "claude-sonnet-4-5-20250929".to_string(),
ClaudeModelType::Haiku4_5 => "claude-haiku-4-5-20251001".to_string(),
ClaudeModelType::Opus4_5 => "claude-opus-4-5-20251001".to_string(),
ClaudeModelType::Sonnet3_7 => "claude-3-7-sonnet-20250219".to_string(),
ClaudeModelType::Custom(id) => id.clone(),
}
}
/// Get full model information including pricing
pub fn info(&self) -> ClaudeModel {
match self {
ClaudeModelType::Sonnet4_5 => ClaudeModel {
id: self.id(),
name: "Claude Sonnet 4.5".to_string(),
input_price_per_million: 3.0,
output_price_per_million: 15.0,
max_context_tokens: 200_000,
available: true,
},
ClaudeModelType::Haiku4_5 => ClaudeModel {
id: self.id(),
name: "Claude Haiku 4.5".to_string(),
input_price_per_million: 0.8,
output_price_per_million: 4.0,
max_context_tokens: 200_000,
available: true,
},
ClaudeModelType::Opus4_5 => ClaudeModel {
id: self.id(),
name: "Claude Opus 4.5".to_string(),
input_price_per_million: 15.0,
output_price_per_million: 75.0,
max_context_tokens: 200_000,
available: true,
},
ClaudeModelType::Sonnet3_7 => ClaudeModel {
id: self.id(),
name: "Claude 3.7 Sonnet".to_string(),
input_price_per_million: 3.0,
output_price_per_million: 15.0,
max_context_tokens: 200_000,
available: true,
},
ClaudeModelType::Custom(id) => ClaudeModel {
id: id.clone(),
name: format!("Custom Model: {}", id),
input_price_per_million: 3.0, // Default to Sonnet pricing
output_price_per_million: 15.0,
max_context_tokens: 200_000,
available: true,
},
}
}
/// Parse a model ID string into a ClaudeModelType
pub fn from_id(id: &str) -> Self {
match id {
"claude-sonnet-4-5-20250929" => ClaudeModelType::Sonnet4_5,
"claude-haiku-4-5-20251001" => ClaudeModelType::Haiku4_5,
"claude-opus-4-5-20251001" => ClaudeModelType::Opus4_5,
"claude-3-7-sonnet-20250219" => ClaudeModelType::Sonnet3_7,
_ => ClaudeModelType::Custom(id.to_string()),
}
}
/// Get a list of all available Claude models
pub fn all_available() -> Vec<ClaudeModelType> {
vec![
ClaudeModelType::Sonnet4_5,
ClaudeModelType::Haiku4_5,
ClaudeModelType::Opus4_5,
ClaudeModelType::Sonnet3_7,
]
}
/// Calculate cost for token usage
pub fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
let info = self.info();
let input_cost = (input_tokens as f64 / 1_000_000.0) * info.input_price_per_million;
let output_cost = (output_tokens as f64 / 1_000_000.0) * info.output_price_per_million;
input_cost + output_cost
}
}
impl fmt::Display for ClaudeModelType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.id())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_id() {
assert_eq!(ClaudeModelType::Sonnet4_5.id(), "claude-sonnet-4-5-20250929");
assert_eq!(ClaudeModelType::Haiku4_5.id(), "claude-haiku-4-5-20251001");
}
#[test]
fn test_from_id() {
assert_eq!(
ClaudeModelType::from_id("claude-sonnet-4-5-20250929"),
ClaudeModelType::Sonnet4_5
);
}
#[test]
fn test_cost_calculation() {
let cost = ClaudeModelType::Sonnet4_5.calculate_cost(1_000_000, 1_000_000);
assert_eq!(cost, 18.0); // 3.0 + 15.0
}
}

103
backend_api/src/config.rs Normal file
View File

@@ -0,0 +1,103 @@
use serde::{Deserialize, Serialize};
use std::fs;
use anyhow::Result;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub anthropic: Option<AnthropicConfig>,
pub ollama: Option<OllamaConfig>,
pub lmstudio: Option<LmStudioConfig>,
pub rate_limiting: RateLimitingConfig,
pub cache: CacheConfig,
pub credits: CreditsConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CreditsConfig {
/// Whether the credit system is enabled
pub enabled: bool,
/// User IDs that bypass the credit system (bot owners)
#[serde(default)]
pub bypass_user_ids: Vec<String>,
/// Credits cost per summarization
#[serde(default = "default_credits_per_summary")]
pub credits_per_summary: i64,
/// Credits cost per OCR summarization
#[serde(default = "default_credits_per_ocr")]
pub credits_per_ocr: i64,
}
fn default_credits_per_summary() -> i64 { 1 }
fn default_credits_per_ocr() -> i64 { 2 }
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum DatabaseType {
Sqlite,
Postgres,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DatabaseConfig {
#[serde(rename = "type")]
pub db_type: DatabaseType,
pub url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AnthropicConfig {
pub api_key: String,
pub base_url: String,
pub model: String,
pub max_tokens: u32,
pub temperature: f32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OllamaConfig {
pub base_url: String,
pub model: String,
pub temperature: f32,
pub max_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LmStudioConfig {
pub base_url: String,
pub model: String,
pub temperature: f32,
pub max_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RateLimitingConfig {
pub enabled: bool,
pub guild_rate_limit: u32,
pub user_rate_limit: u32,
pub window_seconds: Option<i64>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CacheConfig {
pub enabled: bool,
pub ttl_seconds: Option<i64>,
pub max_size: Option<usize>,
}
impl Config {
pub fn load(path: &str) -> Result<Self> {
let content = fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
Ok(config)
}
}

View File

@@ -0,0 +1,46 @@
[server]
host = "127.0.0.1"
port = 3001
[anthropic]
api_key = "your-anthropic-api-key-here"
base_url = "https://api.anthropic.com/v1"
model = "claude-3-5-sonnet-20241022"
max_tokens = 4096
temperature = 1.0
[ollama]
base_url = "http://localhost:11434"
model = "llama3.2"
temperature = 0.7
max_tokens = 4096
[rate_limiting]
enabled = true
# Requests per minute per guild
guild_rate_limit = 30
# Requests per minute per user
user_rate_limit = 10
# Rate limit window in seconds (default: 60)
window_seconds = 60
[cache]
enabled = true
# Cache TTL in seconds (default: 3600 = 1 hour)
ttl_seconds = 3600
# Maximum number of cache entries (default: 1000)
max_size = 1000
[credits]
# Enable credit-based monetization
enabled = true
# User IDs that bypass the credit system (bot owners)
bypass_user_ids = ["YOUR_DISCORD_USER_ID_HERE"]
# Credits cost per summarization (default: 1)
credits_per_summary = 1
# Credits cost per OCR summarization (default: 2)
credits_per_ocr = 2
[database]
type = "sqlite"
url = "sqlite:summarizer.db"

621
backend_api/src/credits.rs Normal file
View File

@@ -0,0 +1,621 @@
//! Credit system for monetization
//!
//! Handles user credits, tiers, transactions, and bypass functionality.
use anyhow::Result;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use sqlx::Row;
use tracing::{debug, info};
use crate::database::DatabasePool;
/// User tier levels
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[allow(dead_code)]
pub enum TierLevel {
Free = 1,
Basic = 2,
Pro = 3,
Unlimited = 4,
}
#[allow(dead_code)]
impl TierLevel {
pub fn from_id(id: i32) -> Self {
match id {
1 => TierLevel::Free,
2 => TierLevel::Basic,
3 => TierLevel::Pro,
4 => TierLevel::Unlimited,
_ => TierLevel::Free,
}
}
pub fn name(&self) -> &'static str {
match self {
TierLevel::Free => "Free",
TierLevel::Basic => "Basic",
TierLevel::Pro => "Pro",
TierLevel::Unlimited => "Unlimited",
}
}
pub fn monthly_credits(&self) -> i64 {
match self {
TierLevel::Free => 50,
TierLevel::Basic => 500,
TierLevel::Pro => 2000,
TierLevel::Unlimited => -1, // Unlimited
}
}
pub fn max_messages(&self) -> i32 {
match self {
TierLevel::Free => 50,
TierLevel::Basic => 100,
TierLevel::Pro => 250,
TierLevel::Unlimited => 500,
}
}
pub fn price_usd(&self) -> f64 {
match self {
TierLevel::Free => 0.0,
TierLevel::Basic => 4.99,
TierLevel::Pro => 9.99,
TierLevel::Unlimited => 19.99,
}
}
}
/// User credit information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserCredits {
pub user_id: String,
pub credits: i64,
pub tier_id: i32,
pub lifetime_credits_used: i64,
pub created_at: String,
pub updated_at: String,
}
impl UserCredits {
pub fn tier(&self) -> TierLevel {
TierLevel::from_id(self.tier_id)
}
pub fn has_unlimited(&self) -> bool {
self.tier() == TierLevel::Unlimited
}
}
/// Credit transaction record
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct CreditTransaction {
pub id: i64,
pub user_id: String,
pub guild_id: Option<String>,
pub amount: i64,
pub transaction_type: String,
pub description: String,
pub balance_after: i64,
pub created_at: String,
}
/// Result of a credit check
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreditCheckResult {
pub allowed: bool,
pub reason: Option<String>,
pub credits_remaining: i64,
pub is_bypass: bool,
pub tier: String,
}
/// Credit system manager
pub struct CreditSystem {
bypass_user_ids: Vec<String>,
}
#[allow(dead_code)]
impl CreditSystem {
pub fn new(bypass_user_ids: Vec<String>) -> Self {
info!("Credit system initialized with {} bypass users", bypass_user_ids.len());
Self { bypass_user_ids }
}
/// Check if a user can bypass the credit system
pub fn is_bypass_user(&self, user_id: &str) -> bool {
self.bypass_user_ids.contains(&user_id.to_string())
}
/// Check if user has sufficient credits
pub async fn check_credits(
&self,
pool: &DatabasePool,
user_id: &str,
credits_needed: i64,
) -> Result<CreditCheckResult> {
// Check bypass first
if self.is_bypass_user(user_id) {
debug!("User {} is a bypass user, allowing request", user_id);
return Ok(CreditCheckResult {
allowed: true,
reason: None,
credits_remaining: -1,
is_bypass: true,
tier: "Bypass".to_string(),
});
}
// Get or create user credits
let user_credits = self.get_or_create_user(pool, user_id).await?;
// Unlimited tier users always pass
if user_credits.has_unlimited() {
return Ok(CreditCheckResult {
allowed: true,
reason: None,
credits_remaining: -1,
is_bypass: false,
tier: user_credits.tier().name().to_string(),
});
}
// Check if user has enough credits
if user_credits.credits >= credits_needed {
Ok(CreditCheckResult {
allowed: true,
reason: None,
credits_remaining: user_credits.credits,
is_bypass: false,
tier: user_credits.tier().name().to_string(),
})
} else {
Ok(CreditCheckResult {
allowed: false,
reason: Some(format!(
"Insufficient credits. You have {} but need {}",
user_credits.credits, credits_needed
)),
credits_remaining: user_credits.credits,
is_bypass: false,
tier: user_credits.tier().name().to_string(),
})
}
}
/// Deduct credits from a user
pub async fn deduct_credits(
&self,
pool: &DatabasePool,
user_id: &str,
amount: i64,
guild_id: Option<&str>,
description: &str,
) -> Result<i64> {
// Bypass users don't get charged
if self.is_bypass_user(user_id) {
debug!("Bypass user {}, not deducting credits", user_id);
return Ok(-1);
}
let user_credits = self.get_or_create_user(pool, user_id).await?;
// Unlimited tier users don't get charged
if user_credits.has_unlimited() {
debug!("Unlimited tier user {}, not deducting credits", user_id);
return Ok(-1);
}
let new_balance = (user_credits.credits - amount).max(0);
let now = Utc::now().to_rfc3339();
match pool {
DatabasePool::Sqlite(p) => {
sqlx::query(
"UPDATE user_credits SET credits = ?, lifetime_credits_used = lifetime_credits_used + ?, updated_at = ? WHERE user_id = ?"
)
.bind(new_balance)
.bind(amount)
.bind(&now)
.bind(user_id)
.execute(p)
.await?;
// Record transaction
sqlx::query(
"INSERT INTO credit_transactions (user_id, guild_id, amount, transaction_type, description, balance_after, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)"
)
.bind(user_id)
.bind(guild_id)
.bind(-amount)
.bind("usage")
.bind(description)
.bind(new_balance)
.bind(&now)
.execute(p)
.await?;
}
DatabasePool::Postgres(p) => {
sqlx::query(
"UPDATE user_credits SET credits = $1, lifetime_credits_used = lifetime_credits_used + $2, updated_at = $3 WHERE user_id = $4"
)
.bind(new_balance)
.bind(amount)
.bind(&now)
.bind(user_id)
.execute(p)
.await?;
sqlx::query(
"INSERT INTO credit_transactions (user_id, guild_id, amount, transaction_type, description, balance_after, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
.bind(user_id)
.bind(guild_id)
.bind(-amount)
.bind("usage")
.bind(description)
.bind(new_balance)
.bind(&now)
.execute(p)
.await?;
}
}
debug!("Deducted {} credits from user {}, new balance: {}", amount, user_id, new_balance);
Ok(new_balance)
}
/// Add credits to a user (for refunds, purchases, etc.)
pub async fn add_credits(
&self,
pool: &DatabasePool,
user_id: &str,
amount: i64,
guild_id: Option<&str>,
transaction_type: &str,
description: &str,
) -> Result<i64> {
let user_credits = self.get_or_create_user(pool, user_id).await?;
let new_balance = user_credits.credits + amount;
let now = Utc::now().to_rfc3339();
match pool {
DatabasePool::Sqlite(p) => {
sqlx::query(
"UPDATE user_credits SET credits = ?, updated_at = ? WHERE user_id = ?"
)
.bind(new_balance)
.bind(&now)
.bind(user_id)
.execute(p)
.await?;
sqlx::query(
"INSERT INTO credit_transactions (user_id, guild_id, amount, transaction_type, description, balance_after, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)"
)
.bind(user_id)
.bind(guild_id)
.bind(amount)
.bind(transaction_type)
.bind(description)
.bind(new_balance)
.bind(&now)
.execute(p)
.await?;
}
DatabasePool::Postgres(p) => {
sqlx::query(
"UPDATE user_credits SET credits = $1, updated_at = $2 WHERE user_id = $3"
)
.bind(new_balance)
.bind(&now)
.bind(user_id)
.execute(p)
.await?;
sqlx::query(
"INSERT INTO credit_transactions (user_id, guild_id, amount, transaction_type, description, balance_after, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
.bind(user_id)
.bind(guild_id)
.bind(amount)
.bind(transaction_type)
.bind(description)
.bind(new_balance)
.bind(&now)
.execute(p)
.await?;
}
}
info!("Added {} credits to user {}, new balance: {}", amount, user_id, new_balance);
Ok(new_balance)
}
/// Get or create user credits
pub async fn get_or_create_user(&self, pool: &DatabasePool, user_id: &str) -> Result<UserCredits> {
let now = Utc::now().to_rfc3339();
let default_credits = TierLevel::Free.monthly_credits();
match pool {
DatabasePool::Sqlite(p) => {
// Try to get existing user
let row = sqlx::query(
"SELECT user_id, credits, tier_id, lifetime_credits_used, created_at, updated_at FROM user_credits WHERE user_id = ?"
)
.bind(user_id)
.fetch_optional(p)
.await?;
if let Some(r) = row {
return Ok(UserCredits {
user_id: r.get(0),
credits: r.get(1),
tier_id: r.get(2),
lifetime_credits_used: r.get(3),
created_at: r.get(4),
updated_at: r.get(5),
});
}
// Create new user
sqlx::query(
"INSERT INTO user_credits (user_id, credits, tier_id, lifetime_credits_used, created_at, updated_at) VALUES (?, ?, 1, 0, ?, ?)"
)
.bind(user_id)
.bind(default_credits)
.bind(&now)
.bind(&now)
.execute(p)
.await?;
// Record initial grant
sqlx::query(
"INSERT INTO credit_transactions (user_id, guild_id, amount, transaction_type, description, balance_after, created_at) VALUES (?, NULL, ?, ?, ?, ?, ?)"
)
.bind(user_id)
.bind(default_credits)
.bind("grant")
.bind("Initial free tier credits")
.bind(default_credits)
.bind(&now)
.execute(p)
.await?;
Ok(UserCredits {
user_id: user_id.to_string(),
credits: default_credits,
tier_id: 1,
lifetime_credits_used: 0,
created_at: now.clone(),
updated_at: now,
})
}
DatabasePool::Postgres(p) => {
let row = sqlx::query(
"SELECT user_id, credits, tier_id, lifetime_credits_used, created_at, updated_at FROM user_credits WHERE user_id = $1"
)
.bind(user_id)
.fetch_optional(p)
.await?;
if let Some(r) = row {
return Ok(UserCredits {
user_id: r.get(0),
credits: r.get(1),
tier_id: r.get(2),
lifetime_credits_used: r.get(3),
created_at: r.get(4),
updated_at: r.get(5),
});
}
sqlx::query(
"INSERT INTO user_credits (user_id, credits, tier_id, lifetime_credits_used, created_at, updated_at) VALUES ($1, $2, 1, 0, $3, $4)"
)
.bind(user_id)
.bind(default_credits)
.bind(&now)
.bind(&now)
.execute(p)
.await?;
sqlx::query(
"INSERT INTO credit_transactions (user_id, guild_id, amount, transaction_type, description, balance_after, created_at) VALUES ($1, NULL, $2, $3, $4, $5, $6)"
)
.bind(user_id)
.bind(default_credits)
.bind("grant")
.bind("Initial free tier credits")
.bind(default_credits)
.bind(&now)
.execute(p)
.await?;
Ok(UserCredits {
user_id: user_id.to_string(),
credits: default_credits,
tier_id: 1,
lifetime_credits_used: 0,
created_at: now.clone(),
updated_at: now,
})
}
}
}
/// Get user's transaction history
pub async fn get_transactions(
&self,
pool: &DatabasePool,
user_id: &str,
limit: i64,
) -> Result<Vec<CreditTransaction>> {
match pool {
DatabasePool::Sqlite(p) => {
let rows = sqlx::query(
"SELECT id, user_id, guild_id, amount, transaction_type, description, balance_after, created_at FROM credit_transactions WHERE user_id = ? ORDER BY created_at DESC LIMIT ?"
)
.bind(user_id)
.bind(limit)
.fetch_all(p)
.await?;
Ok(rows.into_iter().map(|r| CreditTransaction {
id: r.get(0),
user_id: r.get(1),
guild_id: r.get(2),
amount: r.get(3),
transaction_type: r.get(4),
description: r.get(5),
balance_after: r.get(6),
created_at: r.get(7),
}).collect())
}
DatabasePool::Postgres(p) => {
let rows = sqlx::query(
"SELECT id, user_id, guild_id, amount, transaction_type, description, balance_after, created_at FROM credit_transactions WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2"
)
.bind(user_id)
.bind(limit)
.fetch_all(p)
.await?;
Ok(rows.into_iter().map(|r| CreditTransaction {
id: r.get(0),
user_id: r.get(1),
guild_id: r.get(2),
amount: r.get(3),
transaction_type: r.get(4),
description: r.get(5),
balance_after: r.get(6),
created_at: r.get(7),
}).collect())
}
}
}
/// Set user tier
pub async fn set_user_tier(
&self,
pool: &DatabasePool,
user_id: &str,
tier: TierLevel,
) -> Result<()> {
let now = Utc::now().to_rfc3339();
match pool {
DatabasePool::Sqlite(p) => {
sqlx::query(
"UPDATE user_credits SET tier_id = ?, updated_at = ? WHERE user_id = ?"
)
.bind(tier as i32)
.bind(&now)
.bind(user_id)
.execute(p)
.await?;
}
DatabasePool::Postgres(p) => {
sqlx::query(
"UPDATE user_credits SET tier_id = $1, updated_at = $2 WHERE user_id = $3"
)
.bind(tier as i32)
.bind(&now)
.bind(user_id)
.execute(p)
.await?;
}
}
info!("Set user {} tier to {:?}", user_id, tier);
Ok(())
}
}
/// Create credit system tables
pub async fn create_credit_tables(pool: &DatabasePool) -> Result<()> {
match pool {
DatabasePool::Sqlite(p) => {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS user_credits (
user_id TEXT PRIMARY KEY,
credits INTEGER NOT NULL DEFAULT 50,
tier_id INTEGER NOT NULL DEFAULT 1,
lifetime_credits_used INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"#
)
.execute(p)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS credit_transactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
guild_id TEXT,
amount INTEGER NOT NULL,
transaction_type TEXT NOT NULL,
description TEXT NOT NULL,
balance_after INTEGER NOT NULL,
created_at TEXT NOT NULL
)
"#
)
.execute(p)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_credit_transactions_user ON credit_transactions(user_id)")
.execute(p)
.await?;
}
DatabasePool::Postgres(p) => {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS user_credits (
user_id TEXT PRIMARY KEY,
credits BIGINT NOT NULL DEFAULT 50,
tier_id INTEGER NOT NULL DEFAULT 1,
lifetime_credits_used BIGINT NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"#
)
.execute(p)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS credit_transactions (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
guild_id TEXT,
amount BIGINT NOT NULL,
transaction_type TEXT NOT NULL,
description TEXT NOT NULL,
balance_after BIGINT NOT NULL,
created_at TEXT NOT NULL
)
"#
)
.execute(p)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_credit_transactions_user ON credit_transactions(user_id)")
.execute(p)
.await?;
}
}
info!("Credit system tables created");
Ok(())
}

781
backend_api/src/database.rs Normal file
View File

@@ -0,0 +1,781 @@
use anyhow::Result;
use chrono::Utc;
use sqlx::{sqlite::SqlitePool, postgres::PgPool, Row};
use crate::models::{CacheEntry, CacheStats};
use crate::config::DatabaseType;
pub enum DatabasePool {
Sqlite(SqlitePool),
Postgres(PgPool),
}
pub struct Database {
pool: DatabasePool,
}
impl Database {
/// Get a reference to the underlying database pool
pub fn pool(&self) -> &DatabasePool {
&self.pool
}
pub async fn new(database_url: &str, db_type: DatabaseType) -> Result<Self> {
let pool = match db_type {
DatabaseType::Sqlite => {
let pool = SqlitePool::connect(database_url).await?;
DatabasePool::Sqlite(pool)
}
DatabaseType::Postgres => {
let pool = PgPool::connect(database_url).await?;
DatabasePool::Postgres(pool)
}
};
let db = Self { pool };
db.create_tables().await?;
db.initialize_stats().await?;
db.create_indexes().await?;
Ok(db)
}
async fn create_tables(&self) -> Result<()> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS cache (
id INTEGER PRIMARY KEY AUTOINCREMENT,
checksum TEXT NOT NULL UNIQUE,
text TEXT NOT NULL,
summary TEXT NOT NULL,
provider TEXT NOT NULL,
model TEXT NOT NULL,
guild_id TEXT NOT NULL,
user_id TEXT NOT NULL,
channel_id TEXT,
temperature REAL,
max_tokens INTEGER,
top_p REAL,
system_prompt TEXT,
created_at TEXT NOT NULL,
last_accessed TEXT NOT NULL,
access_count INTEGER DEFAULT 1
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS rate_limits (
id INTEGER PRIMARY KEY AUTOINCREMENT,
guild_id TEXT NOT NULL,
user_id TEXT NOT NULL,
window_start TEXT NOT NULL,
request_count INTEGER DEFAULT 1,
UNIQUE(guild_id, user_id, window_start)
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS stats (
id INTEGER PRIMARY KEY CHECK (id = 1),
total_hits INTEGER DEFAULT 0,
total_misses INTEGER DEFAULT 0
)
"#,
)
.execute(pool)
.await?;
}
DatabasePool::Postgres(pool) => {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS cache (
id BIGSERIAL PRIMARY KEY,
checksum TEXT NOT NULL UNIQUE,
text TEXT NOT NULL,
summary TEXT NOT NULL,
provider TEXT NOT NULL,
model TEXT NOT NULL,
guild_id TEXT NOT NULL,
user_id TEXT NOT NULL,
channel_id TEXT,
temperature REAL,
max_tokens INTEGER,
top_p REAL,
system_prompt TEXT,
created_at TEXT NOT NULL,
last_accessed TEXT NOT NULL,
access_count BIGINT DEFAULT 1
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS rate_limits (
id BIGSERIAL PRIMARY KEY,
guild_id TEXT NOT NULL,
user_id TEXT NOT NULL,
window_start TEXT NOT NULL,
request_count BIGINT DEFAULT 1,
UNIQUE(guild_id, user_id, window_start)
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS stats (
id INTEGER PRIMARY KEY CHECK (id = 1),
total_hits BIGINT DEFAULT 0,
total_misses BIGINT DEFAULT 0
)
"#,
)
.execute(pool)
.await?;
}
}
Ok(())
}
async fn initialize_stats(&self) -> Result<()> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
sqlx::query("INSERT OR IGNORE INTO stats (id, total_hits, total_misses) VALUES (1, 0, 0)")
.execute(pool)
.await?;
}
DatabasePool::Postgres(pool) => {
sqlx::query("INSERT INTO stats (id, total_hits, total_misses) VALUES (1, 0, 0) ON CONFLICT (id) DO NOTHING")
.execute(pool)
.await?;
}
}
Ok(())
}
async fn create_indexes(&self) -> Result<()> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
sqlx::query("CREATE INDEX IF NOT EXISTS idx_cache_checksum ON cache(checksum)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_cache_guild ON cache(guild_id)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_cache_user ON cache(user_id)")
.execute(pool)
.await?;
}
DatabasePool::Postgres(pool) => {
sqlx::query("CREATE INDEX IF NOT EXISTS idx_cache_checksum ON cache(checksum)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_cache_guild ON cache(guild_id)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_cache_user ON cache(user_id)")
.execute(pool)
.await?;
}
}
Ok(())
}
pub async fn get_cache_by_checksum(&self, checksum: &str) -> Result<Option<CacheEntry>> {
let now = Utc::now().to_rfc3339();
match &self.pool {
DatabasePool::Sqlite(pool) => {
// Update access count and last_accessed
sqlx::query(
"UPDATE cache SET access_count = access_count + 1, last_accessed = ? WHERE checksum = ?"
)
.bind(&now)
.bind(checksum)
.execute(pool)
.await?;
// Increment cache hit
sqlx::query("UPDATE stats SET total_hits = total_hits + 1 WHERE id = 1")
.execute(pool)
.await?;
let row = sqlx::query(
"SELECT id, checksum, text, summary, provider, model, guild_id, user_id, channel_id, created_at, last_accessed, access_count FROM cache WHERE checksum = ?"
)
.bind(checksum)
.fetch_optional(pool)
.await?;
Ok(row.map(|r| CacheEntry {
id: r.get(0),
checksum: r.get(1),
text: r.get(2),
summary: r.get(3),
provider: r.get(4),
model: r.get(5),
guild_id: r.get(6),
user_id: r.get(7),
channel_id: r.get(8),
created_at: r.get(9),
last_accessed: r.get(10),
access_count: r.get(11),
}))
}
DatabasePool::Postgres(pool) => {
// Update access count and last_accessed
sqlx::query(
"UPDATE cache SET access_count = access_count + 1, last_accessed = $1 WHERE checksum = $2"
)
.bind(&now)
.bind(checksum)
.execute(pool)
.await?;
// Increment cache hit
sqlx::query("UPDATE stats SET total_hits = total_hits + 1 WHERE id = 1")
.execute(pool)
.await?;
let row = sqlx::query(
"SELECT id, checksum, text, summary, provider, model, guild_id, user_id, channel_id, created_at, last_accessed, access_count FROM cache WHERE checksum = $1"
)
.bind(checksum)
.fetch_optional(pool)
.await?;
Ok(row.map(|r| CacheEntry {
id: r.get(0),
checksum: r.get(1),
text: r.get(2),
summary: r.get(3),
provider: r.get(4),
model: r.get(5),
guild_id: r.get(6),
user_id: r.get(7),
channel_id: r.get(8),
created_at: r.get(9),
last_accessed: r.get(10),
access_count: r.get(11),
}))
}
}
}
pub async fn insert_cache(
&self,
checksum: &str,
text: &str,
summary: &str,
provider: &str,
model: &str,
guild_id: &str,
user_id: &str,
channel_id: Option<&str>,
temperature: Option<f32>,
max_tokens: Option<u32>,
top_p: Option<f32>,
system_prompt: Option<&str>,
) -> Result<i64> {
let now = Utc::now().to_rfc3339();
match &self.pool {
DatabasePool::Sqlite(pool) => {
// Increment cache miss
sqlx::query("UPDATE stats SET total_misses = total_misses + 1 WHERE id = 1")
.execute(pool)
.await?;
let result = sqlx::query(
r#"
INSERT INTO cache (checksum, text, summary, provider, model, guild_id, user_id, channel_id, temperature, max_tokens, top_p, system_prompt, created_at, last_accessed)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#
)
.bind(checksum)
.bind(text)
.bind(summary)
.bind(provider)
.bind(model)
.bind(guild_id)
.bind(user_id)
.bind(channel_id)
.bind(temperature)
.bind(max_tokens.map(|v| v as i64))
.bind(top_p)
.bind(system_prompt)
.bind(&now)
.bind(&now)
.execute(pool)
.await?;
Ok(result.last_insert_rowid())
}
DatabasePool::Postgres(pool) => {
// Increment cache miss
sqlx::query("UPDATE stats SET total_misses = total_misses + 1 WHERE id = 1")
.execute(pool)
.await?;
let result = sqlx::query(
r#"
INSERT INTO cache (checksum, text, summary, provider, model, guild_id, user_id, channel_id, temperature, max_tokens, top_p, system_prompt, created_at, last_accessed)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING id
"#
)
.bind(checksum)
.bind(text)
.bind(summary)
.bind(provider)
.bind(model)
.bind(guild_id)
.bind(user_id)
.bind(channel_id)
.bind(temperature)
.bind(max_tokens.map(|v| v as i32))
.bind(top_p)
.bind(system_prompt)
.bind(&now)
.bind(&now)
.fetch_one(pool)
.await?;
Ok(result.get(0))
}
}
}
pub async fn get_all_cache_entries(&self, limit: Option<i64>, offset: Option<i64>) -> Result<Vec<CacheEntry>> {
let limit = limit.unwrap_or(100);
let offset = offset.unwrap_or(0);
match &self.pool {
DatabasePool::Sqlite(pool) => {
let rows = sqlx::query(
"SELECT id, checksum, text, summary, provider, model, guild_id, user_id, channel_id, created_at, last_accessed, access_count FROM cache ORDER BY last_accessed DESC LIMIT ? OFFSET ?"
)
.bind(limit)
.bind(offset)
.fetch_all(pool)
.await?;
Ok(rows.into_iter().map(|r| CacheEntry {
id: r.get(0),
checksum: r.get(1),
text: r.get(2),
summary: r.get(3),
provider: r.get(4),
model: r.get(5),
guild_id: r.get(6),
user_id: r.get(7),
channel_id: r.get(8),
created_at: r.get(9),
last_accessed: r.get(10),
access_count: r.get(11),
}).collect())
}
DatabasePool::Postgres(pool) => {
let rows = sqlx::query(
"SELECT id, checksum, text, summary, provider, model, guild_id, user_id, channel_id, created_at, last_accessed, access_count FROM cache ORDER BY last_accessed DESC LIMIT $1 OFFSET $2"
)
.bind(limit)
.bind(offset)
.fetch_all(pool)
.await?;
Ok(rows.into_iter().map(|r| CacheEntry {
id: r.get(0),
checksum: r.get(1),
text: r.get(2),
summary: r.get(3),
provider: r.get(4),
model: r.get(5),
guild_id: r.get(6),
user_id: r.get(7),
channel_id: r.get(8),
created_at: r.get(9),
last_accessed: r.get(10),
access_count: r.get(11),
}).collect())
}
}
}
pub async fn get_cache_by_guild(&self, guild_id: &str) -> Result<Vec<CacheEntry>> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
let rows = sqlx::query(
"SELECT id, checksum, text, summary, provider, model, guild_id, user_id, channel_id, created_at, last_accessed, access_count FROM cache WHERE guild_id = ? ORDER BY last_accessed DESC"
)
.bind(guild_id)
.fetch_all(pool)
.await?;
Ok(rows.into_iter().map(|r| CacheEntry {
id: r.get(0),
checksum: r.get(1),
text: r.get(2),
summary: r.get(3),
provider: r.get(4),
model: r.get(5),
guild_id: r.get(6),
user_id: r.get(7),
channel_id: r.get(8),
created_at: r.get(9),
last_accessed: r.get(10),
access_count: r.get(11),
}).collect())
}
DatabasePool::Postgres(pool) => {
let rows = sqlx::query(
"SELECT id, checksum, text, summary, provider, model, guild_id, user_id, channel_id, created_at, last_accessed, access_count FROM cache WHERE guild_id = $1 ORDER BY last_accessed DESC"
)
.bind(guild_id)
.fetch_all(pool)
.await?;
Ok(rows.into_iter().map(|r| CacheEntry {
id: r.get(0),
checksum: r.get(1),
text: r.get(2),
summary: r.get(3),
provider: r.get(4),
model: r.get(5),
guild_id: r.get(6),
user_id: r.get(7),
channel_id: r.get(8),
created_at: r.get(9),
last_accessed: r.get(10),
access_count: r.get(11),
}).collect())
}
}
}
pub async fn get_cache_by_user(&self, user_id: &str) -> Result<Vec<CacheEntry>> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
let rows = sqlx::query(
"SELECT id, checksum, text, summary, provider, model, guild_id, user_id, channel_id, created_at, last_accessed, access_count FROM cache WHERE user_id = ? ORDER BY last_accessed DESC"
)
.bind(user_id)
.fetch_all(pool)
.await?;
Ok(rows.into_iter().map(|r| CacheEntry {
id: r.get(0),
checksum: r.get(1),
text: r.get(2),
summary: r.get(3),
provider: r.get(4),
model: r.get(5),
guild_id: r.get(6),
user_id: r.get(7),
channel_id: r.get(8),
created_at: r.get(9),
last_accessed: r.get(10),
access_count: r.get(11),
}).collect())
}
DatabasePool::Postgres(pool) => {
let rows = sqlx::query(
"SELECT id, checksum, text, summary, provider, model, guild_id, user_id, channel_id, created_at, last_accessed, access_count FROM cache WHERE user_id = $1 ORDER BY last_accessed DESC"
)
.bind(user_id)
.fetch_all(pool)
.await?;
Ok(rows.into_iter().map(|r| CacheEntry {
id: r.get(0),
checksum: r.get(1),
text: r.get(2),
summary: r.get(3),
provider: r.get(4),
model: r.get(5),
guild_id: r.get(6),
user_id: r.get(7),
channel_id: r.get(8),
created_at: r.get(9),
last_accessed: r.get(10),
access_count: r.get(11),
}).collect())
}
}
}
pub async fn delete_cache_by_id(&self, id: i64) -> Result<u64> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
let result = sqlx::query("DELETE FROM cache WHERE id = ?")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
DatabasePool::Postgres(pool) => {
let result = sqlx::query("DELETE FROM cache WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
}
}
pub async fn delete_cache_by_guild(&self, guild_id: &str) -> Result<u64> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
let result = sqlx::query("DELETE FROM cache WHERE guild_id = ?")
.bind(guild_id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
DatabasePool::Postgres(pool) => {
let result = sqlx::query("DELETE FROM cache WHERE guild_id = $1")
.bind(guild_id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
}
}
pub async fn delete_cache_by_user(&self, user_id: &str) -> Result<u64> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
let result = sqlx::query("DELETE FROM cache WHERE user_id = ?")
.bind(user_id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
DatabasePool::Postgres(pool) => {
let result = sqlx::query("DELETE FROM cache WHERE user_id = $1")
.bind(user_id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
}
}
pub async fn delete_all_cache(&self) -> Result<u64> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
let result = sqlx::query("DELETE FROM cache")
.execute(pool)
.await?;
Ok(result.rows_affected())
}
DatabasePool::Postgres(pool) => {
let result = sqlx::query("DELETE FROM cache")
.execute(pool)
.await?;
Ok(result.rows_affected())
}
}
}
pub async fn get_cache_stats(&self) -> Result<CacheStats> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
let total_entries: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM cache")
.fetch_one(pool)
.await?;
let stats_row = sqlx::query("SELECT total_hits, total_misses FROM stats WHERE id = 1")
.fetch_one(pool)
.await?;
let total_hits: i64 = stats_row.get(0);
let total_misses: i64 = stats_row.get(1);
let total_requests = total_hits + total_misses;
let hit_rate = if total_requests > 0 {
(total_hits as f64 / total_requests as f64) * 100.0
} else {
0.0
};
let total_size: i64 = sqlx::query_scalar(
"SELECT COALESCE(SUM(LENGTH(text) + LENGTH(summary)), 0) FROM cache"
)
.fetch_one(pool)
.await?;
Ok(CacheStats {
total_entries,
total_hits,
total_misses,
hit_rate,
total_size_bytes: total_size,
})
}
DatabasePool::Postgres(pool) => {
let total_entries: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM cache")
.fetch_one(pool)
.await?;
let stats_row = sqlx::query("SELECT total_hits, total_misses FROM stats WHERE id = 1")
.fetch_one(pool)
.await?;
let total_hits: i64 = stats_row.get(0);
let total_misses: i64 = stats_row.get(1);
let total_requests = total_hits + total_misses;
let hit_rate = if total_requests > 0 {
(total_hits as f64 / total_requests as f64) * 100.0
} else {
0.0
};
let total_size: i64 = sqlx::query_scalar(
"SELECT COALESCE(SUM(LENGTH(text) + LENGTH(summary)), 0) FROM cache"
)
.fetch_one(pool)
.await?;
Ok(CacheStats {
total_entries,
total_hits,
total_misses,
hit_rate,
total_size_bytes: total_size,
})
}
}
}
pub async fn _check_rate_limit(&self, guild_id: &str, user_id: &str, guild_limit: u32, user_limit: u32) -> Result<(bool, bool)> {
let now = Utc::now();
let window_start = now.format("%Y-%m-%d %H:%M").to_string();
match &self.pool {
DatabasePool::Sqlite(pool) => {
// Clean old rate limit entries (older than 1 minute)
sqlx::query("DELETE FROM rate_limits WHERE window_start < datetime('now', '-1 minute')")
.execute(pool)
.await?;
// Check guild rate limit
let guild_count: i64 = sqlx::query_scalar(
"SELECT COALESCE(SUM(request_count), 0) FROM rate_limits WHERE guild_id = ? AND window_start = ?"
)
.bind(guild_id)
.bind(&window_start)
.fetch_one(pool)
.await?;
// Check user rate limit
let user_count: i64 = sqlx::query_scalar(
"SELECT COALESCE(SUM(request_count), 0) FROM rate_limits WHERE user_id = ? AND window_start = ?"
)
.bind(user_id)
.bind(&window_start)
.fetch_one(pool)
.await?;
let guild_limited = guild_count >= guild_limit as i64;
let user_limited = user_count >= user_limit as i64;
if !guild_limited && !user_limited {
// Increment rate limit counter
sqlx::query(
r#"
INSERT INTO rate_limits (guild_id, user_id, window_start, request_count)
VALUES (?, ?, ?, 1)
ON CONFLICT(guild_id, user_id, window_start) DO UPDATE SET request_count = request_count + 1
"#
)
.bind(guild_id)
.bind(user_id)
.bind(&window_start)
.execute(pool)
.await?;
}
Ok((guild_limited, user_limited))
}
DatabasePool::Postgres(pool) => {
// Clean old rate limit entries (older than 1 minute)
sqlx::query("DELETE FROM rate_limits WHERE window_start < (NOW() - INTERVAL '1 minute')::TEXT")
.execute(pool)
.await?;
// Check guild rate limit
let guild_count: i64 = sqlx::query_scalar(
"SELECT COALESCE(SUM(request_count), 0) FROM rate_limits WHERE guild_id = $1 AND window_start = $2"
)
.bind(guild_id)
.bind(&window_start)
.fetch_one(pool)
.await?;
// Check user rate limit
let user_count: i64 = sqlx::query_scalar(
"SELECT COALESCE(SUM(request_count), 0) FROM rate_limits WHERE user_id = $1 AND window_start = $2"
)
.bind(user_id)
.bind(&window_start)
.fetch_one(pool)
.await?;
let guild_limited = guild_count >= guild_limit as i64;
let user_limited = user_count >= user_limit as i64;
if !guild_limited && !user_limited {
// Increment rate limit counter
sqlx::query(
r#"
INSERT INTO rate_limits (guild_id, user_id, window_start, request_count)
VALUES ($1, $2, $3, 1)
ON CONFLICT(guild_id, user_id, window_start) DO UPDATE SET request_count = rate_limits.request_count + 1
"#
)
.bind(guild_id)
.bind(user_id)
.bind(&window_start)
.execute(pool)
.await?;
}
Ok((guild_limited, user_limited))
}
}
}
/// Check if the database connection is healthy
pub async fn health_check(&self) -> Result<bool> {
match &self.pool {
DatabasePool::Sqlite(pool) => {
sqlx::query("SELECT 1")
.fetch_one(pool)
.await?;
Ok(true)
}
DatabasePool::Postgres(pool) => {
sqlx::query("SELECT 1")
.fetch_one(pool)
.await?;
Ok(true)
}
}
}
}

402
backend_api/src/llm.rs Normal file
View File

@@ -0,0 +1,402 @@
use anyhow::{anyhow, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::config::Config;
use crate::models::{SummarizeRequest, TokenUsage};
use crate::v1::enums::LlmProvider;
use crate::claude_models::ClaudeModelType;
pub struct LlmClient {
client: Client,
config: Config,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicRequest {
model: String,
max_tokens: u32,
temperature: f32,
messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicResponse {
content: Vec<AnthropicContent>,
model: String,
usage: AnthropicUsage,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicUsage {
input_tokens: u32,
output_tokens: u32,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicContent {
text: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaRequest {
model: String,
prompt: String,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
options: OllamaOptions,
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaOptions {
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<i32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaResponse {
response: String,
model: String,
prompt_eval_count: Option<u32>,
eval_count: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct _OllamaModelInfo {
pub name: String,
pub modified_at: String,
pub size: i64,
}
#[derive(Debug, Serialize, Deserialize)]
struct _OllamaListResponse {
models: Vec<_OllamaModelInfo>,
}
impl LlmClient {
pub fn new(config: Config) -> Self {
Self {
client: Client::new(),
config,
}
}
pub async fn summarize(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
match request.provider {
LlmProvider::Anthropic => self.summarize_anthropic(request).await,
LlmProvider::Ollama => self.summarize_ollama(request).await,
LlmProvider::Lmstudio => self.summarize_lmstudio(request).await,
}
}
async fn summarize_anthropic(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
let anthropic_config = self.config.anthropic.as_ref()
.ok_or_else(|| anyhow!("Anthropic config not found"))?;
let model = request.model.clone()
.unwrap_or_else(|| anthropic_config.model.clone());
let max_tokens = request.max_tokens
.unwrap_or(anthropic_config.max_tokens);
let temperature = request.temperature
.unwrap_or(anthropic_config.temperature);
let system_prompt = if let Some(ref style) = request.style {
style.to_system_prompt()
} else {
request.system_prompt.clone()
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
};
let anthropic_request = AnthropicRequest {
model: model.clone(),
max_tokens,
temperature,
messages: vec![
AnthropicMessage {
role: "user".to_string(),
content: format!("Please summarize the following text:\n\n{}", request.text),
}
],
system: Some(system_prompt),
top_p: request.top_p,
};
let response = self.client
.post(format!("{}/messages", anthropic_config.base_url))
.header("x-api-key", &anthropic_config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&anthropic_request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow!("Anthropic API error: {}", error_text));
}
let anthropic_response: AnthropicResponse = response.json().await?;
let summary = anthropic_response.content
.first()
.map(|c| c.text.clone())
.ok_or_else(|| anyhow!("No content in Anthropic response"))?;
let token_usage = TokenUsage {
input_tokens: anthropic_response.usage.input_tokens,
output_tokens: anthropic_response.usage.output_tokens,
total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens,
estimated_cost_usd: Some(calculate_anthropic_cost(
&model,
anthropic_response.usage.input_tokens,
anthropic_response.usage.output_tokens,
)),
};
Ok((summary, model, token_usage))
}
async fn summarize_ollama(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
let ollama_config = self.config.ollama.as_ref()
.ok_or_else(|| anyhow!("Ollama config not found"))?;
let model = request.model.clone()
.unwrap_or_else(|| ollama_config.model.clone());
let temperature = request.temperature
.unwrap_or(ollama_config.temperature);
let system_prompt = if let Some(ref style) = request.style {
style.to_system_prompt()
} else {
request.system_prompt.clone()
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
};
let prompt = format!("Please summarize the following text:\n\n{}", request.text);
let ollama_request = OllamaRequest {
model: model.clone(),
prompt,
stream: false,
system: Some(system_prompt),
options: OllamaOptions {
temperature,
top_p: request.top_p,
num_predict: request.max_tokens.map(|v| v as i32),
},
};
let response = self.client
.post(format!("{}/api/generate", ollama_config.base_url))
.json(&ollama_request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow!("Ollama API error: {}", error_text));
}
let ollama_response: OllamaResponse = response.json().await?;
let token_usage = TokenUsage {
input_tokens: ollama_response.prompt_eval_count.unwrap_or(0),
output_tokens: ollama_response.eval_count.unwrap_or(0),
total_tokens: ollama_response.prompt_eval_count.unwrap_or(0) + ollama_response.eval_count.unwrap_or(0),
estimated_cost_usd: None, // Ollama is local, no cost
};
Ok((ollama_response.response, model, token_usage))
}
pub async fn _list_ollama_models(&self) -> Result<Vec<_OllamaModelInfo>> {
let ollama_config = self.config.ollama.as_ref()
.ok_or_else(|| anyhow!("Ollama config not found"))?;
let response = self.client
.get(format!("{}/api/tags", ollama_config.base_url))
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow!("Failed to list Ollama models"));
}
let list_response: _OllamaListResponse = response.json().await?;
Ok(list_response.models)
}
pub fn _get_anthropic_models(&self) -> Vec<String> {
// Return known Claude models from claude_models module
ClaudeModelType::all_available()
.into_iter()
.map(|m| m.id())
.collect()
}
async fn summarize_lmstudio(&self, request: &SummarizeRequest) -> Result<(String, String, TokenUsage)> {
let lmstudio_config = self.config.lmstudio.as_ref()
.ok_or_else(|| anyhow!("LMStudio config not found"))?;
let model = request.model.clone()
.unwrap_or_else(|| lmstudio_config.model.clone());
let temperature = request.temperature
.unwrap_or(lmstudio_config.temperature);
let max_tokens = request.max_tokens
.unwrap_or(lmstudio_config.max_tokens);
let system_prompt = if let Some(ref style) = request.style {
style.to_system_prompt()
} else {
request.system_prompt.clone()
.unwrap_or_else(|| "You are a helpful assistant that summarizes text concisely and accurately. Provide a clear, well-structured summary of the given text.".to_string())
};
// LMStudio uses OpenAI-compatible API
#[derive(Serialize)]
struct LmStudioRequest {
model: String,
messages: Vec<LmStudioMessage>,
temperature: f32,
max_tokens: u32,
}
#[derive(Serialize)]
struct LmStudioMessage {
role: String,
content: String,
}
#[derive(Deserialize)]
struct LmStudioResponse {
choices: Vec<LmStudioChoice>,
usage: Option<LmStudioUsage>,
}
#[derive(Deserialize)]
struct LmStudioChoice {
message: LmStudioResponseMessage,
}
#[derive(Deserialize)]
struct LmStudioResponseMessage {
content: String,
}
#[derive(Deserialize)]
struct LmStudioUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
let lmstudio_request = LmStudioRequest {
model: model.clone(),
messages: vec![
LmStudioMessage {
role: "system".to_string(),
content: system_prompt,
},
LmStudioMessage {
role: "user".to_string(),
content: format!("Please summarize the following text:\n\n{}", request.text),
},
],
temperature,
max_tokens,
};
let response = self.client
.post(format!("{}/chat/completions", lmstudio_config.base_url))
.json(&lmstudio_request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow!("LMStudio API error: {}", error_text));
}
let lmstudio_response: LmStudioResponse = response.json().await?;
let summary = lmstudio_response.choices
.first()
.map(|c| c.message.content.clone())
.ok_or_else(|| anyhow!("No content in LMStudio response"))?;
let token_usage = if let Some(usage) = lmstudio_response.usage {
TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
estimated_cost_usd: None, // LMStudio is local, no cost
}
} else {
TokenUsage {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
estimated_cost_usd: None,
}
};
Ok((summary, model, token_usage))
}
pub async fn _list_lmstudio_models(&self) -> Result<Vec<_OllamaModelInfo>> {
let lmstudio_config = self.config.lmstudio.as_ref()
.ok_or_else(|| anyhow!("LMStudio config not found"))?;
let response = self.client
.get(format!("{}/models", lmstudio_config.base_url))
.send()
.await?;
if !response.status().is_success() {
return Err(anyhow!("Failed to list LMStudio models"));
}
#[derive(Deserialize)]
struct LmStudioModelsResponse {
data: Vec<LmStudioModel>,
}
#[derive(Deserialize)]
struct LmStudioModel {
id: String,
}
let list_response: LmStudioModelsResponse = response.json().await?;
// Convert to OllamaModelInfo format for compatibility
Ok(list_response.data.into_iter().map(|m| _OllamaModelInfo {
name: m.id,
modified_at: "".to_string(),
size: 0,
}).collect())
}
}
// Calculate Anthropic API costs based on model and token usage using ClaudeModelType
fn calculate_anthropic_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
let claude_model = ClaudeModelType::from_id(model);
claude_model.calculate_cost(input_tokens, output_tokens)
}

300
backend_api/src/main.rs Normal file
View File

@@ -0,0 +1,300 @@
mod config;
mod models;
mod llm;
mod checksum;
mod ocr;
mod cache;
mod database;
mod rate_limiter;
mod claude_models;
mod credits;
mod v1;
use anyhow::Result;
use poem::{listener::TcpListener, Route, EndpointExt};
use poem::middleware::Tracing;
use poem_openapi::OpenApiService;
use std::sync::Arc;
use tracing::{info, warn, error};
use config::Config;
use v1::{AppState, create_apis};
/// Result of LLM provider health check
struct LlmHealthResult {
available: bool,
error_message: Option<String>,
}
/// Performs startup health checks and returns whether the server can start.
/// Returns Err if no LLM provider is available.
async fn perform_startup_health_checks(state: &Arc<AppState>) -> Result<()> {
// Check database
match state.database.health_check().await {
Ok(_) => info!("✓ Database connection healthy"),
Err(e) => {
error!("✗ Database connection failed: {}", e);
return Err(anyhow::anyhow!("Database connection failed: {}", e));
}
}
let mut anthropic_available = false;
let mut ollama_available = false;
let mut lmstudio_available = false;
let mut anthropic_error: Option<String> = None;
// Check Anthropic API if configured
if let Some(ref config) = state.config.anthropic {
info!("Checking Anthropic API...");
let result = check_anthropic_health(config).await;
anthropic_available = result.available;
anthropic_error = result.error_message.clone();
if result.available {
info!("✓ Anthropic API healthy");
} else {
error!("✗ Anthropic API check failed: {}", result.error_message.unwrap_or_default());
}
} else {
warn!("Anthropic API not configured");
}
// Check Ollama API if configured
if let Some(ref config) = state.config.ollama {
info!("Checking Ollama API...");
let result = check_ollama_health(config).await;
ollama_available = result.available;
if result.available {
info!("✓ Ollama API healthy");
} else {
error!("✗ Ollama API check failed: {}", result.error_message.unwrap_or_default());
}
} else {
warn!("Ollama API not configured");
}
// Check LMStudio API if configured
if let Some(ref config) = state.config.lmstudio {
info!("Checking LMStudio API...");
let result = check_lmstudio_health(config).await;
lmstudio_available = result.available;
if result.available {
info!("✓ LMStudio API healthy");
} else {
error!("✗ LMStudio API check failed: {}", result.error_message.unwrap_or_default());
}
} else {
warn!("LMStudio API not configured");
}
// Check if at least one LLM provider is available
if !anthropic_available && !ollama_available && !lmstudio_available {
error!("═══════════════════════════════════════════════════════════════");
error!("FATAL: No LLM provider is available!");
error!("═══════════════════════════════════════════════════════════════");
if state.config.anthropic.is_some() {
if let Some(ref err) = anthropic_error {
if err.contains("401") || err.contains("invalid") || err.contains("authentication") {
error!("Anthropic API key appears to be invalid");
} else if err.contains("insufficient") || err.contains("credit") || err.contains("balance") {
error!("Anthropic API credits may be exhausted");
}
}
error!("- Anthropic: UNAVAILABLE");
} else {
error!("- Anthropic: NOT CONFIGURED");
}
if state.config.ollama.is_some() {
error!("- Ollama: UNAVAILABLE (is the server running?)");
} else {
error!("- Ollama: NOT CONFIGURED");
}
if state.config.lmstudio.is_some() {
error!("- LMStudio: UNAVAILABLE (is the server running?)");
} else {
error!("- LMStudio: NOT CONFIGURED");
}
error!("═══════════════════════════════════════════════════════════════");
error!("Please configure at least one working LLM provider in config.toml");
error!("═══════════════════════════════════════════════════════════════");
return Err(anyhow::anyhow!("No LLM provider available. Server cannot start."));
}
info!("Startup health checks completed successfully");
Ok(())
}
async fn check_anthropic_health(config: &config::AnthropicConfig) -> LlmHealthResult {
let client = reqwest::Client::new();
#[derive(serde::Serialize)]
struct TestRequest {
model: String,
max_tokens: u32,
messages: Vec<TestMessage>,
}
#[derive(serde::Serialize)]
struct TestMessage {
role: String,
content: String,
}
let test_request = TestRequest {
model: "claude-haiku-4-5-20251001".to_string(),
max_tokens: 10,
messages: vec![TestMessage {
role: "user".to_string(),
content: "Hello".to_string(),
}],
};
match client
.post(format!("{}/messages", config.base_url))
.header("x-api-key", &config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&test_request)
.send()
.await
{
Ok(resp) if resp.status().is_success() => LlmHealthResult {
available: true,
error_message: None,
},
Ok(resp) => {
let status = resp.status();
let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
LlmHealthResult {
available: false,
error_message: Some(format!("HTTP {} - {}", status, error_text)),
}
}
Err(e) => LlmHealthResult {
available: false,
error_message: Some(format!("Connection failed: {}", e)),
},
}
}
async fn check_ollama_health(config: &config::OllamaConfig) -> LlmHealthResult {
let client = reqwest::Client::new();
match client
.get(format!("{}/api/tags", config.base_url))
.send()
.await
{
Ok(resp) if resp.status().is_success() => LlmHealthResult {
available: true,
error_message: None,
},
Ok(resp) => LlmHealthResult {
available: false,
error_message: Some(format!("HTTP {}", resp.status())),
},
Err(e) => LlmHealthResult {
available: false,
error_message: Some(format!("Connection failed: {}", e)),
},
}
}
async fn check_lmstudio_health(config: &config::LmStudioConfig) -> LlmHealthResult {
let client = reqwest::Client::new();
match client
.get(format!("{}/models", config.base_url))
.send()
.await
{
Ok(resp) if resp.status().is_success() => LlmHealthResult {
available: true,
error_message: None,
},
Ok(resp) => LlmHealthResult {
available: false,
error_message: Some(format!("HTTP {}", resp.status())),
},
Err(e) => LlmHealthResult {
available: false,
error_message: Some(format!("Connection failed: {}", e)),
},
}
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
info!("Starting Summarizer API v1...");
// Load configuration
let config = Config::load("./src/config/config.toml")?;
info!("Configuration loaded");
// Create app state with in-memory cache, database, and rate limiter
let state = Arc::new(AppState::new(config.clone()).await?);
info!("App state initialized with hybrid cache (memory + database) and rate limiter");
// Perform startup health checks - exit if no LLM is available
info!("Performing startup health checks...");
perform_startup_health_checks(&state).await?;
// Create versioned APIs
let (summarize_api, cache_api, models_api, ocr_api, health_api, credits_api) = create_apis(state.clone());
// Combine all APIs into a single OpenAPI service
let api_service = OpenApiService::new(
(summarize_api, cache_api, models_api, ocr_api, health_api, credits_api),
"Summarizer API",
"1.0"
)
.server(format!("http://{}:{}/api/v1", config.server.host, config.server.port));
let ui = api_service.swagger_ui();
let spec = api_service.spec_endpoint();
// Build routes
let app = Route::new()
.nest("/api/v1", api_service)
.nest("/docs", ui)
.nest("/spec", spec)
.with(Tracing);
let addr = format!("{}:{}", config.server.host, config.server.port);
info!("Starting server on {}", addr);
info!("API documentation available at http://{}/docs", addr);
info!("API v1 endpoints available at http://{}/api/v1", addr);
// Start background cleanup task
let cleanup_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300)); // 5 minutes
loop {
interval.tick().await;
cleanup_state.cache.cleanup_expired().await;
cleanup_state.rate_limiter.cleanup_expired().await;
info!("Cleaned up expired cache and rate limit entries");
}
});
poem::Server::new(TcpListener::bind(&addr))
.run(app)
.await?;
Ok(())
}

241
backend_api/src/models.rs Normal file
View File

@@ -0,0 +1,241 @@
use poem_openapi::Object;
use serde::{Deserialize, Serialize};
use validator::Validate;
pub use crate::v1::enums::{LlmProvider, SummarizationStyle};
#[derive(Debug, Clone, Serialize, Deserialize, Object, Validate)]
pub struct SummarizeRequest {
/// The text to summarize
#[validate(length(min = 1, max = 100000, message = "Text must be between 1 and 100000 characters"))]
pub text: String,
/// LLM provider (anthropic or ollama)
pub provider: LlmProvider,
/// Model name (optional, uses default from config)
#[validate(length(min = 1, max = 100))]
pub model: Option<String>,
/// Temperature (0.0 to 2.0)
#[validate(range(min = 0.0, max = 2.0))]
pub temperature: Option<f32>,
/// Max tokens for the response
#[validate(range(min = 1, max = 100000))]
pub max_tokens: Option<u32>,
/// Top P sampling (0.0 to 1.0)
#[validate(range(min = 0.0, max = 1.0))]
pub top_p: Option<f32>,
/// Predefined summarization style (overrides system_prompt if provided)
pub style: Option<SummarizationStyle>,
/// Custom system prompt (optional, ignored if style is set)
#[validate(length(max = 5000))]
pub system_prompt: Option<String>,
/// Discord guild ID
#[validate(length(min = 1, max = 50))]
pub guild_id: String,
/// Discord user ID
#[validate(length(min = 1, max = 50))]
pub user_id: String,
/// Discord channel ID (optional)
#[validate(length(min = 1, max = 50))]
pub channel_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct SummarizeResponse {
/// The generated summary
pub summary: String,
/// Model used
pub model: String,
/// Provider used
pub provider: String,
/// Whether this was served from cache
pub from_cache: bool,
/// Request checksum
pub checksum: String,
/// Timestamp
pub timestamp: String,
/// Token usage information
pub token_usage: Option<TokenUsage>,
/// Processing time in milliseconds
pub processing_time_ms: Option<u64>,
/// Summarization style used
pub style_used: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct TokenUsage {
/// Input tokens (prompt)
pub input_tokens: u32,
/// Output tokens (completion)
pub output_tokens: u32,
/// Total tokens used
pub total_tokens: u32,
/// Estimated cost in USD (if available)
pub estimated_cost_usd: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct LlmInfo {
/// Provider name
pub provider: String,
/// Model name
pub model: String,
/// Model version (if available)
pub version: Option<String>,
/// Whether this model is available
pub available: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct CacheEntry {
/// Cache entry ID
pub id: i64,
/// Request checksum
pub checksum: String,
/// Original text
pub text: String,
/// Generated summary
pub summary: String,
/// Provider used
pub provider: String,
/// Model used
pub model: String,
/// Guild ID
pub guild_id: String,
/// User ID
pub user_id: String,
/// Channel ID
pub channel_id: Option<String>,
/// Created timestamp
pub created_at: String,
/// Last accessed timestamp
pub last_accessed: String,
/// Access count
pub access_count: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct CacheStats {
/// Total cache entries
pub total_entries: i64,
/// Total cache hits
pub total_hits: i64,
/// Total cache misses
pub total_misses: i64,
/// Cache hit rate (percentage)
pub hit_rate: f64,
/// Total size in bytes (approximate)
pub total_size_bytes: i64,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct RateLimitInfo {
/// Guild ID
pub guild_id: String,
/// User ID
pub user_id: String,
/// Requests in current window
pub requests_count: i64,
/// Window start time
pub window_start: String,
/// Whether rate limit is exceeded
pub is_limited: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct DeleteCacheRequest {
/// Cache entry ID to delete (optional)
pub id: Option<i64>,
/// Delete all entries for a guild (optional)
pub guild_id: Option<String>,
/// Delete all entries for a user (optional)
pub user_id: Option<String>,
/// Delete all entries (use with caution)
pub delete_all: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct ErrorResponse {
pub success: bool,
pub error: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct SuccessResponse {
pub success: bool,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object, Validate)]
pub struct OcrSummarizeRequest {
/// LLM provider (anthropic or ollama)
pub provider: LlmProvider,
/// Model name (optional, uses default from config)
#[validate(length(min = 1, max = 100))]
pub model: Option<String>,
/// Temperature (0.0 to 2.0)
#[validate(range(min = 0.0, max = 2.0))]
pub temperature: Option<f32>,
/// Max tokens for the response
#[validate(range(min = 1, max = 100000))]
pub max_tokens: Option<u32>,
/// Top P sampling (0.0 to 1.0)
#[validate(range(min = 0.0, max = 1.0))]
pub top_p: Option<f32>,
/// Predefined summarization style
pub style: Option<SummarizationStyle>,
/// Custom system prompt (optional)
#[validate(length(max = 5000))]
pub system_prompt: Option<String>,
/// Discord guild ID
#[validate(length(min = 1, max = 50))]
pub guild_id: String,
/// Discord user ID
#[validate(length(min = 1, max = 50))]
pub user_id: String,
/// Discord channel ID (optional)
#[validate(length(min = 1, max = 50))]
pub channel_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct OcrSummarizeResponse {
/// Extracted text from OCR
pub extracted_text: String,
/// The generated summary
pub summary: String,
/// Model used
pub model: String,
/// Provider used
pub provider: String,
/// Whether summary was served from cache
pub from_cache: bool,
/// Request checksum
pub checksum: String,
/// Timestamp
pub timestamp: String,
/// Token usage information
pub token_usage: Option<TokenUsage>,
/// OCR processing time in milliseconds
pub ocr_time_ms: u64,
/// Summarization processing time in milliseconds
pub summarization_time_ms: Option<u64>,
/// Total processing time in milliseconds
pub total_time_ms: u64,
/// Summarization style used
pub style_used: Option<String>,
/// Image dimensions
pub image_info: ImageInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct ImageInfo {
/// Image width in pixels
pub width: u32,
/// Image height in pixels
pub height: u32,
/// Image format (e.g., PNG, JPEG)
pub format: String,
/// File size in bytes
pub size_bytes: usize,
}

69
backend_api/src/ocr.rs Normal file
View File

@@ -0,0 +1,69 @@
use anyhow::{anyhow, Result};
use crate::models::ImageInfo;
#[cfg(feature = "ocr")]
use tesseract::Tesseract;
#[cfg(feature = "ocr")]
use image::ImageFormat;
#[cfg(feature = "ocr")]
use std::io::Cursor;
pub struct OcrProcessor;
impl OcrProcessor {
pub fn new() -> Self {
Self
}
#[cfg(feature = "ocr")]
pub async fn extract_text_from_image(&self, image_data: &[u8]) -> Result<(String, ImageInfo, u64)> {
let start = std::time::Instant::now();
let img = image::load_from_memory(image_data)
.map_err(|e| anyhow!("Failed to load image: {}", e))?;
let format = image::guess_format(image_data)
.map_err(|e| anyhow!("Failed to detect image format: {}", e))?;
let image_info = ImageInfo {
width: img.width(),
height: img.height(),
format: format!("{:?}", format),
size_bytes: image_data.len(),
};
let gray_img = img.to_luma8();
// Save to temporary buffer for Tesseract
let mut buffer = Vec::new();
let mut cursor = Cursor::new(&mut buffer);
gray_img.write_to(&mut cursor, ImageFormat::Png)
.map_err(|e| anyhow!("Failed to encode image: {}", e))?;
let tesseract = Tesseract::new(None, Some("eng"))
.map_err(|e| anyhow!("Failed to initialize Tesseract: {}", e))?;
let text = tesseract
.set_image_from_mem(&buffer)
.map_err(|e| anyhow!("Failed to set image: {}", e))?
.get_text()
.map_err(|e| anyhow!("Failed to extract text: {}", e))?;
let ocr_time_ms = start.elapsed().as_millis() as u64;
let cleaned_text = text.trim().to_string();
if cleaned_text.is_empty() {
return Err(anyhow!("No text found in image"));
}
Ok((cleaned_text, image_info, ocr_time_ms))
}
#[cfg(not(feature = "ocr"))]
pub async fn extract_text_from_image(&self, _image_data: &[u8]) -> Result<(String, ImageInfo, u64)> {
Err(anyhow!("OCR feature is not enabled. Rebuild with --features ocr"))
}
}

View File

@@ -0,0 +1,100 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use chrono::{DateTime, Utc, Duration};
#[derive(Clone)]
struct RateLimitEntry {
count: u32,
window_start: DateTime<Utc>,
}
pub struct RateLimiter {
guild_limits: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
user_limits: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
window_seconds: i64,
}
impl RateLimiter {
pub fn new(window_seconds: i64) -> Self {
Self {
guild_limits: Arc::new(RwLock::new(HashMap::new())),
user_limits: Arc::new(RwLock::new(HashMap::new())),
window_seconds,
}
}
pub async fn check_rate_limit(
&self,
guild_id: &str,
user_id: &str,
guild_limit: u32,
user_limit: u32,
) -> (bool, bool) {
let now = Utc::now();
let window_duration = Duration::seconds(self.window_seconds);
// Check guild rate limit
let mut guild_limits = self.guild_limits.write().await;
let guild_limited = if let Some(entry) = guild_limits.get_mut(guild_id) {
if now - entry.window_start > window_duration {
// Reset window
entry.count = 1;
entry.window_start = now;
false
} else if entry.count >= guild_limit {
true
} else {
entry.count += 1;
false
}
} else {
guild_limits.insert(
guild_id.to_string(),
RateLimitEntry {
count: 1,
window_start: now,
},
);
false
};
// Check user rate limit
let mut user_limits = self.user_limits.write().await;
let user_limited = if let Some(entry) = user_limits.get_mut(user_id) {
if now - entry.window_start > window_duration {
// Reset window
entry.count = 1;
entry.window_start = now;
false
} else if entry.count >= user_limit {
true
} else {
entry.count += 1;
false
}
} else {
user_limits.insert(
user_id.to_string(),
RateLimitEntry {
count: 1,
window_start: now,
},
);
false
};
(guild_limited, user_limited)
}
pub async fn cleanup_expired(&self) {
let now = Utc::now();
let window_duration = Duration::seconds(self.window_seconds);
let mut guild_limits = self.guild_limits.write().await;
guild_limits.retain(|_, entry| now - entry.window_start <= window_duration);
let mut user_limits = self.user_limits.write().await;
user_limits.retain(|_, entry| now - entry.window_start <= window_duration);
}
}

View File

@@ -0,0 +1,68 @@
use poem_openapi::Enum;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Enum)]
#[oai(rename_all = "lowercase")]
pub enum LlmProvider {
Anthropic,
Ollama,
Lmstudio,
}
#[derive(Debug, Clone, Serialize, Deserialize, Enum)]
#[oai(rename_all = "lowercase")]
pub enum SummarizationStyle {
/// Brief, concise summary focusing on key points
Brief,
/// Detailed, comprehensive summary with context
Detailed,
/// Funny, humorous take on the content
Funny,
/// Professional, formal business-style summary
Professional,
/// Technical, focusing on technical details and terminology
Technical,
/// ELI5 (Explain Like I'm 5) - simple, easy to understand
Eli5,
/// Bullet points format
Bullets,
/// Academic style with citations and formal language
Academic,
/// Roast style
Roast,
}
impl SummarizationStyle {
pub fn to_system_prompt(&self) -> String {
match self {
SummarizationStyle::Brief => {
"You are a concise summarization assistant. Create brief, to-the-point summaries that capture only the most essential information. Keep it under 3-4 sentences.".to_string()
}
SummarizationStyle::Detailed => {
"You are a comprehensive summarization assistant. Create detailed summaries that include context, nuances, and important details. Organize information logically and ensure nothing significant is missed.".to_string()
}
SummarizationStyle::Funny => {
"You are a humorous summarization assistant. Create entertaining summaries that capture the essence of the content while adding wit, humor, and playful commentary. Make it fun to read while staying accurate.".to_string()
}
SummarizationStyle::Professional => {
"You are a professional business summarization assistant. Create formal, well-structured summaries suitable for corporate environments. Use professional language and focus on actionable insights.".to_string()
}
SummarizationStyle::Technical => {
"You are a technical summarization assistant. Focus on technical details, specifications, methodologies, and precise terminology. Maintain accuracy and include relevant technical context.".to_string()
}
SummarizationStyle::Eli5 => {
"You are an ELI5 (Explain Like I'm 5) summarization assistant. Break down complex topics into simple, easy-to-understand language that anyone can grasp. Use analogies and simple examples.".to_string()
}
SummarizationStyle::Bullets => {
"You are a bullet-point summarization assistant. Create clear, organized summaries using bullet points. Each point should be concise and capture a distinct idea or fact. Use proper formatting with • or - for bullets.".to_string()
}
SummarizationStyle::Academic => {
"You are an academic summarization assistant. Create scholarly summaries with formal language, proper structure, and attention to methodology and evidence. Maintain objectivity and academic rigor.".to_string()
}
SummarizationStyle::Roast => {
"You are a roast summarization assistant. Create humorous summaries that capture the essence of the content while adding wit, dark humor, and harshly criticizing the content. Make it fun to read while staying accurate.".to_string()
}
}
}
}

78
backend_api/src/v1/mod.rs Normal file
View File

@@ -0,0 +1,78 @@
pub mod routes;
pub mod enums;
pub mod responses;
use std::sync::Arc;
use crate::config::Config;
use crate::cache::Cache;
use crate::database::Database;
use crate::rate_limiter::RateLimiter;
use crate::llm::LlmClient;
use crate::ocr::OcrProcessor;
use crate::credits::{CreditSystem, create_credit_tables};
pub struct AppState {
pub cache: Cache,
pub database: Database,
pub rate_limiter: RateLimiter,
pub llm_client: LlmClient,
pub ocr_processor: OcrProcessor,
pub credit_system: CreditSystem,
pub config: Config,
}
impl AppState {
pub async fn new(config: Config) -> anyhow::Result<Self> {
let cache = Cache::new(
config.cache.ttl_seconds.unwrap_or(3600),
config.cache.max_size.unwrap_or(1000),
);
// Initialize database
let database = Database::new(
&config.database.url,
config.database.db_type.clone(),
).await?;
// Create credit system tables
create_credit_tables(&database.pool()).await?;
let rate_limiter = RateLimiter::new(
config.rate_limiting.window_seconds.unwrap_or(60),
);
let llm_client = LlmClient::new(config.clone());
let ocr_processor = OcrProcessor::new();
// Initialize credit system with bypass users
let credit_system = CreditSystem::new(config.credits.bypass_user_ids.clone());
Ok(Self {
cache,
database,
rate_limiter,
llm_client,
ocr_processor,
credit_system,
config,
})
}
}
pub fn create_apis(state: Arc<AppState>) -> (
routes::SummarizeApi,
routes::CacheApi,
routes::ModelsApi,
routes::OcrApi,
routes::HealthApi,
routes::CreditsApi,
) {
(
routes::SummarizeApi { state: state.clone() },
routes::CacheApi { state: state.clone() },
routes::ModelsApi { state: state.clone() },
routes::OcrApi { state: state.clone() },
routes::HealthApi { state: state.clone() },
routes::CreditsApi { state },
)
}

View File

@@ -0,0 +1,42 @@
use poem_openapi::{ApiResponse, payload::Json};
use crate::models::*;
#[derive(ApiResponse)]
pub enum SummarizeResult {
#[oai(status = 200)]
Ok(Json<SummarizeResponse>),
#[oai(status = 400)]
Error(Json<ErrorResponse>),
}
#[derive(ApiResponse)]
pub enum CacheStatsResult {
#[oai(status = 200)]
Ok(Json<CacheStats>),
#[oai(status = 500)]
Error(Json<ErrorResponse>),
}
#[derive(ApiResponse)]
pub enum CacheEntriesResult {
#[oai(status = 200)]
Ok(Json<Vec<CacheEntry>>),
#[oai(status = 500)]
Error(Json<ErrorResponse>),
}
#[derive(ApiResponse)]
pub enum DeleteResult {
#[oai(status = 200)]
Ok(Json<SuccessResponse>),
#[oai(status = 400)]
Error(Json<ErrorResponse>),
}
#[derive(ApiResponse)]
pub enum OcrSummarizeResult {
#[oai(status = 200)]
Ok(Json<OcrSummarizeResponse>),
#[oai(status = 400)]
Error(Json<ErrorResponse>),
}

View File

@@ -0,0 +1,135 @@
use poem_openapi::{payload::Json, param::{Path, Query}, OpenApi};
use std::sync::Arc;
use tracing::warn;
use crate::models::{CacheStats, CacheEntry, DeleteCacheRequest, SuccessResponse};
use super::super::AppState;
pub struct CacheApi {
pub state: Arc<AppState>,
}
#[OpenApi(prefix_path = "/cache")]
impl CacheApi {
/// Get cache statistics from database (persistent stats)
#[oai(path = "/stats", method = "get")]
pub async fn get_cache_stats(&self) -> Json<CacheStats> {
// Get stats from database for accurate persistent statistics
match self.state.database.get_cache_stats().await {
Ok(stats) => Json(stats),
Err(e) => {
warn!("Failed to get cache stats from database: {}", e);
// Fallback to in-memory stats
let (hits, misses, size) = self.state.cache.get_stats().await;
Json(CacheStats {
total_entries: size as i64,
total_hits: hits as i64,
total_misses: misses as i64,
hit_rate: if hits + misses > 0 {
(hits as f64 / (hits + misses) as f64) * 100.0
} else {
0.0
},
total_size_bytes: 0,
})
}
}
}
/// List cache entries with pagination from database
#[oai(path = "/entries", method = "get")]
pub async fn list_cache_entries(
&self,
limit: Query<Option<i64>>,
offset: Query<Option<i64>>,
) -> Json<Vec<CacheEntry>> {
// Fetch from database for complete persistent data
match self.state.database.get_all_cache_entries(limit.0, offset.0).await {
Ok(entries) => Json(entries),
Err(e) => {
warn!("Failed to list cache entries from database: {}", e);
// Fallback to in-memory cache
let limit_usize = limit.0.map(|l| l as usize);
let offset_usize = offset.0.map(|o| o as usize);
let entries = self.state.cache.get_all(limit_usize, offset_usize).await;
Json(entries)
}
}
}
/// Get cache entries for a specific guild from database
#[oai(path = "/guild/:guild_id", method = "get")]
pub async fn get_guild_cache(
&self,
guild_id: Path<String>,
) -> Json<Vec<CacheEntry>> {
// Fetch from database for complete persistent data
match self.state.database.get_cache_by_guild(&guild_id.0).await {
Ok(entries) => Json(entries),
Err(e) => {
warn!("Failed to get guild cache from database: {}", e);
// Fallback to in-memory cache
let entries = self.state.cache.get_by_guild(&guild_id.0).await;
Json(entries)
}
}
}
/// Get cache entries for a specific user from database
#[oai(path = "/user/:user_id", method = "get")]
pub async fn get_user_cache(
&self,
user_id: Path<String>,
) -> Json<Vec<CacheEntry>> {
// Fetch from database for complete persistent data
match self.state.database.get_cache_by_user(&user_id.0).await {
Ok(entries) => Json(entries),
Err(e) => {
warn!("Failed to get user cache from database: {}", e);
// Fallback to in-memory cache
let entries = self.state.cache.get_by_user(&user_id.0).await;
Json(entries)
}
}
}
/// Delete cache entries from both memory and database
#[oai(path = "/delete", method = "post")]
pub async fn delete_cache(
&self,
request: Json<DeleteCacheRequest>,
) -> Json<SuccessResponse> {
let request = request.0;
// Delete from both memory cache and database
let (memory_deleted, db_deleted) = if let Some(id) = request.id {
let mem = self.state.cache.delete_by_id(id).await;
let db = self.state.database.delete_cache_by_id(id).await.unwrap_or(0);
(mem, db)
} else if let Some(guild_id) = request.guild_id {
let mem = self.state.cache.delete_by_guild(&guild_id).await;
let db = self.state.database.delete_cache_by_guild(&guild_id).await.unwrap_or(0);
(mem, db)
} else if let Some(user_id) = request.user_id {
let mem = self.state.cache.delete_by_user(&user_id).await;
let db = self.state.database.delete_cache_by_user(&user_id).await.unwrap_or(0);
(mem, db)
} else if request.delete_all == Some(true) {
let mem = self.state.cache.delete_all().await;
let db = self.state.database.delete_all_cache().await.unwrap_or(0);
(mem, db)
} else {
(0, 0)
};
let total_deleted = memory_deleted + db_deleted;
Json(SuccessResponse {
success: true,
message: format!(
"Deleted {} cache entries (memory: {}, database: {})",
total_deleted, memory_deleted, db_deleted
),
})
}
}

View File

@@ -0,0 +1,357 @@
//! Credits API routes for managing user credits and tiers.
use poem_openapi::{param::Path, param::Query, payload::Json, Object, OpenApi, Tags};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::credits::TierLevel;
use crate::v1::AppState;
#[derive(Tags)]
enum ApiTags {
/// Credit management endpoints
Credits,
}
/// Credits API for managing user credits
pub struct CreditsApi {
pub state: Arc<AppState>,
}
// ==================== Request/Response Models ====================
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct UserCreditsResponse {
pub success: bool,
pub user_id: String,
pub credits: i64,
pub tier_id: i32,
pub tier_name: String,
pub lifetime_credits_used: i64,
pub is_unlimited: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct AddCreditsRequest {
pub user_id: String,
pub amount: i64,
#[oai(default)]
pub transaction_type: Option<String>,
#[oai(default)]
pub description: Option<String>,
#[oai(default)]
pub guild_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct AddCreditsResponse {
pub success: bool,
pub user_id: String,
pub new_balance: i64,
pub amount_added: i64,
#[oai(default)]
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct DeductCreditsRequest {
pub user_id: String,
pub amount: i64,
#[oai(default)]
pub description: Option<String>,
#[oai(default)]
pub guild_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct DeductCreditsResponse {
pub success: bool,
pub user_id: String,
pub new_balance: i64,
pub amount_deducted: i64,
#[oai(default)]
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct SetTierRequest {
pub user_id: String,
pub tier_id: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct SetTierResponse {
pub success: bool,
pub user_id: String,
pub tier_id: i32,
pub tier_name: String,
#[oai(default)]
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct TierInfo {
pub id: i32,
pub name: String,
pub monthly_credits: i64,
pub price_usd: f64,
pub max_messages_per_summary: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct TiersResponse {
pub success: bool,
pub tiers: Vec<TierInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct TransactionInfo {
pub id: i64,
pub user_id: String,
pub guild_id: Option<String>,
pub amount: i64,
pub transaction_type: String,
pub description: String,
pub balance_after: i64,
pub created_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct TransactionsResponse {
pub success: bool,
pub transactions: Vec<TransactionInfo>,
#[oai(default)]
pub error: Option<String>,
}
// ==================== API Implementation ====================
#[OpenApi(prefix_path = "/credits", tag = "ApiTags::Credits")]
impl CreditsApi {
/// Get user credits
#[oai(path = "/:user_id", method = "get")]
pub async fn get_credits(&self, user_id: Path<String>) -> Json<UserCreditsResponse> {
let pool = self.state.database.pool();
match self.state.credit_system.get_or_create_user(pool, &user_id.0).await {
Ok(user) => {
let tier_name = match user.tier_id {
1 => "free",
2 => "basic",
3 => "pro",
4 => "unlimited",
_ => "free",
};
Json(UserCreditsResponse {
success: true,
user_id: user.user_id,
credits: user.credits,
tier_id: user.tier_id,
tier_name: tier_name.to_string(),
lifetime_credits_used: user.lifetime_credits_used,
is_unlimited: user.tier_id == 4,
})
}
Err(_e) => Json(UserCreditsResponse {
success: false,
user_id: user_id.0.clone(),
credits: 0,
tier_id: 1,
tier_name: "free".to_string(),
lifetime_credits_used: 0,
is_unlimited: false,
}),
}
}
/// Add credits to a user
#[oai(path = "/add", method = "post")]
pub async fn add_credits(&self, request: Json<AddCreditsRequest>) -> Json<AddCreditsResponse> {
let pool = self.state.database.pool();
let tx_type = request.transaction_type.as_deref().unwrap_or("add");
let description = request.description.as_deref().unwrap_or("Credits added");
match self.state.credit_system.add_credits(
pool,
&request.user_id,
request.amount,
request.guild_id.as_deref(),
tx_type,
description,
).await {
Ok(new_balance) => Json(AddCreditsResponse {
success: true,
user_id: request.user_id.clone(),
new_balance,
amount_added: request.amount,
error: None,
}),
Err(e) => Json(AddCreditsResponse {
success: false,
user_id: request.user_id.clone(),
new_balance: 0,
amount_added: 0,
error: Some(format!("{}", e)),
}),
}
}
/// Deduct credits from a user
#[oai(path = "/deduct", method = "post")]
pub async fn deduct_credits(&self, request: Json<DeductCreditsRequest>) -> Json<DeductCreditsResponse> {
let pool = self.state.database.pool();
let description = request.description.as_deref().unwrap_or("Credits deducted");
let user = match self.state.credit_system.get_or_create_user(pool, &request.user_id).await {
Ok(u) => u,
Err(e) => {
return Json(DeductCreditsResponse {
success: false,
user_id: request.user_id.clone(),
new_balance: 0,
amount_deducted: 0,
error: Some(format!("Failed to get user: {}", e)),
});
}
};
if user.tier_id == 4 || self.state.credit_system.is_bypass_user(&request.user_id) {
return Json(DeductCreditsResponse {
success: true,
user_id: request.user_id.clone(),
new_balance: user.credits,
amount_deducted: 0,
error: None,
});
}
if user.credits < request.amount {
return Json(DeductCreditsResponse {
success: false,
user_id: request.user_id.clone(),
new_balance: user.credits,
amount_deducted: 0,
error: Some(format!("Insufficient credits. Current balance: {}", user.credits)),
});
}
match self.state.credit_system.deduct_credits(
pool,
&request.user_id,
request.amount,
request.guild_id.as_deref(),
description,
).await {
Ok(new_balance) => Json(DeductCreditsResponse {
success: true,
user_id: request.user_id.clone(),
new_balance,
amount_deducted: request.amount,
error: None,
}),
Err(e) => Json(DeductCreditsResponse {
success: false,
user_id: request.user_id.clone(),
new_balance: user.credits,
amount_deducted: 0,
error: Some(format!("{}", e)),
}),
}
}
/// Set user tier
#[oai(path = "/tier", method = "post")]
pub async fn set_tier(&self, request: Json<SetTierRequest>) -> Json<SetTierResponse> {
let pool = self.state.database.pool();
let tier = match request.tier_id {
1 => TierLevel::Free,
2 => TierLevel::Basic,
3 => TierLevel::Pro,
4 => TierLevel::Unlimited,
_ => {
return Json(SetTierResponse {
success: false,
user_id: request.user_id.clone(),
tier_id: request.tier_id,
tier_name: "invalid".to_string(),
error: Some("Invalid tier ID. Must be 1-4.".to_string()),
});
}
};
let tier_name = match request.tier_id {
1 => "free",
2 => "basic",
3 => "pro",
4 => "unlimited",
_ => "unknown",
};
if let Err(e) = self.state.credit_system.get_or_create_user(pool, &request.user_id).await {
return Json(SetTierResponse {
success: false,
user_id: request.user_id.clone(),
tier_id: request.tier_id,
tier_name: tier_name.to_string(),
error: Some(format!("Failed to get/create user: {}", e)),
});
}
match self.state.credit_system.set_user_tier(pool, &request.user_id, tier).await {
Ok(_) => Json(SetTierResponse {
success: true,
user_id: request.user_id.clone(),
tier_id: request.tier_id,
tier_name: tier_name.to_string(),
error: None,
}),
Err(e) => Json(SetTierResponse {
success: false,
user_id: request.user_id.clone(),
tier_id: request.tier_id,
tier_name: tier_name.to_string(),
error: Some(format!("{}", e)),
}),
}
}
/// Get all available tiers
#[oai(path = "/tiers", method = "get")]
pub async fn get_tiers(&self) -> Json<TiersResponse> {
let tiers = vec![
TierInfo { id: 1, name: "free".to_string(), monthly_credits: 50, price_usd: 0.0, max_messages_per_summary: 50 },
TierInfo { id: 2, name: "basic".to_string(), monthly_credits: 500, price_usd: 4.99, max_messages_per_summary: 100 },
TierInfo { id: 3, name: "pro".to_string(), monthly_credits: 2000, price_usd: 9.99, max_messages_per_summary: 250 },
TierInfo { id: 4, name: "unlimited".to_string(), monthly_credits: -1, price_usd: 19.99, max_messages_per_summary: 500 },
];
Json(TiersResponse { success: true, tiers })
}
/// Get user transaction history
#[oai(path = "/:user_id/transactions", method = "get")]
pub async fn get_transactions(&self, user_id: Path<String>, limit: Query<Option<i64>>) -> Json<TransactionsResponse> {
let pool = self.state.database.pool();
let limit_val = limit.0.unwrap_or(20);
match self.state.credit_system.get_transactions(pool, &user_id.0, limit_val).await {
Ok(transactions) => {
let tx_infos: Vec<TransactionInfo> = transactions.into_iter().map(|tx| TransactionInfo {
id: tx.id,
user_id: tx.user_id,
guild_id: tx.guild_id,
amount: tx.amount,
transaction_type: tx.transaction_type,
description: tx.description,
balance_after: tx.balance_after,
created_at: tx.created_at,
}).collect();
Json(TransactionsResponse { success: true, transactions: tx_infos, error: None })
}
Err(e) => Json(TransactionsResponse { success: false, transactions: vec![], error: Some(format!("{}", e)) }),
}
}
}

View File

@@ -0,0 +1,363 @@
use poem_openapi::{payload::Json, OpenApi, Object, Tags};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use reqwest::Client;
use crate::models::SuccessResponse;
use crate::claude_models::ClaudeModelType;
use super::super::AppState;
#[derive(Tags)]
enum ApiTags {
/// Health check endpoints
Health,
}
pub struct HealthApi {
pub state: Arc<AppState>,
}
/// Simple health status response for bot health checks
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct SimpleHealthResponse {
/// Status string ("healthy" or "unhealthy")
pub status: String,
/// Version of the API
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct HealthCheckResponse {
/// Overall health status
pub healthy: bool,
/// Health check timestamp
pub timestamp: String,
/// Database health
pub database: ServiceHealth,
/// Anthropic API health (if configured)
pub anthropic: Option<ServiceHealth>,
/// Ollama API health (if configured)
pub ollama: Option<ServiceHealth>,
/// LMStudio API health (if configured)
pub lmstudio: Option<ServiceHealth>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Object)]
pub struct ServiceHealth {
/// Whether the service is available
pub available: bool,
/// Status message
pub message: String,
/// Response time in milliseconds
pub response_time_ms: Option<u64>,
/// Available models (if applicable)
pub models: Option<Vec<String>>,
}
#[OpenApi(prefix_path = "/health", tag = "ApiTags::Health")]
impl HealthApi {
/// Simple health check - returns status for bot startup checks
#[oai(path = "/", method = "get")]
pub async fn health(&self) -> Json<SimpleHealthResponse> {
// Quick check if at least one LLM is available
let has_anthropic = self.state.config.anthropic.is_some();
let has_ollama = self.state.config.ollama.is_some();
let has_lmstudio = self.state.config.lmstudio.is_some();
let status = if has_anthropic || has_ollama || has_lmstudio {
"healthy"
} else {
"unhealthy"
};
Json(SimpleHealthResponse {
status: status.to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
})
}
/// Comprehensive health check endpoint with detailed service status
#[oai(path = "/detailed", method = "get")]
pub async fn health_detailed(&self) -> Json<HealthCheckResponse> {
let _start_time = std::time::Instant::now();
// Check Database
let database = self.check_database().await;
// Check Anthropic API if configured
let anthropic = if self.state.config.anthropic.is_some() {
Some(self.check_anthropic().await)
} else {
None
};
// Check Ollama API if configured
let ollama = if self.state.config.ollama.is_some() {
Some(self.check_ollama().await)
} else {
None
};
// Check LMStudio API if configured
let lmstudio = if self.state.config.lmstudio.is_some() {
Some(self.check_lmstudio().await)
} else {
None
};
// Determine overall health
let mut healthy = database.available;
// At least one LLM provider should be available
let llm_available = anthropic.as_ref().map(|s| s.available).unwrap_or(false)
|| ollama.as_ref().map(|s| s.available).unwrap_or(false)
|| lmstudio.as_ref().map(|s| s.available).unwrap_or(false);
healthy = healthy && llm_available;
Json(HealthCheckResponse {
healthy,
timestamp: chrono::Utc::now().to_rfc3339(),
database,
anthropic,
ollama,
lmstudio,
})
}
/// Simple health check endpoint (just returns OK)
#[oai(path = "/simple", method = "get")]
pub async fn health_simple(&self) -> Json<SuccessResponse> {
Json(SuccessResponse {
success: true,
message: "API is healthy".to_string(),
})
}
async fn check_database(&self) -> ServiceHealth {
let start = std::time::Instant::now();
match self.state.database.health_check().await {
Ok(_) => ServiceHealth {
available: true,
message: "Database is healthy".to_string(),
response_time_ms: Some(start.elapsed().as_millis() as u64),
models: None,
},
Err(e) => ServiceHealth {
available: false,
message: format!("Database error: {}", e),
response_time_ms: Some(start.elapsed().as_millis() as u64),
models: None,
},
}
}
async fn check_anthropic(&self) -> ServiceHealth {
let start = std::time::Instant::now();
let config = match &self.state.config.anthropic {
Some(c) => c,
None => return ServiceHealth {
available: false,
message: "Anthropic not configured".to_string(),
response_time_ms: None,
models: None,
},
};
let client = Client::new();
// Try to make a minimal API call to verify the API key
#[derive(serde::Serialize)]
struct TestRequest {
model: String,
max_tokens: u32,
messages: Vec<TestMessage>,
}
#[derive(serde::Serialize)]
struct TestMessage {
role: String,
content: String,
}
let test_request = TestRequest {
model: "claude-haiku-4-5-20251001".to_string(), // Use cheapest model for health check
max_tokens: 10,
messages: vec![TestMessage {
role: "user".to_string(),
content: "Hello".to_string(),
}],
};
let response = client
.post(format!("{}/messages", config.base_url))
.header("x-api-key", &config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&test_request)
.send()
.await;
let response_time = start.elapsed().as_millis() as u64;
match response {
Ok(resp) if resp.status().is_success() => {
let models = ClaudeModelType::all_available()
.into_iter()
.map(|m| m.id())
.collect();
ServiceHealth {
available: true,
message: "Anthropic API is healthy".to_string(),
response_time_ms: Some(response_time),
models: Some(models),
}
}
Ok(resp) => {
let status = resp.status();
let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
ServiceHealth {
available: false,
message: format!("Anthropic API error ({}): {}", status, error_text),
response_time_ms: Some(response_time),
models: None,
}
}
Err(e) => ServiceHealth {
available: false,
message: format!("Failed to connect to Anthropic API: {}", e),
response_time_ms: Some(response_time),
models: None,
},
}
}
async fn check_ollama(&self) -> ServiceHealth {
let start = std::time::Instant::now();
let config = match &self.state.config.ollama {
Some(c) => c,
None => return ServiceHealth {
available: false,
message: "Ollama not configured".to_string(),
response_time_ms: None,
models: None,
},
};
let client = Client::new();
// Try to list models
let response = client
.get(format!("{}/api/tags", config.base_url))
.send()
.await;
let response_time = start.elapsed().as_millis() as u64;
match response {
Ok(resp) if resp.status().is_success() => {
#[derive(serde::Deserialize)]
struct OllamaListResponse {
models: Vec<OllamaModel>,
}
#[derive(serde::Deserialize)]
struct OllamaModel {
name: String,
}
let models_list = resp.json::<OllamaListResponse>().await.ok()
.map(|m| m.models.into_iter().map(|model| model.name).collect());
ServiceHealth {
available: true,
message: "Ollama API is healthy".to_string(),
response_time_ms: Some(response_time),
models: models_list,
}
}
Ok(resp) => {
let status = resp.status();
ServiceHealth {
available: false,
message: format!("Ollama API error ({})", status),
response_time_ms: Some(response_time),
models: None,
}
}
Err(e) => ServiceHealth {
available: false,
message: format!("Failed to connect to Ollama API: {}", e),
response_time_ms: Some(response_time),
models: None,
},
}
}
async fn check_lmstudio(&self) -> ServiceHealth {
let start = std::time::Instant::now();
let config = match &self.state.config.lmstudio {
Some(c) => c,
None => return ServiceHealth {
available: false,
message: "LMStudio not configured".to_string(),
response_time_ms: None,
models: None,
},
};
let client = Client::new();
// Try to list models (OpenAI-compatible endpoint)
let response = client
.get(format!("{}/models", config.base_url))
.send()
.await;
let response_time = start.elapsed().as_millis() as u64;
match response {
Ok(resp) if resp.status().is_success() => {
#[derive(serde::Deserialize)]
struct LmStudioModelsResponse {
data: Vec<LmStudioModel>,
}
#[derive(serde::Deserialize)]
struct LmStudioModel {
id: String,
}
let models_list = resp.json::<LmStudioModelsResponse>().await.ok()
.map(|m| m.data.into_iter().map(|model| model.id).collect());
ServiceHealth {
available: true,
message: "LMStudio API is healthy".to_string(),
response_time_ms: Some(response_time),
models: models_list,
}
}
Ok(resp) => {
let status = resp.status();
ServiceHealth {
available: false,
message: format!("LMStudio API error ({})", status),
response_time_ms: Some(response_time),
models: None,
}
}
Err(e) => ServiceHealth {
available: false,
message: format!("Failed to connect to LMStudio API: {}", e),
response_time_ms: Some(response_time),
models: None,
},
}
}
}

View File

@@ -0,0 +1,13 @@
pub mod summarize;
pub mod cache;
pub mod models;
pub mod ocr;
pub mod health;
pub mod credits;
pub use summarize::SummarizeApi;
pub use cache::CacheApi;
pub use models::ModelsApi;
pub use ocr::OcrApi;
pub use health::HealthApi;
pub use credits::CreditsApi;

View File

@@ -0,0 +1,40 @@
use poem_openapi::{payload::Json, OpenApi};
use std::sync::Arc;
use crate::models::LlmInfo;
use super::super::AppState;
pub struct ModelsApi {
pub state: Arc<AppState>,
}
#[OpenApi(prefix_path = "/models")]
impl ModelsApi {
/// List available LLM models
#[oai(path = "/", method = "get")]
pub async fn list_models(&self) -> Json<Vec<LlmInfo>> {
let mut models = Vec::new();
// Add Anthropic models
if let Some(ref anthropic_config) = self.state.config.anthropic {
models.push(LlmInfo {
provider: "anthropic".to_string(),
model: anthropic_config.model.clone(),
version: None,
available: true,
});
}
// Add Ollama models
if let Some(ref ollama_config) = self.state.config.ollama {
models.push(LlmInfo {
provider: "ollama".to_string(),
model: ollama_config.model.clone(),
version: None,
available: true,
});
}
Json(models)
}
}

View File

@@ -0,0 +1,217 @@
use poem_openapi::{payload::{Json, Binary}, param::Query, OpenApi};
use std::sync::Arc;
use std::time::Instant;
use chrono::Utc;
use tracing::info;
use validator::Validate;
use crate::models::*;
use crate::v1::responses::OcrSummarizeResult;
use crate::v1::enums::{LlmProvider, SummarizationStyle};
use crate::checksum;
use super::super::AppState;
pub struct OcrApi {
pub state: Arc<AppState>,
}
#[OpenApi(prefix_path = "/ocr")]
impl OcrApi {
/// OCR and summarize image (send image as binary in request body, params as query/form)
#[oai(path = "/summarize", method = "post")]
pub async fn ocr_summarize(
&self,
provider: Query<String>,
guild_id: Query<String>,
user_id: Query<String>,
channel_id: Query<Option<String>>,
model: Query<Option<String>>,
temperature: Query<Option<f32>>,
max_tokens: Query<Option<u32>>,
top_p: Query<Option<f32>>,
style: Query<Option<String>>,
image: Binary<Vec<u8>>,
) -> OcrSummarizeResult {
// Parse provider
let provider_enum = match provider.0.to_lowercase().as_str() {
"anthropic" => LlmProvider::Anthropic,
"ollama" => LlmProvider::Ollama,
_ => {
return OcrSummarizeResult::Error(Json(ErrorResponse {
success: false,
error: "Invalid provider. Must be 'anthropic' or 'ollama'".to_string(),
}));
}
};
// Parse style if provided
let style_enum = if let Some(ref s) = style.0 {
match s.to_lowercase().as_str() {
"brief" => Some(SummarizationStyle::Brief),
"detailed" => Some(SummarizationStyle::Detailed),
"funny" => Some(SummarizationStyle::Funny),
"professional" => Some(SummarizationStyle::Professional),
"technical" => Some(SummarizationStyle::Technical),
"eli5" => Some(SummarizationStyle::Eli5),
"bullets" => Some(SummarizationStyle::Bullets),
"academic" => Some(SummarizationStyle::Academic),
_ => None,
}
} else {
None
};
let params = OcrSummarizeRequest {
provider: provider_enum.clone(),
model: model.0.clone(),
temperature: temperature.0,
max_tokens: max_tokens.0,
top_p: top_p.0,
style: style_enum.clone(),
system_prompt: None,
guild_id: guild_id.0.clone(),
user_id: user_id.0.clone(),
channel_id: channel_id.0.clone(),
};
// Validate request
if let Err(e) = params.validate() {
return OcrSummarizeResult::Error(Json(ErrorResponse {
success: false,
error: format!("Validation error: {}", e),
}));
}
let total_start = Instant::now();
// Check rate limiting if enabled
if self.state.config.rate_limiting.enabled {
let (guild_limited, user_limited) = self.state.rate_limiter.check_rate_limit(
&params.guild_id,
&params.user_id,
self.state.config.rate_limiting.guild_rate_limit,
self.state.config.rate_limiting.user_rate_limit,
).await;
if guild_limited {
return OcrSummarizeResult::Error(Json(ErrorResponse {
success: false,
error: "Guild rate limit exceeded".to_string(),
}));
}
if user_limited {
return OcrSummarizeResult::Error(Json(ErrorResponse {
success: false,
error: "User rate limit exceeded".to_string(),
}));
}
}
// Get image data
let image_data = image.0;
// Perform OCR
let (extracted_text, image_info, ocr_time_ms) = match self.state.ocr_processor.extract_text_from_image(&image_data).await {
Ok(result) => result,
Err(e) => {
return OcrSummarizeResult::Error(Json(ErrorResponse {
success: false,
error: format!("OCR failed: {}", e),
}));
}
};
info!("OCR extracted {} characters in {}ms", extracted_text.len(), ocr_time_ms);
// Create summarize request from extracted text
let summarize_request = SummarizeRequest {
text: extracted_text.clone(),
provider: params.provider.clone(),
model: params.model.clone(),
temperature: params.temperature,
max_tokens: params.max_tokens,
top_p: params.top_p,
style: params.style.clone(),
system_prompt: params.system_prompt.clone(),
guild_id: params.guild_id.clone(),
user_id: params.user_id.clone(),
channel_id: params.channel_id.clone(),
};
// Calculate checksum for caching
let checksum = checksum::calculate_checksum(&summarize_request);
let summarization_start = Instant::now();
// Check cache if enabled
if self.state.config.cache.enabled {
if let Some(cached) = self.state.cache.get(&checksum).await {
info!("Cache hit for OCR summary checksum: {}", checksum);
let total_time = total_start.elapsed().as_millis() as u64;
return OcrSummarizeResult::Ok(Json(OcrSummarizeResponse {
extracted_text,
summary: cached.summary,
model: cached.model,
provider: cached.provider,
from_cache: true,
checksum: cached.checksum,
timestamp: Utc::now().to_rfc3339(),
token_usage: None,
ocr_time_ms,
summarization_time_ms: Some(0),
total_time_ms: total_time,
style_used: params.style.as_ref().map(|s| format!("{:?}", s).to_lowercase()),
image_info,
}));
}
}
// Call LLM for summarization
match self.state.llm_client.summarize(&summarize_request).await {
Ok((summary, model, token_usage)) => {
let provider = format!("{:?}", params.provider).to_lowercase();
let summarization_time_ms = summarization_start.elapsed().as_millis() as u64;
// Store in cache if enabled
if self.state.config.cache.enabled {
let cache_entry = CacheEntry {
id: chrono::Utc::now().timestamp(),
checksum: checksum.clone(),
text: extracted_text.clone(),
summary: summary.clone(),
provider: provider.clone(),
model: model.clone(),
guild_id: params.guild_id.clone(),
user_id: params.user_id.clone(),
channel_id: params.channel_id.clone(),
created_at: Utc::now().to_rfc3339(),
last_accessed: Utc::now().to_rfc3339(),
access_count: 1,
};
self.state.cache.insert(checksum.clone(), cache_entry).await;
}
let total_time = total_start.elapsed().as_millis() as u64;
OcrSummarizeResult::Ok(Json(OcrSummarizeResponse {
extracted_text,
summary,
model,
provider,
from_cache: false,
checksum,
timestamp: Utc::now().to_rfc3339(),
token_usage: Some(token_usage),
ocr_time_ms,
summarization_time_ms: Some(summarization_time_ms),
total_time_ms: total_time,
style_used: params.style.as_ref().map(|s| format!("{:?}", s).to_lowercase()),
image_info,
}))
}
Err(e) => OcrSummarizeResult::Error(Json(ErrorResponse {
success: false,
error: format!("Summarization failed: {}", e),
})),
}
}
}

View File

@@ -0,0 +1,205 @@
use poem_openapi::{payload::Json, OpenApi};
use std::sync::Arc;
use std::time::Instant;
use chrono::Utc;
use tracing::info;
use validator::Validate;
use crate::models::*;
use crate::v1::responses::SummarizeResult;
use crate::checksum;
use super::super::AppState;
pub struct SummarizeApi {
pub state: Arc<AppState>,
}
#[OpenApi(prefix_path = "/summarize")]
impl SummarizeApi {
/// Summarize text using specified LLM provider
#[oai(path = "/", method = "post")]
pub async fn summarize(
&self,
request: Json<SummarizeRequest>,
) -> SummarizeResult {
let request = request.0;
// Validate request
if let Err(e) = request.validate() {
return SummarizeResult::Error(Json(ErrorResponse {
success: false,
error: format!("Validation error: {}", e),
}));
}
let start_time = Instant::now();
// Check credits if enabled
let credits_needed = self.state.config.credits.credits_per_summary;
if self.state.config.credits.enabled {
match self.state.credit_system.check_credits(
self.state.database.pool(),
&request.user_id,
credits_needed,
).await {
Ok(result) => {
if !result.allowed {
return SummarizeResult::Error(Json(ErrorResponse {
success: false,
error: result.reason.unwrap_or_else(|| "Insufficient credits".to_string()),
}));
}
}
Err(e) => {
return SummarizeResult::Error(Json(ErrorResponse {
success: false,
error: format!("Credit check failed: {}", e),
}));
}
}
}
// Check rate limiting if enabled
if self.state.config.rate_limiting.enabled {
let (guild_limited, user_limited) = self.state.rate_limiter.check_rate_limit(
&request.guild_id,
&request.user_id,
self.state.config.rate_limiting.guild_rate_limit,
self.state.config.rate_limiting.user_rate_limit,
).await;
if guild_limited {
return SummarizeResult::Error(Json(ErrorResponse {
success: false,
error: "Guild rate limit exceeded".to_string(),
}));
}
if user_limited {
return SummarizeResult::Error(Json(ErrorResponse {
success: false,
error: "User rate limit exceeded".to_string(),
}));
}
}
// Calculate checksum for caching
let checksum = checksum::calculate_checksum(&request);
// Check cache if enabled (hybrid: memory first, then database)
if self.state.config.cache.enabled {
// First check in-memory cache
if let Some(cached) = self.state.cache.get(&checksum).await {
info!("Memory cache hit for checksum: {}", checksum);
let processing_time = start_time.elapsed().as_millis() as u64;
return SummarizeResult::Ok(Json(SummarizeResponse {
summary: cached.summary,
model: cached.model,
provider: cached.provider,
from_cache: true,
checksum: cached.checksum,
timestamp: Utc::now().to_rfc3339(),
token_usage: None,
processing_time_ms: Some(processing_time),
style_used: None,
}));
}
// If not in memory, check database
if let Ok(Some(cached)) = self.state.database.get_cache_by_checksum(&checksum).await {
info!("Database cache hit for checksum: {}", checksum);
// Add to memory cache for faster future access
self.state.cache.insert(checksum.clone(), cached.clone()).await;
let processing_time = start_time.elapsed().as_millis() as u64;
return SummarizeResult::Ok(Json(SummarizeResponse {
summary: cached.summary,
model: cached.model,
provider: cached.provider,
from_cache: true,
checksum: cached.checksum,
timestamp: Utc::now().to_rfc3339(),
token_usage: None,
processing_time_ms: Some(processing_time),
style_used: None,
}));
}
}
// Call LLM
match self.state.llm_client.summarize(&request).await {
Ok((summary, model, token_usage)) => {
let provider = format!("{:?}", request.provider).to_lowercase();
// Deduct credits after successful summarization
if self.state.config.credits.enabled {
if let Err(e) = self.state.credit_system.deduct_credits(
self.state.database.pool(),
&request.user_id,
credits_needed,
Some(&request.guild_id),
&format!("Summarization ({})", provider),
).await {
tracing::warn!("Failed to deduct credits: {}", e);
}
}
// Store in cache if enabled (hybrid: both memory and database)
if self.state.config.cache.enabled {
// Store in database first
let db_result = self.state.database.insert_cache(
&checksum,
&request.text,
&summary,
&provider,
&model,
&request.guild_id,
&request.user_id,
request.channel_id.as_deref(),
request.temperature,
request.max_tokens,
request.top_p,
request.system_prompt.as_deref(),
).await;
let entry_id = db_result.unwrap_or_else(|e| {
tracing::warn!("Failed to store in database: {}", e);
chrono::Utc::now().timestamp()
});
// Also store in memory cache
let cache_entry = CacheEntry {
id: entry_id,
checksum: checksum.clone(),
text: request.text.clone(),
summary: summary.clone(),
provider: provider.clone(),
model: model.clone(),
guild_id: request.guild_id.clone(),
user_id: request.user_id.clone(),
channel_id: request.channel_id.clone(),
created_at: Utc::now().to_rfc3339(),
last_accessed: Utc::now().to_rfc3339(),
access_count: 1,
};
self.state.cache.insert(checksum.clone(), cache_entry).await;
}
let processing_time = start_time.elapsed().as_millis() as u64;
SummarizeResult::Ok(Json(SummarizeResponse {
summary,
model,
provider,
from_cache: false,
checksum,
timestamp: Utc::now().to_rfc3339(),
token_usage: Some(token_usage),
processing_time_ms: Some(processing_time),
style_used: request.style.as_ref().map(|s| format!("{:?}", s).to_lowercase()),
}))
}
Err(e) => SummarizeResult::Error(Json(ErrorResponse {
success: false,
error: format!("LLM error: {}", e),
})),
}
}
}