From c4148a3d528eaf913a5020fac6a045586c016196 Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Sat, 9 Aug 2025 21:52:31 -0700 Subject: [PATCH] more clean up --- crates/hermesllm/src/lib.rs | 32 ++++++++++++++ .../hermesllm/src/providers/arch/provider.rs | 26 +++-------- .../src/providers/claude/provider.rs | 29 +++--------- .../src/providers/deepseek/provider.rs | 26 +++-------- .../src/providers/gemini/provider.rs | 29 +++--------- .../src/providers/github/provider.rs | 29 +++--------- .../hermesllm/src/providers/groq/provider.rs | 26 +++-------- .../src/providers/mistral/provider.rs | 26 +++-------- crates/hermesllm/src/providers/mod.rs | 13 ++++++ .../src/providers/openai/provider.rs | 44 ++++++++++--------- crates/hermesllm/src/providers/traits.rs | 18 ++------ crates/llm_gateway/src/stream_context.rs | 22 +--------- 12 files changed, 107 insertions(+), 213 deletions(-) diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index acc0c431..43e5c350 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -60,4 +60,36 @@ mod tests { // Test that provider supports the expected API endpoints assert!(provider.has_compatible_api("/v1/chat/completions")); } + + #[test] + fn test_provider_extract_user_message() { + use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent}; + + let provider = Provider::new(ProviderId::OpenAI); + + // Test with text message + let request = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: crate::apis::openai::Role::System, + content: MessageContent::Text("You are a helpful assistant".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }, + Message { + role: crate::apis::openai::Role::User, + content: MessageContent::Text("Hello, world!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }, + ], + ..Default::default() + }; + + let user_message = provider.extract_user_message(&request); + assert_eq!(user_message, Some("Hello, world!".to_string())); + } } diff --git a/crates/hermesllm/src/providers/arch/provider.rs b/crates/hermesllm/src/providers/arch/provider.rs index 929c7724..8100af83 100644 --- a/crates/hermesllm/src/providers/arch/provider.rs +++ b/crates/hermesllm/src/providers/arch/provider.rs @@ -40,6 +40,11 @@ impl ProviderRequest for ArchProvider { let openai_provider = OpenAIProvider; ProviderRequest::extract_messages_text(&openai_provider, request) } + + fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_user_message(&openai_provider, request) + } } impl ProviderResponse for ArchProvider { @@ -79,25 +84,4 @@ impl ProviderInterface for ArchProvider { fn supported_apis(&self) -> Vec<&'static str> { vec!["/v1/chat/completions"] } - - fn parse_request(&self, bytes: &[u8]) -> Result> { - match ProviderRequest::try_from_bytes(self, bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - - fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - - fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } } diff --git a/crates/hermesllm/src/providers/claude/provider.rs b/crates/hermesllm/src/providers/claude/provider.rs index e4463eaf..0febda79 100644 --- a/crates/hermesllm/src/providers/claude/provider.rs +++ b/crates/hermesllm/src/providers/claude/provider.rs @@ -43,6 +43,11 @@ impl ProviderRequest for ClaudeProvider { let openai_provider = OpenAIProvider; ProviderRequest::extract_messages_text(&openai_provider, request) } + + fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_user_message(&openai_provider, request) + } } impl ProviderResponse for ClaudeProvider { @@ -84,28 +89,4 @@ impl ProviderInterface for ClaudeProvider { // TODO: Update when Claude API is fully implemented vec!["/v1/messages"] } - - fn parse_request(&self, bytes: &[u8]) -> Result> { - // TODO: Implement Claude-specific request parsing - match ProviderRequest::try_from_bytes(self, bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - - fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - // TODO: Implement Claude-specific response parsing - match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - - fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - // TODO: Implement Claude-specific request serialization - match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } } diff --git a/crates/hermesllm/src/providers/deepseek/provider.rs b/crates/hermesllm/src/providers/deepseek/provider.rs index 3dad7d94..a2cee721 100644 --- a/crates/hermesllm/src/providers/deepseek/provider.rs +++ b/crates/hermesllm/src/providers/deepseek/provider.rs @@ -40,6 +40,11 @@ impl ProviderRequest for DeepseekProvider { let openai_provider = OpenAIProvider; ProviderRequest::extract_messages_text(&openai_provider, request) } + + fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_user_message(&openai_provider, request) + } } impl ProviderResponse for DeepseekProvider { @@ -79,25 +84,4 @@ impl ProviderInterface for DeepseekProvider { fn supported_apis(&self) -> Vec<&'static str> { vec!["/v1/chat/completions"] } - - fn parse_request(&self, bytes: &[u8]) -> Result> { - match ProviderRequest::try_from_bytes(self, bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - - fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - - fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } } diff --git a/crates/hermesllm/src/providers/gemini/provider.rs b/crates/hermesllm/src/providers/gemini/provider.rs index 14de48ad..1a222b7e 100644 --- a/crates/hermesllm/src/providers/gemini/provider.rs +++ b/crates/hermesllm/src/providers/gemini/provider.rs @@ -43,6 +43,11 @@ impl ProviderRequest for GeminiProvider { let openai_provider = OpenAIProvider; ProviderRequest::extract_messages_text(&openai_provider, request) } + + fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_user_message(&openai_provider, request) + } } impl ProviderResponse for GeminiProvider { @@ -84,28 +89,4 @@ impl ProviderInterface for GeminiProvider { // TODO: Update when Gemini API is fully implemented vec!["/v1/models"] } - - fn parse_request(&self, bytes: &[u8]) -> Result> { - // TODO: Implement Gemini-specific request parsing - match ProviderRequest::try_from_bytes(self, bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - - fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - // TODO: Implement Gemini-specific response parsing - match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - - fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - // TODO: Implement Gemini-specific request serialization - match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } } diff --git a/crates/hermesllm/src/providers/github/provider.rs b/crates/hermesllm/src/providers/github/provider.rs index 63ef12e4..6dd51542 100644 --- a/crates/hermesllm/src/providers/github/provider.rs +++ b/crates/hermesllm/src/providers/github/provider.rs @@ -43,6 +43,11 @@ impl ProviderRequest for GitHubProvider { let openai_provider = OpenAIProvider; ProviderRequest::extract_messages_text(&openai_provider, request) } + + fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_user_message(&openai_provider, request) + } } impl ProviderResponse for GitHubProvider { @@ -84,28 +89,4 @@ impl ProviderInterface for GitHubProvider { // TODO: Update when GitHub API is fully implemented vec!["/models"] } - - fn parse_request(&self, bytes: &[u8]) -> Result> { - // TODO: Implement GitHub-specific request parsing - match ProviderRequest::try_from_bytes(self, bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - - fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - // TODO: Implement GitHub-specific response parsing - match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - - fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - // TODO: Implement GitHub-specific request serialization - match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } } diff --git a/crates/hermesllm/src/providers/groq/provider.rs b/crates/hermesllm/src/providers/groq/provider.rs index e08a022b..73a8148b 100644 --- a/crates/hermesllm/src/providers/groq/provider.rs +++ b/crates/hermesllm/src/providers/groq/provider.rs @@ -43,6 +43,11 @@ impl ProviderRequest for GroqProvider { let openai_provider = OpenAIProvider; ProviderRequest::extract_messages_text(&openai_provider, request) } + + fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_user_message(&openai_provider, request) + } } impl ProviderResponse for GroqProvider { @@ -82,25 +87,4 @@ impl ProviderInterface for GroqProvider { fn supported_apis(&self) -> Vec<&'static str> { vec!["/openai/v1/chat/completions"] } - - fn parse_request(&self, bytes: &[u8]) -> Result> { - match ProviderRequest::try_from_bytes(self, bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - - fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - - fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } } diff --git a/crates/hermesllm/src/providers/mistral/provider.rs b/crates/hermesllm/src/providers/mistral/provider.rs index 5aa2ab28..31aae8db 100644 --- a/crates/hermesllm/src/providers/mistral/provider.rs +++ b/crates/hermesllm/src/providers/mistral/provider.rs @@ -40,6 +40,11 @@ impl ProviderRequest for MistralProvider { let openai_provider = OpenAIProvider; ProviderRequest::extract_messages_text(&openai_provider, request) } + + fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_user_message(&openai_provider, request) + } } impl ProviderResponse for MistralProvider { @@ -79,25 +84,4 @@ impl ProviderInterface for MistralProvider { fn supported_apis(&self) -> Vec<&'static str> { vec!["/v1/chat/completions"] } - - fn parse_request(&self, bytes: &[u8]) -> Result> { - match ProviderRequest::try_from_bytes(self, bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - - fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - - fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } } diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 45eb8d6d..fa49dfd9 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -201,6 +201,19 @@ impl ProviderRequest for Provider { Provider::GitHub(provider, _) => ProviderRequest::extract_messages_text(provider, request), } } + + fn extract_user_message(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> Option { + match self { + Provider::OpenAI(provider, _) => ProviderRequest::extract_user_message(provider, request), + Provider::Groq(provider, _) => ProviderRequest::extract_user_message(provider, request), + Provider::Mistral(provider, _) => ProviderRequest::extract_user_message(provider, request), + Provider::Deepseek(provider, _) => ProviderRequest::extract_user_message(provider, request), + Provider::Arch(provider, _) => ProviderRequest::extract_user_message(provider, request), + Provider::Gemini(provider, _) => ProviderRequest::extract_user_message(provider, request), + Provider::Claude(provider, _) => ProviderRequest::extract_user_message(provider, request), + Provider::GitHub(provider, _) => ProviderRequest::extract_user_message(provider, request), + } + } } impl ProviderResponse for Provider { diff --git a/crates/hermesllm/src/providers/openai/provider.rs b/crates/hermesllm/src/providers/openai/provider.rs index 129660e8..3cf85d6a 100644 --- a/crates/hermesllm/src/providers/openai/provider.rs +++ b/crates/hermesllm/src/providers/openai/provider.rs @@ -75,27 +75,6 @@ impl ProviderInterface for OpenAIProvider { fn supported_apis(&self) -> Vec<&'static str> { vec!["/v1/chat/completions"] } - - fn parse_request(&self, bytes: &[u8]) -> Result> { - match ProviderRequest::try_from_bytes(self, bytes) { - Ok(req) => Ok(req), - Err(e) => Err(Box::new(e)), - } - } - - fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { - Ok(resp) => Ok(resp), - Err(e) => Err(Box::new(e)), - } - } - - fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { - Ok(bytes) => Ok(bytes), - Err(e) => Err(Box::new(e)), - } - } } // Direct trait implementations on OpenAIProvider @@ -142,6 +121,29 @@ impl ProviderRequest for OpenAIProvider { } }) } + + fn extract_user_message(&self, request: &ChatCompletionsRequest) -> Option { + request.messages.last().and_then(|msg| { + match &msg.content { + MessageContent::Text(text) => Some(text.clone()), + MessageContent::Parts(parts) => { + // Extract text from content parts, ignoring images + let text_parts: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.clone()), + ContentPart::ImageUrl { .. } => None, + }) + .collect(); + if text_parts.is_empty() { + None + } else { + Some(text_parts.join(" ")) + } + } + } + }) + } } impl ProviderResponse for OpenAIProvider { diff --git a/crates/hermesllm/src/providers/traits.rs b/crates/hermesllm/src/providers/traits.rs index b1d1628b..97a0eded 100644 --- a/crates/hermesllm/src/providers/traits.rs +++ b/crates/hermesllm/src/providers/traits.rs @@ -35,6 +35,9 @@ pub trait ProviderRequest { /// Extract text content from messages for token counting fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String; + + /// Extract the user message for tracing/logging purposes + fn extract_user_message(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> Option; } /// Trait for token usage information @@ -85,19 +88,4 @@ pub trait ProviderInterface: ProviderRequest + ProviderResponse + StreamingRespo /// Get supported API endpoints for this provider fn supported_apis(&self) -> Vec<&'static str>; - - /// Parse a request from raw bytes - delegates to ProviderRequest - fn parse_request(&self, bytes: &[u8]) -> Result> { - ProviderRequest::try_from_bytes(self, bytes).map_err(|e| Box::new(e) as Box) - } - - /// Parse a response from raw bytes - delegates to ProviderResponse - fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result> { - ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode).map_err(|e| Box::new(e) as Box) - } - - /// Convert a request to bytes - delegates to ProviderRequest - fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result, Box> { - ProviderRequest::to_provider_bytes(self, request, provider_id, mode).map_err(|e| Box::new(e) as Box) - } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 9a483d31..26e6c320 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -10,7 +10,6 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::apis::openai::{ContentPart, MessageContent}; use hermesllm::providers::traits::{ ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage, }; @@ -333,26 +332,7 @@ impl HttpContext for StreamContext { let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string // Extract user message for tracing - self.user_message = deserialized_body.messages.last().and_then(|msg| { - match &msg.content { - MessageContent::Text(text) => Some(text.clone()), - MessageContent::Parts(parts) => { - // Extract text from content parts, ignoring images - let text_parts: Vec = parts - .iter() - .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.clone()), - ContentPart::ImageUrl { .. } => None, - }) - .collect(); - if text_parts.is_empty() { - None - } else { - Some(text_parts.join(" ")) - } - } - } - }); + self.user_message = provider.extract_user_message(&deserialized_body); info!( "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",