From 63f23efda440d9174b4c58d7e915bb2e6f84f8cf Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Sat, 9 Aug 2025 11:19:23 -0700 Subject: [PATCH] saving changes, although we will need a small re-factor after this as well --- crates/common/src/configuration.rs | 17 +- crates/hermesllm/src/lib.rs | 158 +++++----------- crates/hermesllm/src/providers/interface.rs | 109 +++++++++++ crates/hermesllm/src/providers/mod.rs | 175 +++++++++++++++++- crates/hermesllm/src/providers/openai/mod.rs | 3 + .../src/providers/openai/provider.rs | 15 +- .../hermesllm/src/providers/openai/types.rs | 30 ++- .../hermesllm/src/providers/provider_enum.rs | 67 ------- crates/hermesllm/src/providers/traits.rs | 32 +++- crates/llm_gateway/src/stream_context.rs | 67 ++----- 10 files changed, 414 insertions(+), 259 deletions(-) create mode 100644 crates/hermesllm/src/providers/interface.rs delete mode 100644 crates/hermesllm/src/providers/provider_enum.rs diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index fab4948e..226a4f90 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -178,14 +178,13 @@ impl Display for LlmProviderType { } impl LlmProviderType { - /// Create a ProviderInstance from this LlmProviderType + /// Create a Provider from this LlmProviderType /// This is the main method for stream_context to get provider-specific interfaces - pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance { - use hermesllm::ProviderInstance; + pub fn create_provider(&self) -> hermesllm::Provider { + use hermesllm::{ProviderId, Provider}; - // For now, all providers use OpenAI-compatible APIs - // TODO: Return specific provider instances when implementing different APIs - ProviderInstance::from_name(&self.to_string()) + let provider_id = ProviderId::from(self.to_string().as_str()); + Provider::new(provider_id) } } @@ -265,10 +264,10 @@ impl Display for LlmProvider { } impl LlmProvider { - /// Create a ProviderInstance for this LlmProvider + /// Create a Provider for this LlmProvider /// This is a convenience method that delegates to the provider_interface - pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance { - self.provider_interface.create_provider_instance() + pub fn create_provider(&self) -> hermesllm::Provider { + self.provider_interface.create_provider() } } diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 2d0a5198..4a81ed90 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -5,121 +5,59 @@ pub mod providers; pub mod apis; pub mod clients; -// Re-export important traits -pub use providers::traits::*; -pub use providers::openai::provider::OpenAIProvider; -pub use providers::provider_enum::ProviderInstance; - - -use std::fmt::Display; -pub enum Provider { - Arch, - Mistral, - Deepseek, - Groq, - Gemini, - OpenAI, - Claude, - Github, -} - -impl From<&str> for Provider { - fn from(value: &str) -> Self { - match value.to_lowercase().as_str() { - "arch" => Provider::Arch, - "mistral" => Provider::Mistral, - "deepseek" => Provider::Deepseek, - "groq" => Provider::Groq, - "gemini" => Provider::Gemini, - "openai" => Provider::OpenAI, - "claude" => Provider::Claude, - "github" => Provider::Github, - _ => panic!("Unknown provider: {}", value), - } - } -} - -impl Provider { - /// Get the API endpoint path for this provider - pub fn api_path(&self) -> &'static str { - match self { - Provider::OpenAI => "/v1/chat/completions", - Provider::Groq => "/openai/v1/chat/completions", // Groq maps to OpenAI-compatible endpoint - Provider::Gemini => "/v1/models", // TODO: Update with correct Gemini path - Provider::Claude => "/v1/messages", // TODO: Update with correct Claude path - Provider::Mistral => "/v1/chat/completions", // Mistral uses OpenAI-compatible API - Provider::Deepseek => "/v1/chat/completions", // DeepSeek uses OpenAI-compatible API - Provider::Arch => "/v1/chat/completions", // Arch gateway endpoint - Provider::Github => "/models", // TODO: Update with correct GitHub models path - } - } - - /// Check if this provider uses OpenAI-compatible API format - pub fn uses_openai_format(&self) -> bool { - match self { - Provider::OpenAI | Provider::Groq | Provider::Mistral | Provider::Deepseek | Provider::Arch => true, - Provider::Gemini | Provider::Claude | Provider::Github => false, // These have their own formats - } - } - - /// Create a provider implementation instance for this provider - pub fn create_provider_instance(&self) -> ProviderInstance { - match self { - Provider::OpenAI => ProviderInstance::OpenAI(OpenAIProvider), - Provider::Groq => ProviderInstance::OpenAI(OpenAIProvider), // Groq uses OpenAI-compatible API - Provider::Mistral => ProviderInstance::OpenAI(OpenAIProvider), // Mistral uses OpenAI-compatible API - Provider::Deepseek => ProviderInstance::OpenAI(OpenAIProvider), // Deepseek uses OpenAI-compatible API - Provider::Arch => ProviderInstance::OpenAI(OpenAIProvider), // Arch gateway uses OpenAI-compatible API - // TODO: Implement specific providers for these when they have different APIs - Provider::Gemini => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible - Provider::Claude => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible - Provider::Github => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible - } - } -} - -impl Display for Provider { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Provider::Arch => write!(f, "Arch"), - Provider::Mistral => write!(f, "Mistral"), - Provider::Deepseek => write!(f, "Deepseek"), - Provider::Groq => write!(f, "Groq"), - Provider::Gemini => write!(f, "Gemini"), - Provider::OpenAI => write!(f, "OpenAI"), - Provider::Claude => write!(f, "Claude"), - Provider::Github => write!(f, "Github"), - } - } -} +// Re-export important types and traits +pub use providers::{ + ProviderId, Provider, ConversionMode, + ProviderInterface, ProviderRequest, ProviderResponse, + TokenUsage, StreamChunk, StreamingResponse, + OpenAIProvider +}; #[cfg(test)] mod tests { - use crate::providers::openai::types::{ChatCompletionsRequest, Message}; + use super::*; #[test] - fn openai_builder() { - let request = - ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())]) - .temperature(0.7) - .top_p(0.9) - .n(1) - .max_tokens(100) - .stream(false) - .stop(vec!["\n".to_string()]) - .presence_penalty(0.0) - .frequency_penalty(0.0) - .build() - .expect("Failed to build OpenAIRequest"); + fn test_provider_id_conversion() { + assert_eq!(ProviderId::from("openai"), ProviderId::OpenAI); + assert_eq!(ProviderId::from("mistral"), ProviderId::Mistral); + assert_eq!(ProviderId::from("groq"), ProviderId::Groq); + assert_eq!(ProviderId::from("arch"), ProviderId::Arch); + } - assert_eq!(request.model, "gpt-3.5-turbo"); - assert_eq!(request.temperature, Some(0.7)); - assert_eq!(request.top_p, Some(0.9)); - assert_eq!(request.n, Some(1)); - assert_eq!(request.max_tokens, Some(100)); - assert_eq!(request.stream, Some(false)); - assert_eq!(request.stop, Some(vec!["\n".to_string()])); - assert_eq!(request.presence_penalty, Some(0.0)); - assert_eq!(request.frequency_penalty, Some(0.0)); + #[test] + fn test_provider_api_paths() { + assert_eq!(ProviderId::OpenAI.api_path(), "/v1/chat/completions"); + assert_eq!(ProviderId::Groq.api_path(), "/openai/v1/chat/completions"); + assert_eq!(ProviderId::Mistral.api_path(), "/v1/chat/completions"); + assert_eq!(ProviderId::Arch.api_path(), "/v1/chat/completions"); + } + + #[test] + fn test_provider_openai_format_support() { + assert!(ProviderId::OpenAI.supports_openai_format()); + assert!(ProviderId::Groq.supports_openai_format()); + assert!(ProviderId::Mistral.supports_openai_format()); + assert!(ProviderId::Arch.supports_openai_format()); + assert!(!ProviderId::Gemini.supports_openai_format()); + assert!(!ProviderId::Claude.supports_openai_format()); + } + + #[test] + fn test_provider_instance_creation() { + let provider = Provider::new(ProviderId::OpenAI); + assert!(provider.has_compatible_api("/v1/chat/completions")); + assert!(!provider.has_compatible_api("/v1/embeddings")); + } + + #[test] + fn test_conversion_mode() { + let provider = Provider::new(ProviderId::OpenAI); + + let compatible_mode = provider.get_interface(false); + assert!(matches!(compatible_mode, ConversionMode::Compatible)); + + let passthrough_mode = provider.get_interface(true); + assert!(matches!(passthrough_mode, ConversionMode::Passthrough)); } } diff --git a/crates/hermesllm/src/providers/interface.rs b/crates/hermesllm/src/providers/interface.rs new file mode 100644 index 00000000..4026516e --- /dev/null +++ b/crates/hermesllm/src/providers/interface.rs @@ -0,0 +1,109 @@ +//! Provider interface trait definitions +//! +//! This module defines the core traits that all LLM providers must implement. +//! The interface is designed around v1/chat/completions API for simplicity. + +use std::error::Error; + +/// Conversion mode for provider requests/responses +#[derive(Debug, Clone, Copy)] +pub enum ConversionMode { + /// Compatible: Convert between different provider formats to ensure compatibility + Compatible, + /// Passthrough: Pass requests/responses through with minimal modification + Passthrough, +} + +/// Token usage information +pub trait TokenUsage { + fn completion_tokens(&self) -> usize; + fn prompt_tokens(&self) -> usize; + fn total_tokens(&self) -> usize; +} + +/// Error type for provider operations +pub trait ProviderError: Error + Send + Sync + 'static {} + +/// Request type that can be converted to/from provider-specific formats +pub trait ProviderRequest: Sized { + type Error: ProviderError; + + /// Parse request from raw bytes (typically JSON) + fn from_bytes(bytes: &[u8]) -> Result; + + /// Convert to bytes for sending to upstream API + fn to_bytes(&self, mode: ConversionMode) -> Result, Self::Error>; + + /// Extract the model name from the request + fn model(&self) -> &str; + + /// Check if this is a streaming request + fn is_streaming(&self) -> bool; + + /// Set streaming options (e.g., include_usage) + fn set_streaming_options(&mut self); + + /// Extract text content from messages for token counting + fn extract_text(&self) -> String; +} + +/// Response type that can be converted to/from provider-specific formats +pub trait ProviderResponse: Sized { + type Error: ProviderError; + type Usage: TokenUsage; + + /// Parse response from raw bytes (typically JSON) + fn from_bytes(bytes: &[u8], mode: ConversionMode) -> Result; + + /// Convert to bytes for sending to client + fn to_bytes(&self) -> Result, Self::Error>; + + /// Get usage information if available + fn usage(&self) -> Option<&Self::Usage>; +} + +/// Streaming response chunk +pub trait StreamChunk: Sized { + type Error: ProviderError; + type Usage: TokenUsage; + + /// Parse chunk from a line of streaming data + fn from_line(line: &str, mode: ConversionMode) -> Result, Self::Error>; + + /// Convert to line for sending to client + fn to_line(&self) -> Result; + + /// Get usage information if available (usually only in final chunk) + fn usage(&self) -> Option<&Self::Usage>; + + /// Check if this is the final chunk in the stream + fn is_final(&self) -> bool; +} + +/// Main provider interface +pub trait LLMProvider { + type Request: ProviderRequest; + type Response: ProviderResponse; + type StreamChunk: StreamChunk; + type Error: ProviderError; + + /// Create a new instance of this provider + fn new() -> Self; + + /// Get the supported API endpoints for this provider + fn supported_apis(&self) -> Vec<&'static str>; + + /// Check if the provider supports v1/chat/completions API + fn supports_chat_completions(&self) -> bool { + self.supported_apis().contains(&"/v1/chat/completions") + } + + /// Parse a request from raw bytes + fn parse_request(&self, bytes: &[u8]) -> Result; + + /// Parse a response from raw bytes + fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result; + + /// Parse streaming response chunks from raw data + fn parse_stream_chunk(&self, line: &str, mode: ConversionMode) -> Result, Self::Error>; +} diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 9e980bb5..06844441 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -1,3 +1,174 @@ -pub mod openai; +//! Provider implementations for different LLM APIs +//! +//! This module contains provider-specific implementations that handle +//! request/response conversion for different LLM service APIs. + pub mod traits; -pub mod provider_enum; +pub mod openai; + +// Re-export the main interfaces +pub use traits::*; +pub use openai::OpenAIProvider; + +use std::fmt::Display; + +/// Provider identifier enum - simple enum for identifying providers +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ProviderId { + OpenAI, + Mistral, + Deepseek, + Groq, + Gemini, + Claude, + GitHub, + Arch, +} + +impl From<&str> for ProviderId { + fn from(value: &str) -> Self { + match value.to_lowercase().as_str() { + "openai" => ProviderId::OpenAI, + "mistral" => ProviderId::Mistral, + "deepseek" => ProviderId::Deepseek, + "groq" => ProviderId::Groq, + "gemini" => ProviderId::Gemini, + "claude" => ProviderId::Claude, + "github" => ProviderId::GitHub, + "arch" => ProviderId::Arch, + _ => panic!("Unknown provider: {}", value), + } + } +} + +impl Display for ProviderId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ProviderId::OpenAI => write!(f, "OpenAI"), + ProviderId::Mistral => write!(f, "Mistral"), + ProviderId::Deepseek => write!(f, "Deepseek"), + ProviderId::Groq => write!(f, "Groq"), + ProviderId::Gemini => write!(f, "Gemini"), + ProviderId::Claude => write!(f, "Claude"), + ProviderId::GitHub => write!(f, "GitHub"), + ProviderId::Arch => write!(f, "Arch"), + } + } +} + +impl ProviderId { + /// Get the API endpoint path for this provider + pub fn api_path(&self) -> &'static str { + match self { + ProviderId::OpenAI => "/v1/chat/completions", + ProviderId::Groq => "/openai/v1/chat/completions", + ProviderId::Gemini => "/v1/models", // TODO: Update when Gemini API is implemented + ProviderId::Claude => "/v1/messages", // TODO: Update when Claude API is implemented + ProviderId::Mistral => "/v1/chat/completions", + ProviderId::Deepseek => "/v1/chat/completions", + ProviderId::GitHub => "/models", // TODO: Update when GitHub models API is implemented + ProviderId::Arch => "/v1/chat/completions", + } + } + + /// Check if this provider supports OpenAI v1/chat/completions API format + pub fn supports_openai_format(&self) -> bool { + matches!( + self, + ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch + ) + } +} + +/// Enum for dynamic dispatch of provider instances +/// For now, most providers use OpenAI-compatible format +pub enum Provider { + OpenAI(OpenAIProvider, ProviderId), + // TODO: Add specific implementations when providers have different APIs + // Mistral(MistralProvider, ProviderId), + // Groq(GroqProvider, ProviderId), + // etc. +} + +impl Provider { + /// Create a provider instance from a provider ID + pub fn new(id: ProviderId) -> Self { + match id { + // For now, all providers that support v1/chat/completions use OpenAI format + ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch => { + Provider::OpenAI(OpenAIProvider, id) + } + // TODO: Implement specific providers when they have different APIs + ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { + Provider::OpenAI(OpenAIProvider, id) // Fallback to OpenAI for now + } + } + } + + /// Get the provider ID + pub fn id(&self) -> ProviderId { + match self { + Provider::OpenAI(_, id) => *id, + } + } + + /// Check if this provider has a compatible API with the client request + pub fn has_compatible_api(&self, api_path: &str) -> bool { + match self { + Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path), + } + } + + /// Get the interface implementation for this provider + pub fn get_interface(&self, passthrough: bool) -> ConversionMode { + match self { + Provider::OpenAI(provider, _) => provider.get_interface(passthrough), + } + } + + /// Parse a request from raw bytes - returns the concrete OpenAI request type for now + pub fn parse_request(&self, bytes: &[u8]) -> Result> { + match self { + Provider::OpenAI(_, _) => { + use crate::apis::openai::ChatCompletionsRequest; + use crate::providers::traits::ProviderRequest; + + match ChatCompletionsRequest::try_from_bytes(bytes) { + Ok(req) => Ok(req), + Err(e) => Err(Box::new(e)), + } + } + } + } + + /// Parse a response from raw bytes - returns the concrete OpenAI response type for now + pub fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result> { + match self { + Provider::OpenAI(_, _) => { + use crate::apis::openai::ChatCompletionsResponse; + use crate::providers::traits::ProviderResponse; + + let provider_id = self.id(); + match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + Ok(resp) => Ok(resp), + Err(e) => Err(Box::new(e)), + } + } + } + } + + /// Convert a request to bytes for sending to upstream API + pub fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, mode: ConversionMode) -> Result, Box> { + match self { + Provider::OpenAI(_, _) => { + use crate::providers::traits::ProviderRequest; + + let provider_id = self.id(); + match request.to_provider_bytes(provider_id, mode) { + Ok(bytes) => Ok(bytes), + Err(e) => Err(Box::new(e)), + } + } + } + } +} diff --git a/crates/hermesllm/src/providers/openai/mod.rs b/crates/hermesllm/src/providers/openai/mod.rs index 1cc60f4d..60b5f265 100644 --- a/crates/hermesllm/src/providers/openai/mod.rs +++ b/crates/hermesllm/src/providers/openai/mod.rs @@ -1,3 +1,6 @@ pub mod builder; pub mod types; pub mod provider; + +// Re-export the main provider +pub use provider::OpenAIProvider; diff --git a/crates/hermesllm/src/providers/openai/provider.rs b/crates/hermesllm/src/providers/openai/provider.rs index 8d98be4a..3f5677c6 100644 --- a/crates/hermesllm/src/providers/openai/provider.rs +++ b/crates/hermesllm/src/providers/openai/provider.rs @@ -2,7 +2,6 @@ use crate::apis::openai::*; use crate::providers::traits::*; -use crate::Provider; // Simple error type for OpenAI API operations #[derive(Debug, thiserror::Error)] @@ -73,6 +72,14 @@ impl ProviderInterface for OpenAIProvider { type Response = ChatCompletionsResponse; type StreamingResponse = OpenAIStreamingResponse; type Usage = Usage; + + fn has_compatible_api(&self, api_path: &str) -> bool { + api_path == "/v1/chat/completions" + } + + fn supported_apis(&self) -> Vec<&'static str> { + vec!["/v1/chat/completions"] + } } // ============================================================================ @@ -87,7 +94,7 @@ impl ProviderRequest for ChatCompletionsRequest { Ok(serde_json::from_str(s)?) } - fn to_provider_bytes(&self, _provider: Provider) -> Result, Self::Error> { + fn to_provider_bytes(&self, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result, Self::Error> { Ok(serde_json::to_vec(self)?) } @@ -142,7 +149,7 @@ impl ProviderResponse for ChatCompletionsResponse { type Error = OpenAIApiError; type Usage = Usage; - fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result { + fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result { let s = std::str::from_utf8(bytes)?; Ok(serde_json::from_str(s)?) } @@ -164,7 +171,7 @@ impl StreamingResponse for OpenAIStreamingResponse { type Error = OpenAIApiError; type Chunk = ChatCompletionsStreamResponse; - fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result { + fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result { let s = std::str::from_utf8(bytes)?; Ok(OpenAIStreamingResponse::new(s.to_string())) } diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index 7dea64df..6e5b676a 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -8,7 +8,7 @@ use std::convert::TryFrom; use std::str; use thiserror::Error; -use crate::Provider; +use crate::providers::ProviderId; #[derive(Debug, Error)] pub enum OpenAIError { @@ -144,28 +144,26 @@ impl TryFrom<&[u8]> for ChatCompletionsResponse { } } -impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse { +impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for ChatCompletionsResponse { type Error = OpenAIError; - fn try_from(input: (&'a [u8], &'a Provider)) -> Result { + fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result { // Use input.provider as needed, if necessary serde_json::from_slice(input.0).map_err(OpenAIError::from) } } impl ChatCompletionsRequest { - pub fn to_bytes(&self, provider: Provider) -> Result> { + pub fn to_bytes(&self, provider: ProviderId) -> Result> { match provider { - Provider::OpenAI - | Provider::Arch - | Provider::Deepseek - | Provider::Mistral - | Provider::Groq - | Provider::Gemini - | Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from), - _ => Err(OpenAIError::UnsupportedProvider { - provider: provider.to_string(), - }), + ProviderId::OpenAI + | ProviderId::Arch + | ProviderId::Deepseek + | ProviderId::Mistral + | ProviderId::Groq + | ProviderId::Gemini + | ProviderId::Claude + | ProviderId::GitHub => serde_json::to_vec(self).map_err(OpenAIError::from), } } } @@ -262,10 +260,10 @@ where } } -impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter> { +impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for SseChatCompletionIter> { type Error = OpenAIError; - fn try_from(input: (&'a [u8], &'a Provider)) -> Result { + fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result { let s = std::str::from_utf8(input.0)?; // Use input.provider as needed Ok(SseChatCompletionIter::new(s.lines())) diff --git a/crates/hermesllm/src/providers/provider_enum.rs b/crates/hermesllm/src/providers/provider_enum.rs deleted file mode 100644 index ea1efa54..00000000 --- a/crates/hermesllm/src/providers/provider_enum.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::providers::traits::*; -use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse}; -use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; - -/// Enum that wraps all possible providers for dynamic dispatch -pub enum ProviderInstance { - OpenAI(OpenAIProvider), - // TODO: Add other providers as they are implemented - // Anthropic(AnthropicProvider), - // Mistral(MistralProvider), - // etc. -} - -impl ProviderInstance { - /// Creates a provider from a provider name string - pub fn from_name(name: &str) -> Self { - match name.to_lowercase().as_str() { - "openai" | "groq" | "gemini" | "mistral" | "deepseek" | "arch" | "claude" => { - ProviderInstance::OpenAI(OpenAIProvider) - } - // TODO: Add other providers when implemented - // "claude" | "anthropic" => ProviderInstance::Anthropic(AnthropicProvider), - // "mistral" => ProviderInstance::Mistral(MistralProvider), - _ => { - // Default to OpenAI for unknown providers - ProviderInstance::OpenAI(OpenAIProvider) - } - } - } - - /// Parse request from bytes using the appropriate provider - pub fn parse_request(&self, bytes: &[u8]) -> Result> { - match self { - ProviderInstance::OpenAI(_) => { - ChatCompletionsRequest::try_from_bytes(bytes).map_err(|e| Box::new(e) as Box) - } - // TODO: Add other provider cases when implemented - } - } - - /// Parse response from bytes using the appropriate provider - pub fn parse_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result> { - match self { - ProviderInstance::OpenAI(_) => { - ChatCompletionsResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box) - } - // TODO: Add other provider cases when implemented - } - } - - /// Parse streaming response from bytes using the appropriate provider - pub fn parse_streaming_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result> { - match self { - ProviderInstance::OpenAI(_) => { - OpenAIStreamingResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box) - } - // TODO: Add other provider cases when implemented - } - } -} - -impl ProviderInterface for ProviderInstance { - type Request = ChatCompletionsRequest; - type Response = ChatCompletionsResponse; - type StreamingResponse = OpenAIStreamingResponse; - type Usage = Usage; -} diff --git a/crates/hermesllm/src/providers/traits.rs b/crates/hermesllm/src/providers/traits.rs index c845e57e..8bc8d519 100644 --- a/crates/hermesllm/src/providers/traits.rs +++ b/crates/hermesllm/src/providers/traits.rs @@ -4,7 +4,15 @@ //! handling of LLM requests and responses in the gateway. use std::error::Error; -use crate::Provider; + +/// Conversion mode for provider requests/responses +#[derive(Debug, Clone, Copy)] +pub enum ConversionMode { + /// Compatible: Convert between different provider formats to ensure compatibility + Compatible, + /// Passthrough: Pass requests/responses through with minimal modification + Passthrough, +} /// Trait for provider-specific request types pub trait ProviderRequest: Sized { @@ -14,7 +22,7 @@ pub trait ProviderRequest: Sized { fn try_from_bytes(bytes: &[u8]) -> Result; /// Convert to provider-specific format - fn to_provider_bytes(&self, provider: Provider) -> Result, Self::Error>; + fn to_provider_bytes(&self, provider: super::ProviderId, mode: ConversionMode) -> Result, Self::Error>; /// Extract the model name from the request fn extract_model(&self) -> &str; @@ -42,7 +50,7 @@ pub trait ProviderResponse: Sized { type Usage: TokenUsage; /// Parse response from raw bytes - fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result; + fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result; /// Get usage information if available fn usage(&self) -> Option<&Self::Usage>; @@ -62,7 +70,7 @@ pub trait StreamingResponse: Iterator> + type Chunk: StreamChunk; /// Parse streaming response from raw bytes - fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result; + fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result; } /// Main provider interface trait @@ -71,4 +79,20 @@ pub trait ProviderInterface { type Response: ProviderResponse; type StreamingResponse: StreamingResponse; type Usage: TokenUsage; + + /// Check if this provider has a compatible API with the client request + fn has_compatible_api(&self, api_path: &str) -> bool; + + /// Get the interface implementation for this provider + /// passthrough: if true, use provider-specific format; if false, use compatible format + fn get_interface(&self, passthrough: bool) -> ConversionMode { + if passthrough { + ConversionMode::Passthrough + } else { + ConversionMode::Compatible + } + } + + /// Get supported API endpoints for this provider + fn supported_apis(&self) -> Vec<&'static str>; } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index fa967397..0beb18d5 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -10,9 +10,7 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::{ - Provider, ProviderInstance, ProviderRequest, ProviderResponse, StreamChunk, TokenUsage, -}; +use hermesllm::{ConversionMode, Provider, ProviderId, ProviderRequest}; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -76,8 +74,8 @@ impl StreamContext { .expect("the provider should be set when asked for it") } - fn get_provider_instance(&self) -> ProviderInstance { - self.llm_provider().create_provider_instance() + fn get_provider(&self) -> Provider { + self.llm_provider().create_provider() } fn select_llm_provider(&mut self) { @@ -295,9 +293,9 @@ impl HttpContext for StreamContext { } }; - let provider_instance = self.get_provider_instance(); + let provider = self.get_provider(); - let mut deserialized_body = match provider_instance.parse_request(&body_bytes) { + let mut deserialized_body = match provider.parse_request(&body_bytes) { Ok(deserialized) => deserialized, Err(e) => { debug!( @@ -356,10 +354,11 @@ impl HttpContext for StreamContext { } let llm_provider_str = self.llm_provider().provider_interface.to_string(); - let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); + let hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str()); // convert chat completion request to llm provider specific request - let deserialized_body_bytes = match deserialized_body.to_provider_bytes(hermes_llm_provider) + let deserialized_body_bytes = match deserialized_body + .to_provider_bytes(hermes_llm_provider_id, ConversionMode::Compatible) { Ok(bytes) => bytes, Err(e) => { @@ -529,42 +528,16 @@ impl HttpContext for StreamContext { } let llm_provider_str = self.llm_provider().provider_interface.to_string(); - let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); + let _provider_id = ProviderId::from(llm_provider_str.as_str()); if self.streaming_response { - // Use the provider instance to parse streaming response - let provider_instance = self.get_provider_instance(); + // TODO: Implement streaming response parsing with new provider structure + warn!( + "Streaming response parsing not yet fully implemented with new provider structure" + ); - let streaming_events = - match provider_instance.parse_streaming_response(&body, &hermes_llm_provider) { - Ok(events) => events, - Err(e) => { - warn!( - "could not parse response: {}, body str: {}", - e, - String::from_utf8_lossy(&body) - ); - return Action::Continue; - } - }; - - for event_result in streaming_events { - match event_result { - Ok(event) => { - if let Some(usage) = event.usage() { - self.response_tokens += usage.completion_tokens(); - } - } - Err(e) => { - warn!("error in response event: {}", e); - continue; - } - } - } - - // Compute TTFT if not already recorded + // For now, just compute TTFT and continue if self.ttft_duration.is_none() { - // if let Some(start_time) = self.start_time { let current_time = get_current_time().unwrap(); self.ttft_time = Some(current_time_ns()); match current_time.duration_since(self.start_time) { @@ -584,9 +557,9 @@ impl HttpContext for StreamContext { } } else { debug!("non streaming response"); - let provider_instance = self.get_provider_instance(); - let response = match provider_instance.parse_response(&body, &hermes_llm_provider) { - Ok(de) => de, + let provider = self.get_provider(); + let _response = match provider.parse_response(&body, ConversionMode::Compatible) { + Ok(response_box) => response_box, Err(e) => { warn!( "could not parse response: {}, body str: {}", @@ -606,9 +579,9 @@ impl HttpContext for StreamContext { } }; - if let Some(usage) = response.usage() { - self.response_tokens += usage.completion_tokens(); - } + // TODO: Extract usage information from the response box + // For now, we'll skip this until we have a better way to handle Any types + warn!("Response token counting not yet implemented with new provider structure"); } debug!(