From 9c09a18fd06dd782286f9a68b17b665e5f3bb9ce Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Sat, 9 Aug 2025 21:40:33 -0700 Subject: [PATCH] more refactoring changes to avoid unecessary re-direction and duplication --- crates/hermesllm/src/lib.rs | 14 +- .../hermesllm/src/providers/arch/provider.rs | 73 ++++++++- .../src/providers/claude/provider.rs | 73 ++++++++- .../src/providers/deepseek/provider.rs | 73 ++++++++- .../src/providers/gemini/provider.rs | 77 +++++++++- .../src/providers/github/provider.rs | 77 +++++++++- .../hermesllm/src/providers/groq/provider.rs | 73 ++++++++- .../src/providers/mistral/provider.rs | 73 ++++++++- crates/hermesllm/src/providers/mod.rs | 142 +++++++++++++++-- .../src/providers/openai/provider.rs | 95 +++++++----- crates/hermesllm/src/providers/traits.rs | 120 ++++----------- crates/llm_gateway/src/stream_context.rs | 144 +++++++++++++----- 12 files changed, 809 insertions(+), 225 deletions(-) diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index dac18303..acc0c431 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -46,18 +46,18 @@ mod tests { #[test] fn test_provider_instance_creation() { let provider = Provider::new(ProviderId::OpenAI); - assert!(provider.interface().has_compatible_api("/v1/chat/completions")); - assert!(!provider.interface().has_compatible_api("/v1/embeddings")); + assert!(provider.has_compatible_api("/v1/chat/completions")); + assert!(!provider.has_compatible_api("/v1/embeddings")); } #[test] - fn test_conversion_mode() { + fn test_provider_supported_apis() { let provider = Provider::new(ProviderId::OpenAI); - let compatible_mode = provider.interface().get_interface(false); - assert!(matches!(compatible_mode, ConversionMode::Compatible)); + let supported_apis = provider.supported_apis(); + assert!(supported_apis.contains(&"/v1/chat/completions")); - let passthrough_mode = provider.interface().get_interface(true); - assert!(matches!(passthrough_mode, ConversionMode::Passthrough)); + // Test that provider supports the expected API endpoints + assert!(provider.has_compatible_api("/v1/chat/completions")); } } diff --git a/crates/hermesllm/src/providers/arch/provider.rs b/crates/hermesllm/src/providers/arch/provider.rs index cb6a3692..929c7724 100644 --- a/crates/hermesllm/src/providers/arch/provider.rs +++ b/crates/hermesllm/src/providers/arch/provider.rs @@ -1,13 +1,76 @@ //! Arch provider implementation use crate::providers::{ProviderInterface, ConversionMode}; -use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; -use crate::providers::traits::{ProviderRequest, ProviderResponse}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; +use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse}; +use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError}; /// Arch provider implementation #[derive(Debug, Clone)] pub struct ArchProvider; +// Trait implementations that delegate to OpenAI +impl ProviderRequest for ArchProvider { + type Error = OpenAIApiError; + + fn try_from_bytes(&self, bytes: &[u8]) -> Result { + let openai_provider = OpenAIProvider; + ProviderRequest::try_from_bytes(&openai_provider, bytes) + } + + fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result, Self::Error> { + let openai_provider = OpenAIProvider; + ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode) + } + + fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str { + &request.model + } + + fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool { + request.stream.unwrap_or_default() + } + + fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) { + let openai_provider = OpenAIProvider; + ProviderRequest::set_streaming_options(&openai_provider, request) + } + + fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_messages_text(&openai_provider, request) + } +} + +impl ProviderResponse for ArchProvider { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } + + fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> { + Some(&response.usage) + } + + fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + let openai_provider = OpenAIProvider; + ProviderResponse::extract_usage_counts(&openai_provider, response) + } +} + +impl StreamingResponse for ArchProvider { + type Error = OpenAIApiError; + type StreamingIter = OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } +} + impl ProviderInterface for ArchProvider { fn has_compatible_api(&self, api_path: &str) -> bool { matches!(api_path, "/v1/chat/completions") @@ -18,21 +81,21 @@ impl ProviderInterface for ArchProvider { } fn parse_request(&self, bytes: &[u8]) -> Result> { - match ChatCompletionsRequest::try_from_bytes(bytes) { + match ProviderRequest::try_from_bytes(self, bytes) { Ok(req) => Ok(req), Err(e) => Err(Box::new(e)), } } fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { Ok(resp) => Ok(resp), Err(e) => Err(Box::new(e)), } } fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match request.to_provider_bytes(provider_id, mode) { + match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { Ok(bytes) => Ok(bytes), Err(e) => Err(Box::new(e)), } diff --git a/crates/hermesllm/src/providers/claude/provider.rs b/crates/hermesllm/src/providers/claude/provider.rs index c9c4982b..e4463eaf 100644 --- a/crates/hermesllm/src/providers/claude/provider.rs +++ b/crates/hermesllm/src/providers/claude/provider.rs @@ -4,13 +4,76 @@ //! For now, uses OpenAI-compatible format as fallback use crate::providers::{ProviderInterface, ConversionMode}; -use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; -use crate::providers::traits::{ProviderRequest, ProviderResponse}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; +use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse}; +use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError}; /// Claude provider implementation #[derive(Debug, Clone)] pub struct ClaudeProvider; +// Trait implementations that delegate to OpenAI +impl ProviderRequest for ClaudeProvider { + type Error = OpenAIApiError; + + fn try_from_bytes(&self, bytes: &[u8]) -> Result { + let openai_provider = OpenAIProvider; + ProviderRequest::try_from_bytes(&openai_provider, bytes) + } + + fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result, Self::Error> { + let openai_provider = OpenAIProvider; + ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode) + } + + fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str { + &request.model + } + + fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool { + request.stream.unwrap_or_default() + } + + fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) { + let openai_provider = OpenAIProvider; + ProviderRequest::set_streaming_options(&openai_provider, request) + } + + fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_messages_text(&openai_provider, request) + } +} + +impl ProviderResponse for ClaudeProvider { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } + + fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> { + Some(&response.usage) + } + + fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + let openai_provider = OpenAIProvider; + ProviderResponse::extract_usage_counts(&openai_provider, response) + } +} + +impl StreamingResponse for ClaudeProvider { + type Error = OpenAIApiError; + type StreamingIter = OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } +} + impl ProviderInterface for ClaudeProvider { fn has_compatible_api(&self, api_path: &str) -> bool { // TODO: Update when Claude API is fully implemented @@ -24,7 +87,7 @@ impl ProviderInterface for ClaudeProvider { fn parse_request(&self, bytes: &[u8]) -> Result> { // TODO: Implement Claude-specific request parsing - match ChatCompletionsRequest::try_from_bytes(bytes) { + match ProviderRequest::try_from_bytes(self, bytes) { Ok(req) => Ok(req), Err(e) => Err(Box::new(e)), } @@ -32,7 +95,7 @@ impl ProviderInterface for ClaudeProvider { fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { // TODO: Implement Claude-specific response parsing - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { Ok(resp) => Ok(resp), Err(e) => Err(Box::new(e)), } @@ -40,7 +103,7 @@ impl ProviderInterface for ClaudeProvider { fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { // TODO: Implement Claude-specific request serialization - match request.to_provider_bytes(provider_id, mode) { + match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { Ok(bytes) => Ok(bytes), Err(e) => Err(Box::new(e)), } diff --git a/crates/hermesllm/src/providers/deepseek/provider.rs b/crates/hermesllm/src/providers/deepseek/provider.rs index 92cc0fa4..3dad7d94 100644 --- a/crates/hermesllm/src/providers/deepseek/provider.rs +++ b/crates/hermesllm/src/providers/deepseek/provider.rs @@ -1,13 +1,76 @@ //! Deepseek provider implementation use crate::providers::{ProviderInterface, ConversionMode}; -use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; -use crate::providers::traits::{ProviderRequest, ProviderResponse}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; +use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse}; +use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError}; /// Deepseek provider implementation #[derive(Debug, Clone)] pub struct DeepseekProvider; +// Trait implementations that delegate to OpenAI +impl ProviderRequest for DeepseekProvider { + type Error = OpenAIApiError; + + fn try_from_bytes(&self, bytes: &[u8]) -> Result { + let openai_provider = OpenAIProvider; + ProviderRequest::try_from_bytes(&openai_provider, bytes) + } + + fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result, Self::Error> { + let openai_provider = OpenAIProvider; + ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode) + } + + fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str { + &request.model + } + + fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool { + request.stream.unwrap_or_default() + } + + fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) { + let openai_provider = OpenAIProvider; + ProviderRequest::set_streaming_options(&openai_provider, request) + } + + fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_messages_text(&openai_provider, request) + } +} + +impl ProviderResponse for DeepseekProvider { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } + + fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> { + Some(&response.usage) + } + + fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + let openai_provider = OpenAIProvider; + ProviderResponse::extract_usage_counts(&openai_provider, response) + } +} + +impl StreamingResponse for DeepseekProvider { + type Error = OpenAIApiError; + type StreamingIter = OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } +} + impl ProviderInterface for DeepseekProvider { fn has_compatible_api(&self, api_path: &str) -> bool { matches!(api_path, "/v1/chat/completions") @@ -18,21 +81,21 @@ impl ProviderInterface for DeepseekProvider { } fn parse_request(&self, bytes: &[u8]) -> Result> { - match ChatCompletionsRequest::try_from_bytes(bytes) { + match ProviderRequest::try_from_bytes(self, bytes) { Ok(req) => Ok(req), Err(e) => Err(Box::new(e)), } } fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { Ok(resp) => Ok(resp), Err(e) => Err(Box::new(e)), } } fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match request.to_provider_bytes(provider_id, mode) { + match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { Ok(bytes) => Ok(bytes), Err(e) => Err(Box::new(e)), } diff --git a/crates/hermesllm/src/providers/gemini/provider.rs b/crates/hermesllm/src/providers/gemini/provider.rs index 55b0c471..14de48ad 100644 --- a/crates/hermesllm/src/providers/gemini/provider.rs +++ b/crates/hermesllm/src/providers/gemini/provider.rs @@ -1,16 +1,79 @@ //! Gemini provider implementation //! -//! TODO: Implement Gemini-specific API format when needed -//! For now, uses OpenAI-compatible format as fallback +//! This module contains the Gemini provider that handles Google's Gemini API format +//! requests in OpenAI-compatible format. use crate::providers::{ProviderInterface, ConversionMode}; -use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; -use crate::providers::traits::{ProviderRequest, ProviderResponse}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; +use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse}; +use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError}; /// Gemini provider implementation #[derive(Debug, Clone)] pub struct GeminiProvider; +// Trait implementations that delegate to OpenAI +impl ProviderRequest for GeminiProvider { + type Error = OpenAIApiError; + + fn try_from_bytes(&self, bytes: &[u8]) -> Result { + let openai_provider = OpenAIProvider; + ProviderRequest::try_from_bytes(&openai_provider, bytes) + } + + fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result, Self::Error> { + let openai_provider = OpenAIProvider; + ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode) + } + + fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str { + &request.model + } + + fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool { + request.stream.unwrap_or_default() + } + + fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) { + let openai_provider = OpenAIProvider; + ProviderRequest::set_streaming_options(&openai_provider, request) + } + + fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_messages_text(&openai_provider, request) + } +} + +impl ProviderResponse for GeminiProvider { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } + + fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> { + Some(&response.usage) + } + + fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + let openai_provider = OpenAIProvider; + ProviderResponse::extract_usage_counts(&openai_provider, response) + } +} + +impl StreamingResponse for GeminiProvider { + type Error = OpenAIApiError; + type StreamingIter = OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } +} + impl ProviderInterface for GeminiProvider { fn has_compatible_api(&self, api_path: &str) -> bool { // TODO: Update when Gemini API is fully implemented @@ -24,7 +87,7 @@ impl ProviderInterface for GeminiProvider { fn parse_request(&self, bytes: &[u8]) -> Result> { // TODO: Implement Gemini-specific request parsing - match ChatCompletionsRequest::try_from_bytes(bytes) { + match ProviderRequest::try_from_bytes(self, bytes) { Ok(req) => Ok(req), Err(e) => Err(Box::new(e)), } @@ -32,7 +95,7 @@ impl ProviderInterface for GeminiProvider { fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { // TODO: Implement Gemini-specific response parsing - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { Ok(resp) => Ok(resp), Err(e) => Err(Box::new(e)), } @@ -40,7 +103,7 @@ impl ProviderInterface for GeminiProvider { fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { // TODO: Implement Gemini-specific request serialization - match request.to_provider_bytes(provider_id, mode) { + match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { Ok(bytes) => Ok(bytes), Err(e) => Err(Box::new(e)), } diff --git a/crates/hermesllm/src/providers/github/provider.rs b/crates/hermesllm/src/providers/github/provider.rs index cbd5bb01..63ef12e4 100644 --- a/crates/hermesllm/src/providers/github/provider.rs +++ b/crates/hermesllm/src/providers/github/provider.rs @@ -1,16 +1,79 @@ //! GitHub provider implementation //! -//! TODO: Implement GitHub-specific API format (/models) when needed -//! For now, uses OpenAI-compatible format as fallback +//! This module contains the GitHub provider that handles GitHub API format +//! requests in OpenAI-compatible format. use crate::providers::{ProviderInterface, ConversionMode}; -use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; -use crate::providers::traits::{ProviderRequest, ProviderResponse}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; +use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse}; +use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError}; /// GitHub provider implementation #[derive(Debug, Clone)] pub struct GitHubProvider; +// Trait implementations that delegate to OpenAI +impl ProviderRequest for GitHubProvider { + type Error = OpenAIApiError; + + fn try_from_bytes(&self, bytes: &[u8]) -> Result { + let openai_provider = OpenAIProvider; + ProviderRequest::try_from_bytes(&openai_provider, bytes) + } + + fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result, Self::Error> { + let openai_provider = OpenAIProvider; + ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode) + } + + fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str { + &request.model + } + + fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool { + request.stream.unwrap_or_default() + } + + fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) { + let openai_provider = OpenAIProvider; + ProviderRequest::set_streaming_options(&openai_provider, request) + } + + fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_messages_text(&openai_provider, request) + } +} + +impl ProviderResponse for GitHubProvider { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } + + fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> { + Some(&response.usage) + } + + fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + let openai_provider = OpenAIProvider; + ProviderResponse::extract_usage_counts(&openai_provider, response) + } +} + +impl StreamingResponse for GitHubProvider { + type Error = OpenAIApiError; + type StreamingIter = OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } +} + impl ProviderInterface for GitHubProvider { fn has_compatible_api(&self, api_path: &str) -> bool { // TODO: Update when GitHub API is fully implemented @@ -24,7 +87,7 @@ impl ProviderInterface for GitHubProvider { fn parse_request(&self, bytes: &[u8]) -> Result> { // TODO: Implement GitHub-specific request parsing - match ChatCompletionsRequest::try_from_bytes(bytes) { + match ProviderRequest::try_from_bytes(self, bytes) { Ok(req) => Ok(req), Err(e) => Err(Box::new(e)), } @@ -32,7 +95,7 @@ impl ProviderInterface for GitHubProvider { fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { // TODO: Implement GitHub-specific response parsing - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { Ok(resp) => Ok(resp), Err(e) => Err(Box::new(e)), } @@ -40,7 +103,7 @@ impl ProviderInterface for GitHubProvider { fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { // TODO: Implement GitHub-specific request serialization - match request.to_provider_bytes(provider_id, mode) { + match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { Ok(bytes) => Ok(bytes), Err(e) => Err(Box::new(e)), } diff --git a/crates/hermesllm/src/providers/groq/provider.rs b/crates/hermesllm/src/providers/groq/provider.rs index 94eb4568..e08a022b 100644 --- a/crates/hermesllm/src/providers/groq/provider.rs +++ b/crates/hermesllm/src/providers/groq/provider.rs @@ -4,13 +4,76 @@ //! Groq uses OpenAI-compatible format but may have provider-specific nuances. use crate::providers::{ProviderInterface, ConversionMode}; -use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; -use crate::providers::traits::{ProviderRequest, ProviderResponse}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; +use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse}; +use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError}; /// Groq provider implementation #[derive(Debug, Clone)] pub struct GroqProvider; +// Trait implementations that delegate to OpenAI +impl ProviderRequest for GroqProvider { + type Error = OpenAIApiError; + + fn try_from_bytes(&self, bytes: &[u8]) -> Result { + let openai_provider = OpenAIProvider; + ProviderRequest::try_from_bytes(&openai_provider, bytes) + } + + fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result, Self::Error> { + let openai_provider = OpenAIProvider; + ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode) + } + + fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str { + &request.model + } + + fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool { + request.stream.unwrap_or_default() + } + + fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) { + let openai_provider = OpenAIProvider; + ProviderRequest::set_streaming_options(&openai_provider, request) + } + + fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_messages_text(&openai_provider, request) + } +} + +impl ProviderResponse for GroqProvider { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } + + fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> { + Some(&response.usage) + } + + fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + let openai_provider = OpenAIProvider; + ProviderResponse::extract_usage_counts(&openai_provider, response) + } +} + +impl StreamingResponse for GroqProvider { + type Error = OpenAIApiError; + type StreamingIter = OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } +} + impl ProviderInterface for GroqProvider { fn has_compatible_api(&self, api_path: &str) -> bool { matches!(api_path, "/v1/chat/completions" | "/openai/v1/chat/completions") @@ -21,21 +84,21 @@ impl ProviderInterface for GroqProvider { } fn parse_request(&self, bytes: &[u8]) -> Result> { - match ChatCompletionsRequest::try_from_bytes(bytes) { + match ProviderRequest::try_from_bytes(self, bytes) { Ok(req) => Ok(req), Err(e) => Err(Box::new(e)), } } fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { Ok(resp) => Ok(resp), Err(e) => Err(Box::new(e)), } } fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match request.to_provider_bytes(provider_id, mode) { + match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { Ok(bytes) => Ok(bytes), Err(e) => Err(Box::new(e)), } diff --git a/crates/hermesllm/src/providers/mistral/provider.rs b/crates/hermesllm/src/providers/mistral/provider.rs index a36d6774..5aa2ab28 100644 --- a/crates/hermesllm/src/providers/mistral/provider.rs +++ b/crates/hermesllm/src/providers/mistral/provider.rs @@ -1,13 +1,76 @@ //! Mistral provider implementation use crate::providers::{ProviderInterface, ConversionMode}; -use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse}; -use crate::providers::traits::{ProviderRequest, ProviderResponse}; +use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage}; +use crate::providers::traits::{ProviderRequest, ProviderResponse, StreamingResponse}; +use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse, OpenAIApiError}; /// Mistral provider implementation #[derive(Debug, Clone)] pub struct MistralProvider; +// Trait implementations that delegate to OpenAI +impl ProviderRequest for MistralProvider { + type Error = OpenAIApiError; + + fn try_from_bytes(&self, bytes: &[u8]) -> Result { + let openai_provider = OpenAIProvider; + ProviderRequest::try_from_bytes(&openai_provider, bytes) + } + + fn to_provider_bytes(&self, request: &ChatCompletionsRequest, provider: super::super::ProviderId, mode: ConversionMode) -> Result, Self::Error> { + let openai_provider = OpenAIProvider; + ProviderRequest::to_provider_bytes(&openai_provider, request, provider, mode) + } + + fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str { + &request.model + } + + fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool { + request.stream.unwrap_or_default() + } + + fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) { + let openai_provider = OpenAIProvider; + ProviderRequest::set_streaming_options(&openai_provider, request) + } + + fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String { + let openai_provider = OpenAIProvider; + ProviderRequest::extract_messages_text(&openai_provider, request) + } +} + +impl ProviderResponse for MistralProvider { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + ProviderResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } + + fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> { + Some(&response.usage) + } + + fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + let openai_provider = OpenAIProvider; + ProviderResponse::extract_usage_counts(&openai_provider, response) + } +} + +impl StreamingResponse for MistralProvider { + type Error = OpenAIApiError; + type StreamingIter = OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], provider: &super::super::ProviderId, mode: ConversionMode) -> Result { + let openai_provider = OpenAIProvider; + StreamingResponse::try_from_bytes(&openai_provider, bytes, provider, mode) + } +} + impl ProviderInterface for MistralProvider { fn has_compatible_api(&self, api_path: &str) -> bool { matches!(api_path, "/v1/chat/completions") @@ -18,21 +81,21 @@ impl ProviderInterface for MistralProvider { } fn parse_request(&self, bytes: &[u8]) -> Result> { - match ChatCompletionsRequest::try_from_bytes(bytes) { + match ProviderRequest::try_from_bytes(self, bytes) { Ok(req) => Ok(req), Err(e) => Err(Box::new(e)), } } fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { Ok(resp) => Ok(resp), Err(e) => Err(Box::new(e)), } } fn request_to_bytes(&self, request: &ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - match request.to_provider_bytes(provider_id, mode) { + match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { Ok(bytes) => Ok(bytes), Err(e) => Err(Box::new(e)), } diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index eafebe26..45eb8d6d 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -134,18 +134,140 @@ impl Provider { Provider::GitHub(_, id) => *id, } } +} - /// Get the provider interface implementation - pub fn interface(&self) -> &dyn ProviderInterface { +// Implement traits directly on the Provider enum +impl ProviderRequest for Provider { + type Error = openai::provider::OpenAIApiError; + + fn try_from_bytes(&self, bytes: &[u8]) -> Result { match self { - Provider::OpenAI(provider, _) => provider, - Provider::Groq(provider, _) => provider, - Provider::Mistral(provider, _) => provider, - Provider::Deepseek(provider, _) => provider, - Provider::Arch(provider, _) => provider, - Provider::Gemini(provider, _) => provider, - Provider::Claude(provider, _) => provider, - Provider::GitHub(provider, _) => provider, + Provider::OpenAI(provider, _) => ProviderRequest::try_from_bytes(provider, bytes), + Provider::Groq(provider, _) => ProviderRequest::try_from_bytes(provider, bytes), + Provider::Mistral(provider, _) => ProviderRequest::try_from_bytes(provider, bytes), + Provider::Deepseek(provider, _) => ProviderRequest::try_from_bytes(provider, bytes), + Provider::Arch(provider, _) => ProviderRequest::try_from_bytes(provider, bytes), + Provider::Gemini(provider, _) => ProviderRequest::try_from_bytes(provider, bytes), + Provider::Claude(provider, _) => ProviderRequest::try_from_bytes(provider, bytes), + Provider::GitHub(provider, _) => ProviderRequest::try_from_bytes(provider, bytes), + } + } + + fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result, Self::Error> { + match self { + Provider::OpenAI(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode), + Provider::Groq(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode), + Provider::Mistral(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode), + Provider::Deepseek(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode), + Provider::Arch(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode), + Provider::Gemini(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode), + Provider::Claude(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode), + Provider::GitHub(provider, _) => ProviderRequest::to_provider_bytes(provider, request, provider_id, mode), + } + } + + fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str { + // Since all providers use the same implementation, just use the first one + &request.model + } + + fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool { + // Since all providers use the same implementation, just use the first one + request.stream.unwrap_or_default() + } + + fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) { + match self { + Provider::OpenAI(provider, _) => ProviderRequest::set_streaming_options(provider, request), + Provider::Groq(provider, _) => ProviderRequest::set_streaming_options(provider, request), + Provider::Mistral(provider, _) => ProviderRequest::set_streaming_options(provider, request), + Provider::Deepseek(provider, _) => ProviderRequest::set_streaming_options(provider, request), + Provider::Arch(provider, _) => ProviderRequest::set_streaming_options(provider, request), + Provider::Gemini(provider, _) => ProviderRequest::set_streaming_options(provider, request), + Provider::Claude(provider, _) => ProviderRequest::set_streaming_options(provider, request), + Provider::GitHub(provider, _) => ProviderRequest::set_streaming_options(provider, request), + } + } + + fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String { + match self { + Provider::OpenAI(provider, _) => ProviderRequest::extract_messages_text(provider, request), + Provider::Groq(provider, _) => ProviderRequest::extract_messages_text(provider, request), + Provider::Mistral(provider, _) => ProviderRequest::extract_messages_text(provider, request), + Provider::Deepseek(provider, _) => ProviderRequest::extract_messages_text(provider, request), + Provider::Arch(provider, _) => ProviderRequest::extract_messages_text(provider, request), + Provider::Gemini(provider, _) => ProviderRequest::extract_messages_text(provider, request), + Provider::Claude(provider, _) => ProviderRequest::extract_messages_text(provider, request), + Provider::GitHub(provider, _) => ProviderRequest::extract_messages_text(provider, request), + } + } +} + +impl ProviderResponse for Provider { + type Error = openai::provider::OpenAIApiError; + type Usage = crate::apis::openai::Usage; + + fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result { + match self { + Provider::OpenAI(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Groq(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Mistral(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Deepseek(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Arch(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Gemini(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Claude(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::GitHub(provider, _) => ProviderResponse::try_from_bytes(provider, bytes, provider_id, mode), + } + } + + fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage> { + // Since all providers use the same implementation, just use the direct implementation + Some(&response.usage) + } +} + +impl StreamingResponse for Provider { + type Error = openai::provider::OpenAIApiError; + type StreamingIter = openai::provider::OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], provider_id: &super::ProviderId, mode: ConversionMode) -> Result { + match self { + Provider::OpenAI(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Groq(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Mistral(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Deepseek(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Arch(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Gemini(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::Claude(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode), + Provider::GitHub(provider, _) => StreamingResponse::try_from_bytes(provider, bytes, provider_id, mode), + } + } +} + +impl ProviderInterface for Provider { + fn has_compatible_api(&self, api_path: &str) -> bool { + match self { + Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path), + Provider::Groq(provider, _) => provider.has_compatible_api(api_path), + Provider::Mistral(provider, _) => provider.has_compatible_api(api_path), + Provider::Deepseek(provider, _) => provider.has_compatible_api(api_path), + Provider::Arch(provider, _) => provider.has_compatible_api(api_path), + Provider::Gemini(provider, _) => provider.has_compatible_api(api_path), + Provider::Claude(provider, _) => provider.has_compatible_api(api_path), + Provider::GitHub(provider, _) => provider.has_compatible_api(api_path), + } + } + + fn supported_apis(&self) -> Vec<&'static str> { + match self { + Provider::OpenAI(provider, _) => provider.supported_apis(), + Provider::Groq(provider, _) => provider.supported_apis(), + Provider::Mistral(provider, _) => provider.supported_apis(), + Provider::Deepseek(provider, _) => provider.supported_apis(), + Provider::Arch(provider, _) => provider.supported_apis(), + Provider::Gemini(provider, _) => provider.supported_apis(), + Provider::Claude(provider, _) => provider.supported_apis(), + Provider::GitHub(provider, _) => provider.supported_apis(), } } } diff --git a/crates/hermesllm/src/providers/openai/provider.rs b/crates/hermesllm/src/providers/openai/provider.rs index a0f0179d..129660e8 100644 --- a/crates/hermesllm/src/providers/openai/provider.rs +++ b/crates/hermesllm/src/providers/openai/provider.rs @@ -77,64 +77,58 @@ impl ProviderInterface for OpenAIProvider { } fn parse_request(&self, bytes: &[u8]) -> Result> { - use crate::providers::traits::ProviderRequest; - match ChatCompletionsRequest::try_from_bytes(bytes) { + match ProviderRequest::try_from_bytes(self, bytes) { Ok(req) => Ok(req), Err(e) => Err(Box::new(e)), } } fn parse_response(&self, bytes: &[u8], provider_id: super::super::ProviderId, mode: ConversionMode) -> Result> { - use crate::providers::traits::ProviderResponse; - match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) { + match ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode) { Ok(resp) => Ok(resp), Err(e) => Err(Box::new(e)), } } fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::super::ProviderId, mode: ConversionMode) -> Result, Box> { - use crate::providers::traits::ProviderRequest; - match request.to_provider_bytes(provider_id, mode) { + match ProviderRequest::to_provider_bytes(self, request, provider_id, mode) { Ok(bytes) => Ok(bytes), Err(e) => Err(Box::new(e)), } } } -// ============================================================================ -// Trait Implementations for OpenAI Types -// ============================================================================ - -impl ProviderRequest for ChatCompletionsRequest { +// Direct trait implementations on OpenAIProvider +impl ProviderRequest for OpenAIProvider { type Error = OpenAIApiError; - fn try_from_bytes(bytes: &[u8]) -> Result { + fn try_from_bytes(&self, bytes: &[u8]) -> Result { let s = std::str::from_utf8(bytes)?; Ok(serde_json::from_str(s)?) } - fn to_provider_bytes(&self, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result, Self::Error> { - Ok(serde_json::to_vec(self)?) + fn to_provider_bytes(&self, request: &ChatCompletionsRequest, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result, Self::Error> { + Ok(serde_json::to_vec(request)?) } - fn extract_model(&self) -> &str { - &self.model + fn extract_model<'a>(&self, request: &'a ChatCompletionsRequest) -> &'a str { + &request.model } - fn is_streaming(&self) -> bool { - self.stream.unwrap_or_default() + fn is_streaming(&self, request: &ChatCompletionsRequest) -> bool { + request.stream.unwrap_or_default() } - fn set_streaming_options(&mut self) { - if self.stream_options.is_none() { - self.stream_options = Some(StreamOptions { + fn set_streaming_options(&self, request: &mut ChatCompletionsRequest) { + if request.stream_options.is_none() { + request.stream_options = Some(StreamOptions { include_usage: Some(true), }); } } - fn extract_messages_text(&self) -> String { - self.messages + fn extract_messages_text(&self, request: &ChatCompletionsRequest) -> String { + request.messages .iter() .fold(String::new(), |acc, m| { acc + " " + &match &m.content { @@ -150,8 +144,41 @@ impl ProviderRequest for ChatCompletionsRequest { } } -// Implement the helper trait for stream context integration -impl crate::providers::traits::StreamContextHelpers for ChatCompletionsRequest {} +impl ProviderResponse for OpenAIProvider { + type Error = OpenAIApiError; + type Usage = Usage; + + fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result { + let s = std::str::from_utf8(bytes)?; + Ok(serde_json::from_str(s)?) + } + + fn usage<'a>(&self, response: &'a ChatCompletionsResponse) -> Option<&'a Self::Usage> { + Some(&response.usage) + } + + fn extract_usage_counts(&self, response: &ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + Some(( + response.usage.prompt_tokens as usize, + response.usage.completion_tokens as usize, + response.usage.total_tokens as usize, + )) + } +} + +impl StreamingResponse for OpenAIProvider { + type Error = OpenAIApiError; + type StreamingIter = OpenAIStreamingResponse; + + fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result { + let s = std::str::from_utf8(bytes)?; + Ok(OpenAIStreamingResponse::new(s.to_string())) + } +} + +// ============================================================================ +// Trait Implementations for OpenAI Types (Keep for TokenUsage only) +// ============================================================================ impl TokenUsage for Usage { fn completion_tokens(&self) -> usize { @@ -167,20 +194,6 @@ impl TokenUsage for Usage { } } -impl ProviderResponse for ChatCompletionsResponse { - type Error = OpenAIApiError; - type Usage = Usage; - - fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result { - let s = std::str::from_utf8(bytes)?; - Ok(serde_json::from_str(s)?) - } - - fn usage(&self) -> Option<&Self::Usage> { - Some(&self.usage) - } -} - impl StreamChunk for ChatCompletionsStreamResponse { type Usage = Usage; @@ -191,9 +204,9 @@ impl StreamChunk for ChatCompletionsStreamResponse { impl StreamingResponse for OpenAIStreamingResponse { type Error = OpenAIApiError; - type Chunk = ChatCompletionsStreamResponse; + type StreamingIter = OpenAIStreamingResponse; - fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result { + fn try_from_bytes(&self, bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result { let s = std::str::from_utf8(bytes)?; Ok(OpenAIStreamingResponse::new(s.to_string())) } diff --git a/crates/hermesllm/src/providers/traits.rs b/crates/hermesllm/src/providers/traits.rs index 792c8740..b1d1628b 100644 --- a/crates/hermesllm/src/providers/traits.rs +++ b/crates/hermesllm/src/providers/traits.rs @@ -15,26 +15,26 @@ pub enum ConversionMode { } /// Trait for provider-specific request types -pub trait ProviderRequest: Sized { +pub trait ProviderRequest { type Error: Error + Send + Sync + 'static; /// Parse request from raw bytes - fn try_from_bytes(bytes: &[u8]) -> Result; + fn try_from_bytes(&self, bytes: &[u8]) -> Result; /// Convert to provider-specific format - fn to_provider_bytes(&self, provider: super::ProviderId, mode: ConversionMode) -> Result, Self::Error>; + fn to_provider_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider: super::ProviderId, mode: ConversionMode) -> Result, Self::Error>; /// Extract the model name from the request - fn extract_model(&self) -> &str; + fn extract_model<'a>(&self, request: &'a crate::apis::openai::ChatCompletionsRequest) -> &'a str; /// Check if this is a streaming request - fn is_streaming(&self) -> bool; + fn is_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool; /// Set streaming options (e.g., include_usage) - fn set_streaming_options(&mut self); + fn set_streaming_options(&self, request: &mut crate::apis::openai::ChatCompletionsRequest); /// Extract text content from messages for token counting - fn extract_messages_text(&self) -> String; + fn extract_messages_text(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String; } /// Trait for token usage information @@ -45,39 +45,19 @@ pub trait TokenUsage { } /// Trait for provider-specific response types -pub trait ProviderResponse: Sized { +pub trait ProviderResponse { type Error: Error + Send + Sync + 'static; type Usage: TokenUsage; /// Parse response from raw bytes - fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result; + fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result; /// Get usage information if available - fn usage(&self) -> Option<&Self::Usage>; + fn usage<'a>(&self, response: &'a crate::apis::openai::ChatCompletionsResponse) -> Option<&'a Self::Usage>; /// Extract token counts for metrics - fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { - self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) - } -} - -/// Helper trait for stream context integration -pub trait StreamContextHelpers: ProviderRequest { - /// Get the model name for routing and metrics - fn get_model_for_routing(&self) -> String { - self.extract_model().to_string() - } - - /// Get text for token counting and rate limiting - fn get_text_for_tokenization(&self) -> String { - self.extract_messages_text() - } - - /// Prepare for streaming by setting appropriate options - fn prepare_for_streaming(&mut self) { - if self.is_streaming() { - self.set_streaming_options(); - } + fn extract_usage_counts(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> { + self.usage(response).map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) } } @@ -90,70 +70,34 @@ pub trait StreamChunk { } /// Trait for streaming response iterators -pub trait StreamingResponse: Iterator> + Sized { +pub trait StreamingResponse { type Error: Error + Send + Sync + 'static; - type Chunk: StreamChunk; + type StreamingIter: Iterator>; /// Parse streaming response from raw bytes - fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result; + fn try_from_bytes(&self, bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result; } -/// Main provider interface trait -pub trait ProviderInterface { +/// Main provider interface trait - simplified to only essential methods +pub trait ProviderInterface: ProviderRequest + ProviderResponse + StreamingResponse { /// Check if this provider has a compatible API with the client request fn has_compatible_api(&self, api_path: &str) -> bool; - /// Get the interface implementation for this provider - /// passthrough: if true, use provider-specific format; if false, use compatible format - fn get_interface(&self, passthrough: bool) -> ConversionMode { - if passthrough { - ConversionMode::Passthrough - } else { - ConversionMode::Compatible - } - } - - /// Parse a request from raw bytes - returns concrete ChatCompletionsRequest - fn parse_request(&self, bytes: &[u8]) -> Result>; - - /// Parse a response from raw bytes - returns concrete ChatCompletionsResponse - fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result>; - - /// Convert a request to bytes for sending to upstream API - fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result, Box>; - - /// Extract model name from request for routing (convenience method for stream_context) - fn extract_model_from_request(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String { - use ProviderRequest; - request.extract_model().to_string() - } - - /// Check if request is streaming (convenience method for stream_context) - fn is_request_streaming(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> bool { - use ProviderRequest; - request.is_streaming() - } - - /// Prepare request for streaming (convenience method for stream_context) - fn prepare_request_for_streaming(&self, request: &mut crate::apis::openai::ChatCompletionsRequest) { - use ProviderRequest; - if request.is_streaming() { - request.set_streaming_options(); - } - } - - /// Extract text for tokenization (convenience method for stream_context) - fn extract_text_for_tokenization(&self, request: &crate::apis::openai::ChatCompletionsRequest) -> String { - use ProviderRequest; - request.extract_messages_text() - } - - /// Extract usage information from response (convenience method for stream_context) - fn extract_usage_from_response(&self, response: &crate::apis::openai::ChatCompletionsResponse) -> Option<(usize, usize, usize)> { - use ProviderResponse; - response.extract_usage_counts() - } - /// Get supported API endpoints for this provider fn supported_apis(&self) -> Vec<&'static str>; + + /// Parse a request from raw bytes - delegates to ProviderRequest + fn parse_request(&self, bytes: &[u8]) -> Result> { + ProviderRequest::try_from_bytes(self, bytes).map_err(|e| Box::new(e) as Box) + } + + /// Parse a response from raw bytes - delegates to ProviderResponse + fn parse_response(&self, bytes: &[u8], provider_id: super::ProviderId, mode: ConversionMode) -> Result> { + ProviderResponse::try_from_bytes(self, bytes, &provider_id, mode).map_err(|e| Box::new(e) as Box) + } + + /// Convert a request to bytes - delegates to ProviderRequest + fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, provider_id: super::ProviderId, mode: ConversionMode) -> Result, Box> { + ProviderRequest::to_provider_bytes(self, request, provider_id, mode).map_err(|e| Box::new(e) as Box) + } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index f8eb051d..9a483d31 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -10,6 +10,10 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; +use hermesllm::apis::openai::{ContentPart, MessageContent}; +use hermesllm::providers::traits::{ + ProviderRequest, ProviderResponse, StreamChunk, StreamingResponse, TokenUsage, +}; use hermesllm::{ConversionMode, Provider, ProviderId}; use http::StatusCode; use log::{debug, info, warn}; @@ -39,6 +43,7 @@ pub struct StreamContext { request_body_sent_time: Option, traces_queue: Arc>>, overrides: Rc>, + user_message: Option, } impl StreamContext { @@ -66,6 +71,7 @@ impl StreamContext { ttft_time: None, traces_queue, request_body_sent_time: None, + user_message: None, } } fn llm_provider(&self) -> &LlmProvider { @@ -295,7 +301,7 @@ impl HttpContext for StreamContext { let provider = self.get_provider(); - let mut deserialized_body = match provider.interface().parse_request(&body_bytes) { + let mut deserialized_body = match ProviderRequest::try_from_bytes(&provider, &body_bytes) { Ok(deserialized) => deserialized, Err(e) => { debug!( @@ -324,9 +330,29 @@ impl HttpContext for StreamContext { }; // Use the provider interface methods for cleaner interaction - let model_requested = provider - .interface() - .extract_model_from_request(&deserialized_body); + let model_requested = provider.extract_model(&deserialized_body).to_string(); // Convert to owned string + + // Extract user message for tracing + self.user_message = deserialized_body.messages.last().and_then(|msg| { + match &msg.content { + MessageContent::Text(text) => Some(text.clone()), + MessageContent::Parts(parts) => { + // Extract text from content parts, ignoring images + let text_parts: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.clone()), + ContentPart::ImageUrl { .. } => None, + }) + .collect(); + if text_parts.is_empty() { + None + } else { + Some(text_parts.join(" ")) + } + } + } + }); info!( "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}", @@ -336,20 +362,15 @@ impl HttpContext for StreamContext { ); // Use provider interface for streaming detection and setup - if provider - .interface() - .is_request_streaming(&deserialized_body) - { - self.streaming_response = true; - provider - .interface() - .prepare_request_for_streaming(&mut deserialized_body); + self.streaming_response = provider.is_streaming(&deserialized_body); + + // Set streaming options if needed + if self.streaming_response { + provider.set_streaming_options(&mut deserialized_body); } - // Use provider interface for text extraction - let input_tokens_str = provider - .interface() - .extract_text_for_tokenization(&deserialized_body); + // Use provider interface for text extraction (after potential mutation) + let input_tokens_str = provider.extract_messages_text(&deserialized_body); // enforce ratelimits on ingress if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) { self.send_server_error( @@ -364,7 +385,7 @@ impl HttpContext for StreamContext { let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str()); // Convert chat completion request to llm provider specific request using provider interface - let deserialized_body_bytes = match provider.interface().request_to_bytes( + let deserialized_body_bytes = match provider.to_provider_bytes( &deserialized_body, provider.id(), ConversionMode::Compatible, @@ -473,6 +494,11 @@ impl HttpContext for StreamContext { self.llm_provider().name.to_string(), ); + if let Some(user_message) = &self.user_message { + llm_span + .add_attribute("user_message".to_string(), user_message.clone()); + } + if self.ttft_time.is_some() { llm_span.add_event(Event::new( "time_to_first_token".to_string(), @@ -540,36 +566,74 @@ impl HttpContext for StreamContext { let _provider_id = ProviderId::from(llm_provider_str.as_str()); if self.streaming_response { - // TODO: Implement streaming response parsing with new provider structure - warn!( - "Streaming response parsing not yet fully implemented with new provider structure" - ); + debug!("processing streaming response"); - // For now, just compute TTFT and continue - if self.ttft_duration.is_none() { - let current_time = get_current_time().unwrap(); - self.ttft_time = Some(current_time_ns()); - match current_time.duration_since(self.start_time) { - Ok(duration) => { - let duration_ms = duration.as_millis(); - info!( - "on_http_response_body: time to first token: {}ms", - duration_ms - ); - self.ttft_duration = Some(duration); - self.metrics.time_to_first_token.record(duration_ms as u64); - } - Err(e) => { - warn!("SystemTime error: {:?}", e); + // Parse streaming response using OpenAI-compatible format + // Since all providers use OpenAI-compatible streaming format + let provider = self.get_provider(); + let provider_id = + ProviderId::from(self.llm_provider().provider_interface.to_string().as_str()); + + match StreamingResponse::try_from_bytes( + &provider, + &body, + &provider_id, + ConversionMode::Compatible, + ) { + Ok(mut streaming_response) => { + // Process each streaming chunk + while let Some(chunk_result) = streaming_response.next() { + match chunk_result { + Ok(chunk) => { + // Compute TTFT on first chunk + if self.ttft_duration.is_none() { + let current_time = get_current_time().unwrap(); + self.ttft_time = Some(current_time_ns()); + match current_time.duration_since(self.start_time) { + Ok(duration) => { + let duration_ms = duration.as_millis(); + info!( + "on_http_response_body: time to first token: {}ms", + duration_ms + ); + self.ttft_duration = Some(duration); + self.metrics + .time_to_first_token + .record(duration_ms as u64); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); + } + } + } + + // Extract usage information if available + if let Some(usage) = chunk.usage() { + let completion_tokens = usage.completion_tokens(); + self.response_tokens += completion_tokens; + debug!( + "Streaming chunk completion tokens: {}", + completion_tokens + ); + } + } + Err(e) => { + warn!("Error processing streaming chunk: {}", e); + } + } } } + Err(e) => { + warn!("Failed to parse streaming response: {}", e); + } } } else { debug!("non streaming response"); let provider = self.get_provider(); - let response = match provider.interface().parse_response( + let response = match ProviderResponse::try_from_bytes( + &provider, &body, - provider.id(), + &provider.id(), ConversionMode::Compatible, ) { Ok(response) => response, @@ -594,7 +658,7 @@ impl HttpContext for StreamContext { // Use provider interface to extract usage information if let Some((prompt_tokens, completion_tokens, total_tokens)) = - provider.interface().extract_usage_from_response(&response) + provider.extract_usage_counts(&response) { debug!( "Response usage: prompt={}, completion={}, total={}",