diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 15c77bed..8815a369 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -5,7 +5,10 @@ use std::collections::HashMap; use std::fmt::Display; use thiserror::Error; -use crate::{providers::ProviderRequestError, ConversionMode, ProviderRequest}; + + +use crate::providers::request::{ProviderRequest, ProviderRequestError}; +use crate::providers::response::{ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, TokenUsage}; use super::ApiDefinition; // ============================================================================ @@ -127,8 +130,6 @@ pub struct Message { pub tool_call_id: Option, } - - #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ResponseMessage { @@ -449,9 +450,92 @@ pub struct StreamOptions { pub include_usage: Option, } -/// ============================================================================ -/// OpenAI Provider Request Wrapper -/// ============================================================================ +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelDetail { + pub id: String, + pub object: String, + pub created: usize, + pub owned_by: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ModelObject { + #[serde(rename = "list")] + List, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Models { + pub object: ModelObject, + pub data: Vec, +} + + +// Error type for streaming operations +#[derive(Debug, thiserror::Error)] +pub enum OpenAIStreamError { + #[error("JSON parsing error: {0}")] + JsonError(#[from] serde_json::Error), + #[error("UTF-8 parsing error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("Invalid streaming data: {0}")] + InvalidStreamingData(String), +} + +#[derive(Debug, Error)] +pub enum OpenAIError { + #[error("json error: {0}")] + JsonParseError(#[from] serde_json::Error), + #[error("utf8 parsing error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("invalid streaming data err {source}, data: {data}")] + InvalidStreamingData { + source: serde_json::Error, + data: String, + }, + #[error("unsupported provider: {provider}")] + UnsupportedProvider { provider: String }, +} + +// ============================================================================ +/// Trait Implementations +/// =========================================================================== + + +/// Parameterized conversion for ChatCompletionsRequest +impl TryFrom<&[u8]> for ChatCompletionsRequest { + type Error = OpenAIStreamError; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(OpenAIStreamError::from) + } +} + +/// Parameterized conversion for ChatCompletionsResponse +impl TryFrom<&[u8]> for ChatCompletionsResponse { + type Error = OpenAIStreamError; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(OpenAIStreamError::from) + } +} + +/// Implementation of TokenUsage for OpenAI Usage type +impl TokenUsage for Usage { + fn completion_tokens(&self) -> usize { + self.completion_tokens as usize + } + + fn prompt_tokens(&self) -> usize { + self.prompt_tokens as usize + } + + fn total_tokens(&self) -> usize { + self.total_tokens as usize + } +} + +/// Implementation of ProviderRequest for ChatCompletionsRequest impl ProviderRequest for ChatCompletionsRequest { fn model(&self) -> &str { &self.model @@ -493,144 +577,29 @@ impl ProviderRequest for ChatCompletionsRequest { }) } - fn to_provider_bytes(&self, mode: ConversionMode) -> Result, ProviderRequestError> { - match mode { - ConversionMode::Compatible | ConversionMode::Passthrough => { - serde_json::to_vec(&self).map_err(|e| ProviderRequestError { - message: format!("Failed to serialize OpenAI request: {}", e), - source: Some(Box::new(e)), - }) - } - } + fn to_bytes(&self) -> Result, ProviderRequestError> { + serde_json::to_vec(&self).map_err(|e| ProviderRequestError { + message: format!("Failed to serialize OpenAI request: {}", e), + source: Some(Box::new(e)), + }) } } -// ============================================================================ -// STREAMING SUPPORT -// ============================================================================ - -use crate::providers::traits::{ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, TokenUsage}; - -// Direct implementation of ProviderResponse on ChatCompletionsResponse +/// Implementation of ProviderResponse for ChatCompletionsResponse impl ProviderResponse for ChatCompletionsResponse { fn usage(&self) -> Option<&dyn TokenUsage> { Some(&self.usage) } -} -// ============================================================================ -// PARAMETERIZED CONVERSIONS FOR PROVIDER FUNCTIONS -// ============================================================================ - -use crate::providers::ProviderId; - -/// Parameterized conversion for ChatCompletionsRequest -impl TryFrom<(&[u8], &ProviderId)> for ChatCompletionsRequest { - type Error = OpenAIStreamError; - - fn try_from((bytes, _provider_id): (&[u8], &ProviderId)) -> Result { - serde_json::from_slice(bytes).map_err(OpenAIStreamError::from) + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + Some(( + self.usage.prompt_tokens(), + self.usage.completion_tokens(), + self.usage.total_tokens(), + )) } } -/// Parameterized conversion for ChatCompletionsResponse -impl TryFrom<(&[u8], &ProviderId)> for ChatCompletionsResponse { - type Error = OpenAIStreamError; - - fn try_from((bytes, _provider_id): (&[u8], &ProviderId)) -> Result { - serde_json::from_slice(bytes).map_err(OpenAIStreamError::from) - } -} - -// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse -impl ProviderStreamResponse for ChatCompletionsStreamResponse { - fn content_delta(&self) -> Option<&str> { - self.choices - .first() - .and_then(|choice| choice.delta.content.as_deref()) - } - - fn is_final(&self) -> bool { - self.choices - .first() - .map(|choice| choice.finish_reason.is_some()) - .unwrap_or(false) - } - - fn role(&self) -> Option<&str> { - self.choices - .first() - .and_then(|choice| choice.delta.role.as_ref().map(|r| match r { - Role::System => "system", - Role::User => "user", - Role::Assistant => "assistant", - Role::Tool => "tool", - })) - } -} - -// Implementation of TokenUsage for OpenAI Usage type -impl TokenUsage for Usage { - fn completion_tokens(&self) -> usize { - self.completion_tokens as usize - } - - fn prompt_tokens(&self) -> usize { - self.prompt_tokens as usize - } - - fn total_tokens(&self) -> usize { - self.total_tokens as usize - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelDetail { - pub id: String, - pub object: String, - pub created: usize, - pub owned_by: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ModelObject { - #[serde(rename = "list")] - List, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Models { - pub object: ModelObject, - pub data: Vec, -} - -// Error type for streaming operations -#[derive(Debug, thiserror::Error)] -pub enum OpenAIStreamError { - #[error("JSON parsing error: {0}")] - JsonError(#[from] serde_json::Error), - #[error("UTF-8 parsing error: {0}")] - Utf8Error(#[from] std::str::Utf8Error), - #[error("Invalid streaming data: {0}")] - InvalidStreamingData(String), -} - -#[derive(Debug, Error)] -pub enum OpenAIError { - #[error("json error: {0}")] - JsonParseError(#[from] serde_json::Error), - #[error("utf8 parsing error: {0}")] - Utf8Error(#[from] std::str::Utf8Error), - #[error("invalid streaming data err {source}, data: {data}")] - InvalidStreamingData { - source: serde_json::Error, - data: String, - }, - #[error("unsupported provider: {provider}")] - UnsupportedProvider { provider: String }, -} - - /// SSE-based streaming iterator for OpenAI chat completions /// Implements ProviderStreamResponseIter directly pub struct SseChatCompletionIter @@ -696,6 +665,34 @@ where } +// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse +impl ProviderStreamResponse for ChatCompletionsStreamResponse { + fn content_delta(&self) -> Option<&str> { + self.choices + .first() + .and_then(|choice| choice.delta.content.as_deref()) + } + + fn is_final(&self) -> bool { + self.choices + .first() + .map(|choice| choice.finish_reason.is_some()) + .unwrap_or(false) + } + + fn role(&self) -> Option<&str> { + self.choices + .first() + .and_then(|choice| choice.delta.role.as_ref().map(|r| match r { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + })) + } +} + + #[cfg(test)] mod tests { use super::*; diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index feace88a..ad9b3e33 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -6,13 +6,10 @@ pub mod apis; pub mod clients; // Re-export important types and traits -pub use providers::{ - ProviderId, ConversionMode, - ProviderRequest, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, - TokenUsage, - try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes, - has_compatible_api, supported_apis -}; +pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError}; +pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, try_streaming_from_bytes}; +pub use providers::id::ProviderId; +pub use providers::adapters::{has_compatible_api, supported_apis}; #[cfg(test)] mod tests { @@ -58,7 +55,7 @@ mod tests { ] }"#; - let result = try_request_from_bytes(json_request.as_bytes(), &ProviderId::OpenAI); + let result: Result = ProviderRequestType::try_from(json_request.as_bytes()); assert!(result.is_ok()); let request = result.unwrap(); @@ -74,7 +71,7 @@ mod tests { data: [DONE] "#; - let result = try_streaming_from_bytes(sse_data.as_bytes(), &ProviderId::OpenAI, ConversionMode::Passthrough); + let result = try_streaming_from_bytes(sse_data.as_bytes(), &ProviderId::OpenAI); assert!(result.is_ok()); let mut streaming_response = result.unwrap(); diff --git a/crates/hermesllm/src/providers/adapters.rs b/crates/hermesllm/src/providers/adapters.rs new file mode 100644 index 00000000..a001cf09 --- /dev/null +++ b/crates/hermesllm/src/providers/adapters.rs @@ -0,0 +1,39 @@ +use crate::providers::id::ProviderId; + +#[derive(Debug, Clone)] +pub enum AdapterType { + OpenAICompatible, + // Future: Claude, Gemini, etc. +} + +/// Provider adapter configuration +#[derive(Debug, Clone)] +pub struct ProviderConfig { + pub supported_apis: &'static [&'static str], + pub adapter_type: AdapterType, +} + +/// Check if provider has compatible API +pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool { + let config = get_provider_config(provider_id); + config.supported_apis.iter().any(|&supported| supported == api_path) +} + +/// Get supported APIs for provider +pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> { + let config = get_provider_config(provider_id); + config.supported_apis.to_vec() +} + +/// Get provider configuration +pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig { + match provider_id { + ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek + | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { + ProviderConfig { + supported_apis: &["/v1/chat/completions"], + adapter_type: AdapterType::OpenAICompatible, + } + } + } +} diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs new file mode 100644 index 00000000..2c0c494e --- /dev/null +++ b/crates/hermesllm/src/providers/id.rs @@ -0,0 +1,45 @@ +use std::fmt::Display; + +/// Provider identifier enum - simple enum for identifying providers +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ProviderId { + OpenAI, + Mistral, + Deepseek, + Groq, + Gemini, + Claude, + GitHub, + Arch, +} + +impl From<&str> for ProviderId { + fn from(value: &str) -> Self { + match value.to_lowercase().as_str() { + "openai" => ProviderId::OpenAI, + "mistral" => ProviderId::Mistral, + "deepseek" => ProviderId::Deepseek, + "groq" => ProviderId::Groq, + "gemini" => ProviderId::Gemini, + "claude" => ProviderId::Claude, + "github" => ProviderId::GitHub, + "arch" => ProviderId::Arch, + _ => panic!("Unknown provider: {}", value), + } + } +} + +impl Display for ProviderId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ProviderId::OpenAI => write!(f, "OpenAI"), + ProviderId::Mistral => write!(f, "Mistral"), + ProviderId::Deepseek => write!(f, "Deepseek"), + ProviderId::Groq => write!(f, "Groq"), + ProviderId::Gemini => write!(f, "Gemini"), + ProviderId::Claude => write!(f, "Claude"), + ProviderId::GitHub => write!(f, "GitHub"), + ProviderId::Arch => write!(f, "Arch"), + } + } +} diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 1b69c777..0f0574c3 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -2,57 +2,13 @@ //! //! This module contains provider-specific implementations that handle //! request/response conversion for different LLM service APIs. +//! +pub mod id; +pub mod request; +pub mod response; +pub mod adapters; -pub mod traits; -pub mod openai; - -// Re-export the main interfaces -pub use traits::*; -// Note: OpenAIProvider has been deprecated in favor of function-based approach -// OpenAI functionality is accessed through openai::builder and openai::types modules - -use std::fmt::Display; - -/// Provider identifier enum - simple enum for identifying providers -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum ProviderId { - OpenAI, - Mistral, - Deepseek, - Groq, - Gemini, - Claude, - GitHub, - Arch, -} - -impl From<&str> for ProviderId { - fn from(value: &str) -> Self { - match value.to_lowercase().as_str() { - "openai" => ProviderId::OpenAI, - "mistral" => ProviderId::Mistral, - "deepseek" => ProviderId::Deepseek, - "groq" => ProviderId::Groq, - "gemini" => ProviderId::Gemini, - "claude" => ProviderId::Claude, - "github" => ProviderId::GitHub, - "arch" => ProviderId::Arch, - _ => panic!("Unknown provider: {}", value), - } - } -} - -impl Display for ProviderId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ProviderId::OpenAI => write!(f, "OpenAI"), - ProviderId::Mistral => write!(f, "Mistral"), - ProviderId::Deepseek => write!(f, "Deepseek"), - ProviderId::Groq => write!(f, "Groq"), - ProviderId::Gemini => write!(f, "Gemini"), - ProviderId::Claude => write!(f, "Claude"), - ProviderId::GitHub => write!(f, "GitHub"), - ProviderId::Arch => write!(f, "Arch"), - } - } -} +pub use id::ProviderId; +pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ; +pub use response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage }; +pub use adapters::*; diff --git a/crates/hermesllm/src/providers/openai/mod.rs b/crates/hermesllm/src/providers/openai/mod.rs deleted file mode 100644 index d82d5ab0..00000000 --- a/crates/hermesllm/src/providers/openai/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -// Re-export the main types and builder functionality -pub use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse}; - -// Note: The OpenAIProvider struct has been deprecated in favor of the function-based approach in traits.rs -// All provider functionality is now accessed through try_request_from_bytes, try_response_from_bytes, etc. diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs new file mode 100644 index 00000000..577be2b7 --- /dev/null +++ b/crates/hermesllm/src/providers/request.rs @@ -0,0 +1,124 @@ + +use crate::apis::openai::ChatCompletionsRequest; +use super::{ProviderId, get_provider_config, AdapterType}; +use std::error::Error; +use std::fmt; +pub enum ProviderRequestType { + ChatCompletionsRequest(ChatCompletionsRequest), + //MessagesRequest(MessagesRequest), + //add more request types here +} + +impl TryFrom<&[u8]> for ProviderRequestType { + type Error = std::io::Error; + + // if passing bytes without provider id we assume the request is in OpenAI format + fn try_from(bytes: &[u8]) -> Result { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } +} + +impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType { + type Error = std::io::Error; + + fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result { + let config = get_provider_config(provider_id); + match config.adapter_type { + AdapterType::OpenAICompatible => { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } + // Future: handle other adapter types like Claude + } + } +} + +pub trait ProviderRequest: Send + Sync { + /// Extract the model name from the request + fn model(&self) -> &str; + + /// Set the model name for the request + fn set_model(&mut self, model: String); + + /// Check if this is a streaming request + fn is_streaming(&self) -> bool; + + /// Set streaming options (e.g., include_usage) + fn set_streaming_options(&mut self); + + /// Extract text content from messages for token counting + fn extract_messages_text(&self) -> String; + + /// Extract the user message for tracing/logging purposes + fn extract_user_message(&self) -> Option; + + /// Convert the request to bytes for transmission + fn to_bytes(&self) -> Result, ProviderRequestError>; +} + +impl ProviderRequest for ProviderRequestType { + fn model(&self) -> &str { + match self { + Self::ChatCompletionsRequest(r) => r.model(), + } + } + + fn set_model(&mut self, model: String) { + match self { + Self::ChatCompletionsRequest(r) => r.set_model(model), + } + } + + fn is_streaming(&self) -> bool { + match self { + Self::ChatCompletionsRequest(r) => r.is_streaming(), + } + } + + fn set_streaming_options(&mut self) { + match self { + Self::ChatCompletionsRequest(r) => r.set_streaming_options(), + } + } + + fn extract_messages_text(&self) -> String { + match self { + Self::ChatCompletionsRequest(r) => r.extract_messages_text(), + } + } + + fn extract_user_message(&self) -> Option { + match self { + Self::ChatCompletionsRequest(r) => r.extract_user_message(), + } + } + + fn to_bytes(&self) -> Result, ProviderRequestError> { + match self { + Self::ChatCompletionsRequest(r) => r.to_bytes(), + } + } +} + + +/// Error types for provider operations +#[derive(Debug)] +pub struct ProviderRequestError { + pub message: String, + pub source: Option>, +} + +impl fmt::Display for ProviderRequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Provider request error: {}", self.message) + } +} + +impl Error for ProviderRequestError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) + } +} diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs new file mode 100644 index 00000000..64e57e63 --- /dev/null +++ b/crates/hermesllm/src/providers/response.rs @@ -0,0 +1,142 @@ +use std::error::Error; +use std::fmt; + +use crate::apis::openai::ChatCompletionsResponse; +use crate::apis::openai::ChatCompletionsStreamResponse; +use crate::providers::id::ProviderId; +use crate::providers::adapters::{get_provider_config, AdapterType}; + +pub enum ProviderResponseType { + ChatCompletionsResponse(ChatCompletionsResponse), + //MessagesResponse(MessagesResponse), +} + +pub enum ProviderStreamResponseType { + ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), + //MessagesStreamResponse(MessagesStreamMessage), +} + +impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType { + type Error = std::io::Error; + + fn try_from((bytes, provider_id): (&[u8], ProviderId)) -> Result { + let config = get_provider_config(&provider_id); + match config.adapter_type { + AdapterType::OpenAICompatible => { + let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response)) + } + // Future: handle other adapter types like Claude + } + } +} +pub trait ProviderResponse: Send + Sync { + /// Get usage information if available - returns dynamic trait object + fn usage(&self) -> Option<&dyn TokenUsage>; + + /// Extract token counts for metrics + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) + } +} + +pub trait ProviderStreamResponse: Send + Sync { + /// Get the content delta for this chunk + fn content_delta(&self) -> Option<&str>; + + /// Check if this is the final chunk in the stream + fn is_final(&self) -> bool; + + /// Get role information if available + fn role(&self) -> Option<&str>; +} + +/// Trait for streaming response iterators +pub trait ProviderStreamResponseIter: Iterator, Box>> + Send + Sync { + +} + + +impl ProviderResponse for ProviderResponseType { + fn usage(&self) -> Option<&dyn TokenUsage> { + match self { + ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(), + // Future: ProviderResponseType::MessagesResponse(resp) => resp.usage(), + } + } + + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + match self { + ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(), + // Future: ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(), + } + } +} + +impl ProviderStreamResponse for ProviderStreamResponseType { + fn content_delta(&self) -> Option<&str> { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(), + // Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.content_delta(), + } + } + + fn is_final(&self) -> bool { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(), + // Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.is_final(), + } + } + + fn role(&self) -> Option<&str> { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(), + // Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.role(), + } + } +} + +/// Trait for token usage information +pub trait TokenUsage { + fn completion_tokens(&self) -> usize; + fn prompt_tokens(&self) -> usize; + fn total_tokens(&self) -> usize; +} + + +#[derive(Debug)] +pub struct ProviderResponseError { + pub message: String, + pub source: Option>, +} + + +impl fmt::Display for ProviderResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Provider response error: {}", self.message) + } +} + +impl Error for ProviderResponseError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) + } +} + +/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object +pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result, Box> { + let config = get_provider_config(provider_id); + + match config.adapter_type { + AdapterType::OpenAICompatible => { + // Parse SSE (Server-Sent Events) streaming data + let s = std::str::from_utf8(bytes)?; + let lines: Vec = s.lines().map(|line| line.to_string()).collect(); + let iter = crate::apis::openai::SseChatCompletionIter::new(lines.into_iter()); + + // Return the iterator directly - it implements ProviderStreamResponseIter + Ok(Box::new(iter)) + } + } +} diff --git a/crates/hermesllm/src/providers/traits.rs b/crates/hermesllm/src/providers/traits.rs deleted file mode 100644 index 5148ed7a..00000000 --- a/crates/hermesllm/src/providers/traits.rs +++ /dev/null @@ -1,247 +0,0 @@ -//! Provider traits for generic request/response handling -//! -//! This module defines the core traits that enable provider-agnostic -//! handling of LLM requests and responses in the gateway. - -use std::error::Error; -use std::fmt; - -/// Trait for provider-specific request types -pub trait ProviderRequest: Send + Sync { - /// Extract the model name from the request - fn model(&self) -> &str; - - /// Set the model name for the request - fn set_model(&mut self, model: String); - - /// Check if this is a streaming request - fn is_streaming(&self) -> bool; - - /// Set streaming options (e.g., include_usage) - fn set_streaming_options(&mut self); - - /// Extract text content from messages for token counting - fn extract_messages_text(&self) -> String; - - /// Extract the user message for tracing/logging purposes - fn extract_user_message(&self) -> Option; - - /// Convert to provider-specific format - fn to_provider_bytes(&self, mode: ConversionMode) -> Result, ProviderRequestError>; -} - -/// Trait for provider-specific response types -pub trait ProviderResponse: Send + Sync { - /// Get usage information if available - returns dynamic trait object - fn usage(&self) -> Option<&dyn TokenUsage>; - - /// Extract token counts for metrics - fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { - self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) - } -} - -/// Trait for provider-specific streaming response types -pub trait ProviderStreamResponse: Send + Sync { - /// Get the content delta for this chunk - fn content_delta(&self) -> Option<&str>; - - /// Check if this is the final chunk in the stream - fn is_final(&self) -> bool; - - /// Get role information if available - fn role(&self) -> Option<&str>; -} - -/// Trait for streaming response iterators -pub trait ProviderStreamResponseIter: Iterator, Box>> + Send + Sync { - // No additional methods needed - just the Iterator constraint with proper bounds -} - -/// Conversion mode for provider requests/responses -#[derive(Debug, Clone, Copy)] -pub enum ConversionMode { - /// Compatible: Convert between different provider formats to ensure compatibility - Compatible, - /// Passthrough: Pass requests/responses through with minimal modification - Passthrough, -} - -/// Trait for token usage information -pub trait TokenUsage { - fn completion_tokens(&self) -> usize; - fn prompt_tokens(&self) -> usize; - fn total_tokens(&self) -> usize; -} - -// ============================================================================ -// PROVIDER FUNCTIONS - NO TRAITS, JUST PARAMETERIZED CONVERSION -// ============================================================================ -// -// ARCHITECTURAL DECISION: Function-based Provider API -// -// We chose this function-based approach over the original ProviderInterface trait -// for several critical reasons: -// -// 1. TRAIT OBJECT LIMITATION: -// - The original ProviderInterface had associated types (Request, Response, etc.) -// - Traits with associated types cannot be used as trait objects (Box) -// - This prevented dynamic provider selection at runtime based on request headers -// - Error: "the trait `ProviderInterface` cannot be made into an object" -// -// 2. DYNAMIC PROVIDER SELECTION REQUIREMENT: -// - The gateway needs to select providers dynamically based on incoming headers -// - Cannot know provider type at compile time - must dispatch at runtime -// - Need ability to return generic trait objects that work polymorphically -// -// 3. WRAPPER TYPE ELIMINATION: -// - Original design required wrapper types like OpenAIRequestWrapper, OpenAIResponseWrapper -// - User wanted to implement traits directly on concrete types (ChatCompletionsRequest, etc.) -// - Function-based approach allows direct trait implementations without wrappers -// -// 4. PARAMETERIZED CONVERSION PATTERN: -// - Follows existing codebase pattern: TryFrom<(&[u8], &ProviderId)> -// - Enables runtime provider selection while maintaining type safety -// - Single implementation can handle multiple OpenAI-compatible providers -// -// 5. TYPE ERASURE FOR GENERIC INTERFACE: -// - Functions return Box - works as trait objects -// - stream_context.rs can work with generic interfaces without knowing concrete types -// - Maintains polymorphism while enabling dynamic dispatch -// ============================================================================ - -use crate::ProviderId; - -// ============================================================================ -// PROVIDER ADAPTER REGISTRY (Organizational Enhancement) -// ============================================================================ - -/// Provider adapter configuration -#[derive(Debug, Clone)] -pub struct ProviderConfig { - pub supported_apis: &'static [&'static str], - pub adapter_type: AdapterType, -} - -#[derive(Debug, Clone)] -pub enum AdapterType { - OpenAICompatible, - // Future: Claude, Gemini, etc. -} - -/// Get provider configuration -pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig { - match provider_id { - ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek - | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { - ProviderConfig { - supported_apis: &["/v1/chat/completions"], - adapter_type: AdapterType::OpenAICompatible, - } - } - } -} - -/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object -pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result, ProviderRequestError> { - let config = get_provider_config(provider_id); - - match config.adapter_type { - AdapterType::OpenAICompatible => { - let request = crate::apis::openai::ChatCompletionsRequest::try_from((bytes, provider_id)) - .map_err(|e| ProviderRequestError { - message: format!("Failed to parse request: {}", e), - source: Some(Box::new(e)), - })?; - - // Return as trait object - this enables polymorphic usage - // ChatCompletionsRequest implements ProviderRequest directly (no wrapper needed) - Ok(Box::new(request) as Box) - } - } -} - -/// Parse response from bytes using provider ID - returns generic ProviderResponse trait object -pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result, ProviderResponseError> { - let config = get_provider_config(provider_id); - - match config.adapter_type { - AdapterType::OpenAICompatible => { - // Parameterized conversion allows provider-specific response parsing - let response = crate::apis::openai::ChatCompletionsResponse::try_from((bytes, provider_id)) - .map_err(|e| ProviderResponseError { - message: format!("Failed to parse response: {}", e), - source: Some(Box::new(e)), - })?; - - // ChatCompletionsResponse implements ProviderResponse directly - no wrapper needed! - Ok(Box::new(response) as Box) - } - } -} - -/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object -pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result, Box> { - let config = get_provider_config(provider_id); - - match config.adapter_type { - AdapterType::OpenAICompatible => { - // Parse SSE (Server-Sent Events) streaming data - let s = std::str::from_utf8(bytes)?; - let lines: Vec = s.lines().map(|line| line.to_string()).collect(); - let iter = crate::apis::openai::SseChatCompletionIter::new(lines.into_iter()); - - // Return the iterator directly - it implements ProviderStreamResponseIter - Ok(Box::new(iter)) - } - } -} - -/// Check if provider has compatible API -pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool { - let config = get_provider_config(provider_id); - config.supported_apis.iter().any(|&supported| supported == api_path) -} - -/// Get supported APIs for provider -pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> { - let config = get_provider_config(provider_id); - config.supported_apis.to_vec() -} - -/// Error types for provider operations -#[derive(Debug)] -pub struct ProviderRequestError { - pub message: String, - pub source: Option>, -} - -#[derive(Debug)] -pub struct ProviderResponseError { - pub message: String, - pub source: Option>, -} - -impl fmt::Display for ProviderRequestError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Provider request error: {}", self.message) - } -} - -impl fmt::Display for ProviderResponseError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Provider response error: {}", self.message) - } -} - -impl Error for ProviderRequestError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) - } -} - -impl Error for ProviderResponseError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) - } -} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 0d981297..77b37a46 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -11,8 +11,8 @@ use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use hermesllm::{ - try_request_from_bytes, try_response_from_bytes, try_streaming_from_bytes, ConversionMode, - ProviderId, + try_streaming_from_bytes, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, + ProviderResponseType, }; use http::StatusCode; use log::{debug, info, warn}; @@ -300,20 +300,21 @@ impl HttpContext for StreamContext { let provider_id = self.get_provider_id(); - let mut deserialized_body = match try_request_from_bytes(&body_bytes, &provider_id) { - Ok(deserialized) => deserialized, - Err(e) => { - debug!( - "on_http_request_body: request body: {}", - String::from_utf8_lossy(&body_bytes) - ); - self.send_server_error( - ServerError::LogicError(format!("Request parsing error: {}", e)), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }; + let mut deserialized_body = + match ProviderRequestType::try_from((&body_bytes[..], &provider_id)) { + Ok(deserialized) => deserialized, + Err(e) => { + debug!( + "on_http_request_body: request body: {}", + String::from_utf8_lossy(&body_bytes) + ); + self.send_server_error( + ServerError::LogicError(format!("Request parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }; let model_name = match self.llm_provider.as_ref() { Some(llm_provider) => llm_provider.model.as_ref(), @@ -388,18 +389,17 @@ impl HttpContext for StreamContext { let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str()); // Convert chat completion request to llm provider specific request using provider interface - let deserialized_body_bytes = - match deserialized_body.to_provider_bytes(ConversionMode::Compatible) { - Ok(bytes) => bytes, - Err(e) => { - warn!("Failed to serialize request body: {}", e); - self.send_server_error( - ServerError::LogicError(format!("Request serialization error: {}", e)), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }; + let deserialized_body_bytes = match deserialized_body.to_bytes() { + Ok(bytes) => bytes, + Err(e) => { + warn!("Failed to serialize request body: {}", e); + self.send_server_error( + ServerError::LogicError(format!("Request serialization error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }; self.set_http_request_body(0, body_size, &deserialized_body_bytes); @@ -572,7 +572,7 @@ impl HttpContext for StreamContext { // Since all providers use OpenAI-compatible streaming format let provider_id = self.get_provider_id(); - match try_streaming_from_bytes(&body, &provider_id, ConversionMode::Compatible) { + match try_streaming_from_bytes(&body, &provider_id) { Ok(mut streaming_response) => { // Process each streaming chunk while let Some(chunk_result) = streaming_response.next() { @@ -630,8 +630,8 @@ impl HttpContext for StreamContext { } else { debug!("non streaming response"); let provider_id = self.get_provider_id(); - let response = - match try_response_from_bytes(&body, &provider_id, ConversionMode::Compatible) { + let response: ProviderResponseType = + match ProviderResponseType::try_from((&body[..], provider_id)) { Ok(response) => response, Err(e) => { warn!(