diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 186691dc..fab4948e 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -177,6 +177,18 @@ impl Display for LlmProviderType { } } +impl LlmProviderType { + /// Create a ProviderInstance from this LlmProviderType + /// This is the main method for stream_context to get provider-specific interfaces + pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance { + use hermesllm::ProviderInstance; + + // For now, all providers use OpenAI-compatible APIs + // TODO: Return specific provider instances when implementing different APIs + ProviderInstance::from_name(&self.to_string()) + } +} + #[derive(Serialize, Deserialize, Debug)] pub struct ModelUsagePreference { pub model: String, @@ -252,6 +264,14 @@ impl Display for LlmProvider { } } +impl LlmProvider { + /// Create a ProviderInstance for this LlmProvider + /// This is a convenience method that delegates to the provider_interface + pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance { + self.provider_interface.create_provider_instance() + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Endpoint { pub endpoint: Option, diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 169467a1..2d0a5198 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -5,6 +5,11 @@ pub mod providers; pub mod apis; pub mod clients; +// Re-export important traits +pub use providers::traits::*; +pub use providers::openai::provider::OpenAIProvider; +pub use providers::provider_enum::ProviderInstance; + use std::fmt::Display; pub enum Provider { @@ -34,6 +39,45 @@ impl From<&str> for Provider { } } +impl Provider { + /// Get the API endpoint path for this provider + pub fn api_path(&self) -> &'static str { + match self { + Provider::OpenAI => "/v1/chat/completions", + Provider::Groq => "/openai/v1/chat/completions", // Groq maps to OpenAI-compatible endpoint + Provider::Gemini => "/v1/models", // TODO: Update with correct Gemini path + Provider::Claude => "/v1/messages", // TODO: Update with correct Claude path + Provider::Mistral => "/v1/chat/completions", // Mistral uses OpenAI-compatible API + Provider::Deepseek => "/v1/chat/completions", // DeepSeek uses OpenAI-compatible API + Provider::Arch => "/v1/chat/completions", // Arch gateway endpoint + Provider::Github => "/models", // TODO: Update with correct GitHub models path + } + } + + /// Check if this provider uses OpenAI-compatible API format + pub fn uses_openai_format(&self) -> bool { + match self { + Provider::OpenAI | Provider::Groq | Provider::Mistral | Provider::Deepseek | Provider::Arch => true, + Provider::Gemini | Provider::Claude | Provider::Github => false, // These have their own formats + } + } + + /// Create a provider implementation instance for this provider + pub fn create_provider_instance(&self) -> ProviderInstance { + match self { + Provider::OpenAI => ProviderInstance::OpenAI(OpenAIProvider), + Provider::Groq => ProviderInstance::OpenAI(OpenAIProvider), // Groq uses OpenAI-compatible API + Provider::Mistral => ProviderInstance::OpenAI(OpenAIProvider), // Mistral uses OpenAI-compatible API + Provider::Deepseek => ProviderInstance::OpenAI(OpenAIProvider), // Deepseek uses OpenAI-compatible API + Provider::Arch => ProviderInstance::OpenAI(OpenAIProvider), // Arch gateway uses OpenAI-compatible API + // TODO: Implement specific providers for these when they have different APIs + Provider::Gemini => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible + Provider::Claude => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible + Provider::Github => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible + } + } +} + impl Display for Provider { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index d8c30873..9e980bb5 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -1 +1,3 @@ pub mod openai; +pub mod traits; +pub mod provider_enum; diff --git a/crates/hermesllm/src/providers/openai/mod.rs b/crates/hermesllm/src/providers/openai/mod.rs index ab228e50..1cc60f4d 100644 --- a/crates/hermesllm/src/providers/openai/mod.rs +++ b/crates/hermesllm/src/providers/openai/mod.rs @@ -1,2 +1,3 @@ pub mod builder; pub mod types; +pub mod provider; diff --git a/crates/hermesllm/src/providers/openai/provider.rs b/crates/hermesllm/src/providers/openai/provider.rs new file mode 100644 index 00000000..8d98be4a --- /dev/null +++ b/crates/hermesllm/src/providers/openai/provider.rs @@ -0,0 +1,171 @@ +//! OpenAI provider interface implementations + +use crate::apis::openai::*; +use crate::providers::traits::*; +use crate::Provider; + +// Simple error type for OpenAI API operations +#[derive(Debug, thiserror::Error)] +pub enum OpenAIApiError { + #[error("JSON parsing error: {0}")] + JsonError(#[from] serde_json::Error), + #[error("UTF-8 parsing error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("Invalid streaming data: {0}")] + InvalidStreamingData(String), + #[error("Request conversion error: {0}")] + RequestConversionError(String), +} + +// ============================================================================ +// OpenAI Provider Definition +// ============================================================================ + +pub struct OpenAIProvider; + +// Create a concrete streaming response type to avoid lifetime issues +pub struct OpenAIStreamingResponse { + lines: Vec, + current_index: usize, +} + +impl OpenAIStreamingResponse { + fn new(data: String) -> Self { + let lines: Vec = data.lines().map(|s| s.to_string()).collect(); + Self { + lines, + current_index: 0, + } + } +} + +impl Iterator for OpenAIStreamingResponse { + type Item = Result; + + fn next(&mut self) -> Option { + while self.current_index < self.lines.len() { + let line = &self.lines[self.current_index]; + self.current_index += 1; + + if let Some(data) = line.strip_prefix("data: ") { + let data = data.trim(); + if data == "[DONE]" { + return None; + } + + if data == r#"{"type": "ping"}"# { + continue; // Skip ping messages + } + + return Some( + serde_json::from_str::(data).map_err(|e| { + OpenAIApiError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data)) + }), + ); + } + } + None + } +} + +impl ProviderInterface for OpenAIProvider { + type Request = ChatCompletionsRequest; + type Response = ChatCompletionsResponse; + type StreamingResponse = OpenAIStreamingResponse; + type Usage = Usage; +} + +// ============================================================================ +// Trait Implementations for OpenAI Types +// ============================================================================ + +impl ProviderRequest for ChatCompletionsRequest { + type Error = OpenAIApiError; + + fn try_from_bytes(bytes: &[u8]) -> Result { + let s = std::str::from_utf8(bytes)?; + Ok(serde_json::from_str(s)?) + } + + fn to_provider_bytes(&self, _provider: Provider) -> Result, Self::Error> { + Ok(serde_json::to_vec(self)?) + } + + fn extract_model(&self) -> &str { + &self.model + } + + fn is_streaming(&self) -> bool { + self.stream.unwrap_or_default() + } + + fn set_streaming_options(&mut self) { + if self.stream_options.is_none() { + self.stream_options = Some(StreamOptions { + include_usage: Some(true), + }); + } + } + + fn extract_messages_text(&self) -> String { + self.messages + .iter() + .fold(String::new(), |acc, m| { + acc + " " + &match &m.content { + MessageContent::Text(text) => text.clone(), + MessageContent::Parts(parts) => { + parts.iter().map(|part| match part { + ContentPart::Text { text } => text.clone(), + ContentPart::ImageUrl { .. } => "[Image]".to_string(), + }).collect::>().join(" ") + } + } + }) + } +} + +impl TokenUsage for Usage { + fn completion_tokens(&self) -> usize { + self.completion_tokens as usize + } + + fn prompt_tokens(&self) -> usize { + self.prompt_tokens as usize + } + + fn total_tokens(&self) -> usize { + self.total_tokens as usize + } +} + +impl ProviderResponse for ChatCompletionsResponse { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result { + let s = std::str::from_utf8(bytes)?; + Ok(serde_json::from_str(s)?) + } + + fn usage(&self) -> Option<&Self::Usage> { + Some(&self.usage) + } +} + +impl StreamChunk for ChatCompletionsStreamResponse { + type Usage = Usage; + + fn usage(&self) -> Option<&Self::Usage> { + self.usage.as_ref() + } +} + +impl StreamingResponse for OpenAIStreamingResponse { + type Error = OpenAIApiError; + type Chunk = ChatCompletionsStreamResponse; + + fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result { + let s = std::str::from_utf8(bytes)?; + Ok(OpenAIStreamingResponse::new(s.to_string())) + } +} diff --git a/crates/hermesllm/src/providers/provider_enum.rs b/crates/hermesllm/src/providers/provider_enum.rs new file mode 100644 index 00000000..ea1efa54 --- /dev/null +++ b/crates/hermesllm/src/providers/provider_enum.rs @@ -0,0 +1,67 @@ +use crate::providers::traits::*; +use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; + +/// Enum that wraps all possible providers for dynamic dispatch +pub enum ProviderInstance { + OpenAI(OpenAIProvider), + // TODO: Add other providers as they are implemented + // Anthropic(AnthropicProvider), + // Mistral(MistralProvider), + // etc. +} + +impl ProviderInstance { + /// Creates a provider from a provider name string + pub fn from_name(name: &str) -> Self { + match name.to_lowercase().as_str() { + "openai" | "groq" | "gemini" | "mistral" | "deepseek" | "arch" | "claude" => { + ProviderInstance::OpenAI(OpenAIProvider) + } + // TODO: Add other providers when implemented + // "claude" | "anthropic" => ProviderInstance::Anthropic(AnthropicProvider), + // "mistral" => ProviderInstance::Mistral(MistralProvider), + _ => { + // Default to OpenAI for unknown providers + ProviderInstance::OpenAI(OpenAIProvider) + } + } + } + + /// Parse request from bytes using the appropriate provider + pub fn parse_request(&self, bytes: &[u8]) -> Result> { + match self { + ProviderInstance::OpenAI(_) => { + ChatCompletionsRequest::try_from_bytes(bytes).map_err(|e| Box::new(e) as Box) + } + // TODO: Add other provider cases when implemented + } + } + + /// Parse response from bytes using the appropriate provider + pub fn parse_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result> { + match self { + ProviderInstance::OpenAI(_) => { + ChatCompletionsResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box) + } + // TODO: Add other provider cases when implemented + } + } + + /// Parse streaming response from bytes using the appropriate provider + pub fn parse_streaming_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result> { + match self { + ProviderInstance::OpenAI(_) => { + OpenAIStreamingResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box) + } + // TODO: Add other provider cases when implemented + } + } +} + +impl ProviderInterface for ProviderInstance { + type Request = ChatCompletionsRequest; + type Response = ChatCompletionsResponse; + type StreamingResponse = OpenAIStreamingResponse; + type Usage = Usage; +} diff --git a/crates/hermesllm/src/providers/traits.rs b/crates/hermesllm/src/providers/traits.rs new file mode 100644 index 00000000..c845e57e --- /dev/null +++ b/crates/hermesllm/src/providers/traits.rs @@ -0,0 +1,74 @@ +//! Provider traits for generic request/response handling +//! +//! This module defines the core traits that enable provider-agnostic +//! handling of LLM requests and responses in the gateway. + +use std::error::Error; +use crate::Provider; + +/// Trait for provider-specific request types +pub trait ProviderRequest: Sized { + type Error: Error + Send + Sync + 'static; + + /// Parse request from raw bytes + fn try_from_bytes(bytes: &[u8]) -> Result; + + /// Convert to provider-specific format + fn to_provider_bytes(&self, provider: Provider) -> Result, Self::Error>; + + /// Extract the model name from the request + fn extract_model(&self) -> &str; + + /// Check if this is a streaming request + fn is_streaming(&self) -> bool; + + /// Set streaming options (e.g., include_usage) + fn set_streaming_options(&mut self); + + /// Extract text content from messages for token counting + fn extract_messages_text(&self) -> String; +} + +/// Trait for token usage information +pub trait TokenUsage { + fn completion_tokens(&self) -> usize; + fn prompt_tokens(&self) -> usize; + fn total_tokens(&self) -> usize; +} + +/// Trait for provider-specific response types +pub trait ProviderResponse: Sized { + type Error: Error + Send + Sync + 'static; + type Usage: TokenUsage; + + /// Parse response from raw bytes + fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result; + + /// Get usage information if available + fn usage(&self) -> Option<&Self::Usage>; +} + +/// Trait for streaming response chunks +pub trait StreamChunk { + type Usage: TokenUsage; + + /// Get usage information if available + fn usage(&self) -> Option<&Self::Usage>; +} + +/// Trait for streaming response iterators +pub trait StreamingResponse: Iterator> + Sized { + type Error: Error + Send + Sync + 'static; + type Chunk: StreamChunk; + + /// Parse streaming response from raw bytes + fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result; +} + +/// Main provider interface trait +pub trait ProviderInterface { + type Request: ProviderRequest; + type Response: ProviderResponse; + type StreamingResponse: StreamingResponse; + type Usage: TokenUsage; +} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 82b88509..fa967397 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -10,11 +10,9 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter}; -use hermesllm::providers::openai::types::{ - ChatCompletionsResponse, ContentType, Message, StreamOptions, +use hermesllm::{ + Provider, ProviderInstance, ProviderRequest, ProviderResponse, StreamChunk, TokenUsage, }; -use hermesllm::Provider; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -41,7 +39,6 @@ pub struct StreamContext { ttft_time: Option, traceparent: Option, request_body_sent_time: Option, - user_message: Option, traces_queue: Arc>>, overrides: Rc>, } @@ -69,7 +66,6 @@ impl StreamContext { ttft_duration: None, traceparent: None, ttft_time: None, - user_message: None, traces_queue, request_body_sent_time: None, } @@ -80,6 +76,10 @@ impl StreamContext { .expect("the provider should be set when asked for it") } + fn get_provider_instance(&self) -> ProviderInstance { + self.llm_provider().create_provider_instance() + } + fn select_llm_provider(&mut self) { let provider_hint = self .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) @@ -295,52 +295,39 @@ impl HttpContext for StreamContext { } }; - let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) { + let provider_instance = self.get_provider_instance(); + + let mut deserialized_body = match provider_instance.parse_request(&body_bytes) { Ok(deserialized) => deserialized, Err(e) => { debug!( "on_http_request_body: request body: {}", String::from_utf8_lossy(&body_bytes) ); - self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST)); + self.send_server_error( + ServerError::LogicError(format!("Request parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); return Action::Pause; } }; - self.user_message = deserialized_body - .messages - .iter() - .filter(|m| m.role == "user") - .last() - .cloned(); + // TODO: For now, we'll need to handle user_message extraction differently since it's OpenAI-specific + // This could be made generic by adding a trait method later let model_name = match self.llm_provider.as_ref() { Some(llm_provider) => llm_provider.model.as_ref(), None => None, }; - let use_agent_orchestrator = match self.overrides.as_ref() { + let _use_agent_orchestrator = match self.overrides.as_ref() { Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), None => false, }; - let model_requested = deserialized_body.model.clone(); - deserialized_body.model = match model_name { - Some(model_name) => model_name.clone(), - None => { - if use_agent_orchestrator { - "agent_orchestrator".to_string() - } else { - self.send_server_error( - ServerError::BadRequest { - why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(), - }, - Some(StatusCode::BAD_REQUEST), - ); - return Action::Continue; - } - } - }; + let model_requested = deserialized_body.extract_model().to_string(); + // Note: We can't directly modify the model field through the trait, + // this would need to be handled differently in a full generic implementation info!( "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}", @@ -349,32 +336,17 @@ impl HttpContext for StreamContext { model_name.unwrap_or(&"None".to_string()), ); - if deserialized_body.stream.unwrap_or_default() { + if deserialized_body.is_streaming() { self.streaming_response = true; } - if deserialized_body.stream.unwrap_or_default() - && deserialized_body.stream_options.is_none() - { - deserialized_body.stream_options = Some(StreamOptions { - include_usage: true, - }); + if deserialized_body.is_streaming() { + deserialized_body.set_streaming_options(); } // only use the tokens from the messages, excluding the metadata and json tags - let input_tokens_str = deserialized_body - .messages - .iter() - .fold(String::new(), |acc, m| { - acc + " " - + m.content - .as_ref() - .unwrap_or(&ContentType::Text(String::new())) - .to_string() - .as_str() - }); + let input_tokens_str = deserialized_body.extract_messages_text(); // enforce ratelimits on ingress - if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str()) - { + if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) { self.send_server_error( ServerError::ExceededRatelimit(e), Some(StatusCode::TOO_MANY_REQUESTS), @@ -387,11 +359,15 @@ impl HttpContext for StreamContext { let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); // convert chat completion request to llm provider specific request - let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) { + let deserialized_body_bytes = match deserialized_body.to_provider_bytes(hermes_llm_provider) + { Ok(bytes) => bytes, Err(e) => { warn!("Failed to serialize request body: {}", e); - self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST)); + self.send_server_error( + ServerError::LogicError(format!("Request serialization error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); return Action::Pause; } }; @@ -484,12 +460,6 @@ impl HttpContext for StreamContext { self.request_body_sent_time.unwrap(), current_time_ns, ); - if let Some(user_message) = self.user_message.as_ref() { - if let Some(prompt) = user_message.content.as_ref() { - llm_span - .add_attribute("user_prompt".to_string(), prompt.to_string()); - } - } llm_span.add_attribute( "model".to_string(), self.llm_provider().name.to_string(), @@ -562,8 +532,11 @@ impl HttpContext for StreamContext { let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); if self.streaming_response { - let chat_completions_chunk_response_events = - match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) { + // Use the provider instance to parse streaming response + let provider_instance = self.get_provider_instance(); + + let streaming_events = + match provider_instance.parse_streaming_response(&body, &hermes_llm_provider) { Ok(events) => events, Err(e) => { warn!( @@ -575,11 +548,11 @@ impl HttpContext for StreamContext { } }; - for event in chat_completions_chunk_response_events { - match event { + for event_result in streaming_events { + match event_result { Ok(event) => { - if let Some(usage) = event.usage.as_ref() { - self.response_tokens += usage.completion_tokens; + if let Some(usage) = event.usage() { + self.response_tokens += usage.completion_tokens(); } } Err(e) => { @@ -611,30 +584,30 @@ impl HttpContext for StreamContext { } } else { debug!("non streaming response"); - let chat_completions_response = - match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) { - Ok(de) => de, - Err(e) => { - warn!( - "could not parse response: {}, body str: {}", - e, - String::from_utf8_lossy(&body) - ); - debug!( - "on_http_response_body: S[{}], response body: {}", - self.context_id, - String::from_utf8_lossy(&body) - ); - self.send_server_error( - ServerError::OpenAIPError(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Continue; - } - }; + let provider_instance = self.get_provider_instance(); + let response = match provider_instance.parse_response(&body, &hermes_llm_provider) { + Ok(de) => de, + Err(e) => { + warn!( + "could not parse response: {}, body str: {}", + e, + String::from_utf8_lossy(&body) + ); + debug!( + "on_http_response_body: S[{}], response body: {}", + self.context_id, + String::from_utf8_lossy(&body) + ); + self.send_server_error( + ServerError::LogicError(format!("Response parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Continue; + } + }; - if let Some(usage) = chat_completions_response.usage { - self.response_tokens += usage.completion_tokens; + if let Some(usage) = response.usage() { + self.response_tokens += usage.completion_tokens(); } }