From 58028bb7ae6b5de0aab4f12b08fca43c563018a5 Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Sat, 9 Aug 2025 20:44:26 -0700 Subject: [PATCH] more refactoring changes, getting close --- crates/hermesllm/src/clients/transformer.rs | 4 +- crates/hermesllm/src/lib.rs | 8 +- crates/hermesllm/src/providers/arch/mod.rs | 6 + .../hermesllm/src/providers/arch/provider.rs | 40 ++++++ crates/hermesllm/src/providers/claude/mod.rs | 7 ++ .../src/providers/claude/provider.rs | 48 ++++++++ .../hermesllm/src/providers/deepseek/mod.rs | 6 + .../src/providers/deepseek/provider.rs | 40 ++++++ crates/hermesllm/src/providers/gemini/mod.rs | 7 ++ .../src/providers/gemini/provider.rs | 48 ++++++++ crates/hermesllm/src/providers/github/mod.rs | 7 ++ .../src/providers/github/provider.rs | 48 ++++++++ crates/hermesllm/src/providers/groq/mod.rs | 6 + .../hermesllm/src/providers/groq/provider.rs | 43 +++++++ crates/hermesllm/src/providers/interface.rs | 109 ----------------- crates/hermesllm/src/providers/mistral/mod.rs | 6 + .../src/providers/mistral/provider.rs | 40 ++++++ crates/hermesllm/src/providers/mod.rs | 115 +++++++----------- .../src/providers/openai/provider.rs | 32 ++++- crates/hermesllm/src/providers/traits.rs | 71 ++++++++++- crates/llm_gateway/src/stream_context.rs | 68 +++++++---- 21 files changed, 542 insertions(+), 217 deletions(-) create mode 100644 crates/hermesllm/src/providers/arch/mod.rs create mode 100644 crates/hermesllm/src/providers/arch/provider.rs create mode 100644 crates/hermesllm/src/providers/claude/mod.rs create mode 100644 crates/hermesllm/src/providers/claude/provider.rs create mode 100644 crates/hermesllm/src/providers/deepseek/mod.rs create mode 100644 crates/hermesllm/src/providers/deepseek/provider.rs create mode 100644 crates/hermesllm/src/providers/gemini/mod.rs create mode 100644 crates/hermesllm/src/providers/gemini/provider.rs create mode 100644 crates/hermesllm/src/providers/github/mod.rs create mode 100644 crates/hermesllm/src/providers/github/provider.rs create mode 100644 crates/hermesllm/src/providers/groq/mod.rs create mode 100644 crates/hermesllm/src/providers/groq/provider.rs delete mode 100644 crates/hermesllm/src/providers/interface.rs create mode 100644 crates/hermesllm/src/providers/mistral/mod.rs create mode 100644 crates/hermesllm/src/providers/mistral/provider.rs diff --git a/crates/hermesllm/src/clients/transformer.rs b/crates/hermesllm/src/clients/transformer.rs index c6d524f4..23ca26ee 100644 --- a/crates/hermesllm/src/clients/transformer.rs +++ b/crates/hermesllm/src/clients/transformer.rs @@ -13,14 +13,14 @@ //! //! ```rust //! use hermesllm::apis::{ -//! AnthropicMessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage, +//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage, //! MessagesMessageContent, MessagesSystemPrompt, //! }; //! use hermesllm::clients::TransformError; //! use std::convert::TryInto; //! //! // Transform Anthropic to OpenAI -//! let anthropic_req = AnthropicMessagesRequest { +//! let anthropic_req = MessagesRequest { //! model: "claude-3-sonnet".to_string(), //! system: None, //! messages: vec![], diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 4a81ed90..dac18303 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -46,18 +46,18 @@ mod tests { #[test] fn test_provider_instance_creation() { let provider = Provider::new(ProviderId::OpenAI); - assert!(provider.has_compatible_api("/v1/chat/completions")); - assert!(!provider.has_compatible_api("/v1/embeddings")); + assert!(provider.interface().has_compatible_api("/v1/chat/completions")); + assert!(!provider.interface().has_compatible_api("/v1/embeddings")); } #[test] fn test_conversion_mode() { let provider = Provider::new(ProviderId::OpenAI); - let compatible_mode = provider.get_interface(false); + let compatible_mode = provider.interface().get_interface(false); assert!(matches!(compatible_mode, ConversionMode::Compatible)); - let passthrough_mode = provider.get_interface(true); + let passthrough_mode = provider.interface().get_interface(true); assert!(matches!(passthrough_mode, ConversionMode::Passthrough)); } } diff --git a/crates/hermesllm/src/providers/arch/mod.rs b/crates/hermesllm/src/providers/arch/mod.rs new file mode 100644 index 00000000..7673d127 --- /dev/null +++ b/crates/hermesllm/src/providers/arch/mod.rs @@ -0,0 +1,6 @@ +//! Arch provider implementation +//! +//! Arch uses OpenAI-compatible API format + +pub mod provider; +pub use provider::ArchProvider; diff --git a/crates/hermesllm/src/providers/arch/provider.rs b/crates/hermesllm/src/providers/arch/provider.rs new file mode 100644 index 00000000..cb6a3692 --- /dev/null +++ b/crates/hermesllm/src/providers/arch/provider.rs @@ -0,0 +1,40 @@ +//! Arch provider implementation + +use crate::providers::{ProviderInterface, ConversionMode}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; +use crate::providers::traits::{ProviderRequest, ProviderResponse}; + +/// Arch provider implementation +#[derive(Debug, Clone)] +pub struct ArchProvider; + +impl ProviderInterface for ArchProvider { + fn has_compatible_api(&self, api_path: &str) -> bool { + matches!(api_path, "/v1/chat/completions") + } + + fn supported_apis(&self) -> Vec<&'static str> { + vec!["/v1/chat/completions"] + } + + fn parse_request(&self, bytes: &[u8]) -> Result> { + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + + fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + + fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } +} diff --git a/crates/hermesllm/src/providers/claude/mod.rs b/crates/hermesllm/src/providers/claude/mod.rs new file mode 100644 index 00000000..1a574112 --- /dev/null +++ b/crates/hermesllm/src/providers/claude/mod.rs @@ -0,0 +1,7 @@ +//! Claude provider implementation +//! +//! Claude will use a different API format in the future (/v1/messages) +//! For now, fallback to OpenAI-compatible format + +pub mod provider; +pub use provider::ClaudeProvider; diff --git a/crates/hermesllm/src/providers/claude/provider.rs b/crates/hermesllm/src/providers/claude/provider.rs new file mode 100644 index 00000000..c9c4982b --- /dev/null +++ b/crates/hermesllm/src/providers/claude/provider.rs @@ -0,0 +1,48 @@ +//! Claude provider implementation +//! +//! TODO: Implement Claude-specific API format (/v1/messages) when needed +//! For now, uses OpenAI-compatible format as fallback + +use crate::providers::{ProviderInterface, ConversionMode}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; +use crate::providers::traits::{ProviderRequest, ProviderResponse}; + +/// Claude provider implementation +#[derive(Debug, Clone)] +pub struct ClaudeProvider; + +impl ProviderInterface for ClaudeProvider { + fn has_compatible_api(&self, api_path: &str) -> bool { + // TODO: Update when Claude API is fully implemented + matches!(api_path, "/v1/chat/completions" | "/v1/messages") + } + + fn supported_apis(&self) -> Vec<&'static str> { + // TODO: Update when Claude API is fully implemented + vec!["/v1/messages"] + } + + fn parse_request(&self, bytes: &[u8]) -> Result> { + // TODO: Implement Claude-specific request parsing + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + + fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { + // TODO: Implement Claude-specific response parsing + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + + fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { + // TODO: Implement Claude-specific request serialization + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } +} diff --git a/crates/hermesllm/src/providers/deepseek/mod.rs b/crates/hermesllm/src/providers/deepseek/mod.rs new file mode 100644 index 00000000..d020a57f --- /dev/null +++ b/crates/hermesllm/src/providers/deepseek/mod.rs @@ -0,0 +1,6 @@ +//! Deepseek provider implementation +//! +//! Deepseek uses OpenAI-compatible API format + +pub mod provider; +pub use provider::DeepseekProvider; diff --git a/crates/hermesllm/src/providers/deepseek/provider.rs b/crates/hermesllm/src/providers/deepseek/provider.rs new file mode 100644 index 00000000..92cc0fa4 --- /dev/null +++ b/crates/hermesllm/src/providers/deepseek/provider.rs @@ -0,0 +1,40 @@ +//! Deepseek provider implementation + +use crate::providers::{ProviderInterface, ConversionMode}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; +use crate::providers::traits::{ProviderRequest, ProviderResponse}; + +/// Deepseek provider implementation +#[derive(Debug, Clone)] +pub struct DeepseekProvider; + +impl ProviderInterface for DeepseekProvider { + fn has_compatible_api(&self, api_path: &str) -> bool { + matches!(api_path, "/v1/chat/completions") + } + + fn supported_apis(&self) -> Vec<&'static str> { + vec!["/v1/chat/completions"] + } + + fn parse_request(&self, bytes: &[u8]) -> Result> { + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + + fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + + fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } +} diff --git a/crates/hermesllm/src/providers/gemini/mod.rs b/crates/hermesllm/src/providers/gemini/mod.rs new file mode 100644 index 00000000..7634dad7 --- /dev/null +++ b/crates/hermesllm/src/providers/gemini/mod.rs @@ -0,0 +1,7 @@ +//! Gemini provider implementation +//! +//! Gemini will use a different API format in the future +//! For now, fallback to OpenAI-compatible format + +pub mod provider; +pub use provider::GeminiProvider; diff --git a/crates/hermesllm/src/providers/gemini/provider.rs b/crates/hermesllm/src/providers/gemini/provider.rs new file mode 100644 index 00000000..55b0c471 --- /dev/null +++ b/crates/hermesllm/src/providers/gemini/provider.rs @@ -0,0 +1,48 @@ +//! Gemini provider implementation +//! +//! TODO: Implement Gemini-specific API format when needed +//! For now, uses OpenAI-compatible format as fallback + +use crate::providers::{ProviderInterface, ConversionMode}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; +use crate::providers::traits::{ProviderRequest, ProviderResponse}; + +/// Gemini provider implementation +#[derive(Debug, Clone)] +pub struct GeminiProvider; + +impl ProviderInterface for GeminiProvider { + fn has_compatible_api(&self, api_path: &str) -> bool { + // TODO: Update when Gemini API is fully implemented + matches!(api_path, "/v1/chat/completions" | "/v1/models") + } + + fn supported_apis(&self) -> Vec<&'static str> { + // TODO: Update when Gemini API is fully implemented + vec!["/v1/models"] + } + + fn parse_request(&self, bytes: &[u8]) -> Result> { + // TODO: Implement Gemini-specific request parsing + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + + fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { + // TODO: Implement Gemini-specific response parsing + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + + fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { + // TODO: Implement Gemini-specific request serialization + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } +} diff --git a/crates/hermesllm/src/providers/github/mod.rs b/crates/hermesllm/src/providers/github/mod.rs new file mode 100644 index 00000000..d0b9ea79 --- /dev/null +++ b/crates/hermesllm/src/providers/github/mod.rs @@ -0,0 +1,7 @@ +//! GitHub provider implementation +//! +//! GitHub will use a different API format in the future (/models) +//! For now, fallback to OpenAI-compatible format + +pub mod provider; +pub use provider::GitHubProvider; diff --git a/crates/hermesllm/src/providers/github/provider.rs b/crates/hermesllm/src/providers/github/provider.rs new file mode 100644 index 00000000..cbd5bb01 --- /dev/null +++ b/crates/hermesllm/src/providers/github/provider.rs @@ -0,0 +1,48 @@ +//! GitHub provider implementation +//! +//! TODO: Implement GitHub-specific API format (/models) when needed +//! For now, uses OpenAI-compatible format as fallback + +use crate::providers::{ProviderInterface, ConversionMode}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; +use crate::providers::traits::{ProviderRequest, ProviderResponse}; + +/// GitHub provider implementation +#[derive(Debug, Clone)] +pub struct GitHubProvider; + +impl ProviderInterface for GitHubProvider { + fn has_compatible_api(&self, api_path: &str) -> bool { + // TODO: Update when GitHub API is fully implemented + matches!(api_path, "/v1/chat/completions" | "/models") + } + + fn supported_apis(&self) -> Vec<&'static str> { + // TODO: Update when GitHub API is fully implemented + vec!["/models"] + } + + fn parse_request(&self, bytes: &[u8]) -> Result> { + // TODO: Implement GitHub-specific request parsing + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + + fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { + // TODO: Implement GitHub-specific response parsing + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + + fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { + // TODO: Implement GitHub-specific request serialization + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } +} diff --git a/crates/hermesllm/src/providers/groq/mod.rs b/crates/hermesllm/src/providers/groq/mod.rs new file mode 100644 index 00000000..a3edd84b --- /dev/null +++ b/crates/hermesllm/src/providers/groq/mod.rs @@ -0,0 +1,6 @@ +//! Groq provider implementation +//! +//! Groq uses OpenAI-compatible API format but with different endpoints + +pub mod provider; +pub use provider::GroqProvider; diff --git a/crates/hermesllm/src/providers/groq/provider.rs b/crates/hermesllm/src/providers/groq/provider.rs new file mode 100644 index 00000000..94eb4568 --- /dev/null +++ b/crates/hermesllm/src/providers/groq/provider.rs @@ -0,0 +1,43 @@ +//! Groq provider implementation +//! +//! This module contains the Groq provider that handles Groq API format requests. +//! Groq uses OpenAI-compatible format but may have provider-specific nuances. + +use crate::providers::{ProviderInterface, ConversionMode}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; +use crate::providers::traits::{ProviderRequest, ProviderResponse}; + +/// Groq provider implementation +#[derive(Debug, Clone)] +pub struct GroqProvider; + +impl ProviderInterface for GroqProvider { + fn has_compatible_api(&self, api_path: &str) -> bool { + matches!(api_path, "/v1/chat/completions" | "/openai/v1/chat/completions") + } + + fn supported_apis(&self) -> Vec<&'static str> { + vec!["/openai/v1/chat/completions"] + } + + fn parse_request(&self, bytes: &[u8]) -> Result> { + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + + fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + + fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } +} diff --git a/crates/hermesllm/src/providers/interface.rs b/crates/hermesllm/src/providers/interface.rs deleted file mode 100644 index 4026516e..00000000 --- a/crates/hermesllm/src/providers/interface.rs +++ /dev/null @@ -1,109 +0,0 @@ -//! Provider interface trait definitions -//! -//! This module defines the core traits that all LLM providers must implement. -//! The interface is designed around v1/chat/completions API for simplicity. - -use std::error::Error; - -/// Conversion mode for provider requests/responses -#[derive(Debug, Clone, Copy)] -pub enum ConversionMode { - /// Compatible: Convert between different provider formats to ensure compatibility - Compatible, - /// Passthrough: Pass requests/responses through with minimal modification - Passthrough, -} - -/// Token usage information -pub trait TokenUsage { - fn completion_tokens(&self) -> usize; - fn prompt_tokens(&self) -> usize; - fn total_tokens(&self) -> usize; -} - -/// Error type for provider operations -pub trait ProviderError: Error + Send + Sync + 'static {} - -/// Request type that can be converted to/from provider-specific formats -pub trait ProviderRequest: Sized { - type Error: ProviderError; - - /// Parse request from raw bytes (typically JSON) - fn from_bytes(bytes: &[u8]) -> Result; - - /// Convert to bytes for sending to upstream API - fn to_bytes(&self, mode: ConversionMode) -> Result, Self::Error>; - - /// Extract the model name from the request - fn model(&self) -> &str; - - /// Check if this is a streaming request - fn is_streaming(&self) -> bool; - - /// Set streaming options (e.g., include_usage) - fn set_streaming_options(&mut self); - - /// Extract text content from messages for token counting - fn extract_text(&self) -> String; -} - -/// Response type that can be converted to/from provider-specific formats -pub trait ProviderResponse: Sized { - type Error: ProviderError; - type Usage: TokenUsage; - - /// Parse response from raw bytes (typically JSON) - fn from_bytes(bytes: &[u8], mode: ConversionMode) -> Result; - - /// Convert to bytes for sending to client - fn to_bytes(&self) -> Result, Self::Error>; - - /// Get usage information if available - fn usage(&self) -> Option<&Self::Usage>; -} - -/// Streaming response chunk -pub trait StreamChunk: Sized { - type Error: ProviderError; - type Usage: TokenUsage; - - /// Parse chunk from a line of streaming data - fn from_line(line: &str, mode: ConversionMode) -> Result, Self::Error>; - - /// Convert to line for sending to client - fn to_line(&self) -> Result; - - /// Get usage information if available (usually only in final chunk) - fn usage(&self) -> Option<&Self::Usage>; - - /// Check if this is the final chunk in the stream - fn is_final(&self) -> bool; -} - -/// Main provider interface -pub trait LLMProvider { - type Request: ProviderRequest; - type Response: ProviderResponse; - type StreamChunk: StreamChunk; - type Error: ProviderError; - - /// Create a new instance of this provider - fn new() -> Self; - - /// Get the supported API endpoints for this provider - fn supported_apis(&self) -> Vec<&'static str>; - - /// Check if the provider supports v1/chat/completions API - fn supports_chat_completions(&self) -> bool { - self.supported_apis().contains(&"/v1/chat/completions") - } - - /// Parse a request from raw bytes - fn parse_request(&self, bytes: &[u8]) -> Result; - - /// Parse a response from raw bytes - fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result; - - /// Parse streaming response chunks from raw data - fn parse_stream_chunk(&self, line: &str, mode: ConversionMode) -> Result, Self::Error>; -} diff --git a/crates/hermesllm/src/providers/mistral/mod.rs b/crates/hermesllm/src/providers/mistral/mod.rs new file mode 100644 index 00000000..3d491919 --- /dev/null +++ b/crates/hermesllm/src/providers/mistral/mod.rs @@ -0,0 +1,6 @@ +//! Mistral provider implementation +//! +//! Mistral uses OpenAI-compatible API format + +pub mod provider; +pub use provider::MistralProvider; diff --git a/crates/hermesllm/src/providers/mistral/provider.rs b/crates/hermesllm/src/providers/mistral/provider.rs new file mode 100644 index 00000000..a36d6774 --- /dev/null +++ b/crates/hermesllm/src/providers/mistral/provider.rs @@ -0,0 +1,40 @@ +//! Mistral provider implementation + +use crate::providers::{ProviderInterface, ConversionMode}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; +use crate::providers::traits::{ProviderRequest, ProviderResponse}; + +/// Mistral provider implementation +#[derive(Debug, Clone)] +pub struct MistralProvider; + +impl ProviderInterface for MistralProvider { + fn has_compatible_api(&self, api_path: &str) -> bool { + matches!(api_path, "/v1/chat/completions") + } + + fn supported_apis(&self) -> Vec<&'static str> { + vec!["/v1/chat/completions"] + } + + fn parse_request(&self, bytes: &[u8]) -> Result> { + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + + fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + + fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } +} diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 06844441..eafebe26 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -5,10 +5,24 @@ pub mod traits; pub mod openai; +pub mod groq; +pub mod mistral; +pub mod deepseek; +pub mod arch; +pub mod gemini; +pub mod claude; +pub mod github; // Re-export the main interfaces pub use traits::*; pub use openai::OpenAIProvider; +pub use groq::GroqProvider; +pub use mistral::MistralProvider; +pub use deepseek::DeepseekProvider; +pub use arch::ArchProvider; +pub use gemini::GeminiProvider; +pub use claude::ClaudeProvider; +pub use github::GitHubProvider; use std::fmt::Display; @@ -81,27 +95,29 @@ impl ProviderId { } /// Enum for dynamic dispatch of provider instances -/// For now, most providers use OpenAI-compatible format pub enum Provider { OpenAI(OpenAIProvider, ProviderId), - // TODO: Add specific implementations when providers have different APIs - // Mistral(MistralProvider, ProviderId), - // Groq(GroqProvider, ProviderId), - // etc. + Groq(GroqProvider, ProviderId), + Mistral(MistralProvider, ProviderId), + Deepseek(DeepseekProvider, ProviderId), + Arch(ArchProvider, ProviderId), + Gemini(GeminiProvider, ProviderId), + Claude(ClaudeProvider, ProviderId), + GitHub(GitHubProvider, ProviderId), } impl Provider { /// Create a provider instance from a provider ID pub fn new(id: ProviderId) -> Self { match id { - // For now, all providers that support v1/chat/completions use OpenAI format - ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch => { - Provider::OpenAI(OpenAIProvider, id) - } - // TODO: Implement specific providers when they have different APIs - ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { - Provider::OpenAI(OpenAIProvider, id) // Fallback to OpenAI for now - } + ProviderId::OpenAI => Provider::OpenAI(OpenAIProvider, id), + ProviderId::Groq => Provider::Groq(GroqProvider, id), + ProviderId::Mistral => Provider::Mistral(MistralProvider, id), + ProviderId::Deepseek => Provider::Deepseek(DeepseekProvider, id), + ProviderId::Arch => Provider::Arch(ArchProvider, id), + ProviderId::Gemini => Provider::Gemini(GeminiProvider, id), + ProviderId::Claude => Provider::Claude(ClaudeProvider, id), + ProviderId::GitHub => Provider::GitHub(GitHubProvider, id), } } @@ -109,66 +125,27 @@ impl Provider { pub fn id(&self) -> ProviderId { match self { Provider::OpenAI(_, id) => *id, + Provider::Groq(_, id) => *id, + Provider::Mistral(_, id) => *id, + Provider::Deepseek(_, id) => *id, + Provider::Arch(_, id) => *id, + Provider::Gemini(_, id) => *id, + Provider::Claude(_, id) => *id, + Provider::GitHub(_, id) => *id, } } - /// Check if this provider has a compatible API with the client request - pub fn has_compatible_api(&self, api_path: &str) -> bool { + /// Get the provider interface implementation + pub fn interface(&self) -> &dyn ProviderInterface { match self { - Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path), - } - } - - /// Get the interface implementation for this provider - pub fn get_interface(&self, passthrough: bool) -> ConversionMode { - match self { - Provider::OpenAI(provider, _) => provider.get_interface(passthrough), - } - } - - /// Parse a request from raw bytes - returns the concrete OpenAI request type for now - pub fn parse_request(&self, bytes: &[u8]) -> Result> { - match self { - Provider::OpenAI(_, _) => { - use crate::apis::openai::ChatCompletionsRequest; - use crate::providers::traits::ProviderRequest; - - match ChatCompletionsRequest::try_from_bytes(bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - } - } - - /// Parse a response from raw bytes - returns the concrete OpenAI response type for now - pub fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result> { - match self { - Provider::OpenAI(_, _) => { - use crate::apis::openai::ChatCompletionsResponse; - use crate::providers::traits::ProviderResponse; - - let provider_id = self.id(); - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - } - } - - /// Convert a request to bytes for sending to upstream API - pub fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, mode: ConversionMode) -> Result, Box> { - match self { - Provider::OpenAI(_, _) => { - use crate::providers::traits::ProviderRequest; - - let provider_id = self.id(); - match request.to_provider_bytes(provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } + Provider::OpenAI(provider, _) => provider, + Provider::Groq(provider, _) => provider, + Provider::Mistral(provider, _) => provider, + Provider::Deepseek(provider, _) => provider, + Provider::Arch(provider, _) => provider, + Provider::Gemini(provider, _) => provider, + Provider::Claude(provider, _) => provider, + Provider::GitHub(provider, _) => provider, } } } diff --git a/crates/hermesllm/src/providers/openai/provider.rs b/crates/hermesllm/src/providers/openai/provider.rs index 3f5677c6..a0f0179d 100644 --- a/crates/hermesllm/src/providers/openai/provider.rs +++ b/crates/hermesllm/src/providers/openai/provider.rs @@ -68,11 +68,6 @@ impl Iterator for OpenAIStreamingResponse { } impl ProviderInterface for OpenAIProvider { - type Request = ChatCompletionsRequest; - type Response = ChatCompletionsResponse; - type StreamingResponse = OpenAIStreamingResponse; - type Usage = Usage; - fn has_compatible_api(&self, api_path: &str) -> bool { api_path == "/v1/chat/completions" } @@ -80,6 +75,30 @@ impl ProviderInterface for OpenAIProvider { fn supported_apis(&self) -> Vec<&'static str> { vec!["/v1/chat/completions"] } + + fn parse_request(&self, bytes: &[u8]) -> Result> { + use crate::providers::traits::ProviderRequest; + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + + fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { + use crate::providers::traits::ProviderResponse; + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + + fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { + use crate::providers::traits::ProviderRequest; + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } } // ============================================================================ @@ -131,6 +150,9 @@ impl ProviderRequest for ChatCompletionsRequest { } } +// Implement the helper trait for stream context integration +impl crate::providers::traits::StreamContextHelpers for ChatCompletionsRequest {} + impl TokenUsage for Usage { fn completion_tokens(&self) -> usize { self.completion_tokens as usize diff --git a/crates/hermesllm/src/providers/traits.rs b/crates/hermesllm/src/providers/traits.rs index 8bc8d519..792c8740 100644 --- a/crates/hermesllm/src/providers/traits.rs +++ b/crates/hermesllm/src/providers/traits.rs @@ -54,6 +54,31 @@ pub trait ProviderResponse: Sized { /// Get usage information if available fn usage(&self) -> Option<&Self::Usage>; + + /// Extract token counts for metrics + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) + } +} + +/// Helper trait for stream context integration +pub trait StreamContextHelpers: ProviderRequest { + /// Get the model name for routing and metrics + fn get_model_for_routing(&self) -> String { + self.extract_model().to_string() + } + + /// Get text for token counting and rate limiting + fn get_text_for_tokenization(&self) -> String { + self.extract_messages_text() + } + + /// Prepare for streaming by setting appropriate options + fn prepare_for_streaming(&mut self) { + if self.is_streaming() { + self.set_streaming_options(); + } + } } /// Trait for streaming response chunks @@ -75,11 +100,6 @@ pub trait StreamingResponse: Iterator> + /// Main provider interface trait pub trait ProviderInterface { - type Request: ProviderRequest; - type Response: ProviderResponse; - type StreamingResponse: StreamingResponse; - type Usage: TokenUsage; - /// Check if this provider has a compatible API with the client request fn has_compatible_api(&self, api_path: &str) -> bool; @@ -93,6 +113,47 @@ pub trait ProviderInterface { } } + /// Parse a request from raw bytes - returns concrete ChatCompletionsRequest + fn parse_request(&self, bytes: &[u8]) -> Result>; + + /// Parse a response from raw bytes - returns concrete ChatCompletionsResponse + fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result>; + + /// Convert a request to bytes for sending to upstream API + fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result, Box>; + + /// Extract model name from request for routing (convenience method for stream_context) + fn extract_model_from_request(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String { + use ProviderRequest; + request.extract_model().to_string() + } + + /// Check if request is streaming (convenience method for stream_context) + fn is_request_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool { + use ProviderRequest; + request.is_streaming() + } + + /// Prepare request for streaming (convenience method for stream_context) + fn prepare_request_for_streaming(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) { + use ProviderRequest; + if request.is_streaming() { + request.set_streaming_options(); + } + } + + /// Extract text for tokenization (convenience method for stream_context) + fn extract_text_for_tokenization(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String { + use ProviderRequest; + request.extract_messages_text() + } + + /// Extract usage information from response (convenience method for stream_context) + fn extract_usage_from_response(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + use ProviderResponse; + response.extract_usage_counts() + } + /// Get supported API endpoints for this provider fn supported_apis(&self) -> Vec<&'static str>; } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 0beb18d5..f8eb051d 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -10,7 +10,7 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::{ConversionMode, Provider, ProviderId, ProviderRequest}; +use hermesllm::{ConversionMode, Provider, ProviderId}; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -295,7 +295,7 @@ impl HttpContext for StreamContext { let provider = self.get_provider(); - let mut deserialized_body = match provider.parse_request(&body_bytes) { + let mut deserialized_body = match provider.interface().parse_request(&body_bytes) { Ok(deserialized) => deserialized, Err(e) => { debug!( @@ -310,8 +310,8 @@ impl HttpContext for StreamContext { } }; - // TODO: For now, we'll need to handle user_message extraction differently since it's OpenAI-specific - // This could be made generic by adding a trait method later + // TODO: For now, we'll work with the concrete ChatCompletionsRequest type + // In the future, this could be made more generic using trait objects let model_name = match self.llm_provider.as_ref() { Some(llm_provider) => llm_provider.model.as_ref(), @@ -323,9 +323,10 @@ impl HttpContext for StreamContext { None => false, }; - let model_requested = deserialized_body.extract_model().to_string(); - // Note: We can't directly modify the model field through the trait, - // this would need to be handled differently in a full generic implementation + // Use the provider interface methods for cleaner interaction + let model_requested = provider + .interface() + .extract_model_from_request(&deserialized_body); info!( "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}", @@ -334,15 +335,21 @@ impl HttpContext for StreamContext { model_name.unwrap_or(&"None".to_string()), ); - if deserialized_body.is_streaming() { + // Use provider interface for streaming detection and setup + if provider + .interface() + .is_request_streaming(&deserialized_body) + { self.streaming_response = true; - } - if deserialized_body.is_streaming() { - deserialized_body.set_streaming_options(); + provider + .interface() + .prepare_request_for_streaming(&mut deserialized_body); } - // only use the tokens from the messages, excluding the metadata and json tags - let input_tokens_str = deserialized_body.extract_messages_text(); + // Use provider interface for text extraction + let input_tokens_str = provider + .interface() + .extract_text_for_tokenization(&deserialized_body); // enforce ratelimits on ingress if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) { self.send_server_error( @@ -354,12 +361,14 @@ impl HttpContext for StreamContext { } let llm_provider_str = self.llm_provider().provider_interface.to_string(); - let hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str()); + let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str()); - // convert chat completion request to llm provider specific request - let deserialized_body_bytes = match deserialized_body - .to_provider_bytes(hermes_llm_provider_id, ConversionMode::Compatible) - { + // Convert chat completion request to llm provider specific request using provider interface + let deserialized_body_bytes = match provider.interface().request_to_bytes( + &deserialized_body, + provider.id(), + ConversionMode::Compatible, + ) { Ok(bytes) => bytes, Err(e) => { warn!("Failed to serialize request body: {}", e); @@ -558,8 +567,12 @@ impl HttpContext for StreamContext { } else { debug!("non streaming response"); let provider = self.get_provider(); - let _response = match provider.parse_response(&body, ConversionMode::Compatible) { - Ok(response_box) => response_box, + let response = match provider.interface().parse_response( + &body, + provider.id(), + ConversionMode::Compatible, + ) { + Ok(response) => response, Err(e) => { warn!( "could not parse response: {}, body str: {}", @@ -579,9 +592,18 @@ impl HttpContext for StreamContext { } }; - // TODO: Extract usage information from the response box - // For now, we'll skip this until we have a better way to handle Any types - warn!("Response token counting not yet implemented with new provider structure"); + // Use provider interface to extract usage information + if let Some((prompt_tokens, completion_tokens, total_tokens)) = + provider.interface().extract_usage_from_response(&response) + { + debug!( + "Response usage: prompt={}, completion={}, total={}", + prompt_tokens, completion_tokens, total_tokens + ); + self.response_tokens = completion_tokens; + } else { + warn!("No usage information found in response"); + } } debug!(