From ecf453ed70e644890a872ba74de493f3a2e94e09 Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Thu, 4 Sep 2025 15:13:53 -0700 Subject: [PATCH] /v1/messages works with transformations to and from /v1/chat/completions --- crates/common/src/configuration.rs | 6 +- crates/hermesllm/src/apis/anthropic.rs | 53 +-- crates/hermesllm/src/apis/openai.rs | 65 +-- crates/hermesllm/src/clients/endpoints.rs | 2 +- crates/hermesllm/src/lib.rs | 43 +- crates/hermesllm/src/providers/adapters.rs | 2 +- crates/hermesllm/src/providers/id.rs | 10 +- crates/hermesllm/src/providers/mod.rs | 2 +- crates/hermesllm/src/providers/response.rs | 464 +++++++++++++++++---- crates/llm_gateway/src/stream_context.rs | 102 +++-- 10 files changed, 495 insertions(+), 254 deletions(-) diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 20d2623b..93f4fd38 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -149,8 +149,8 @@ pub struct EmbeddingProviver { pub enum LlmProviderType { #[serde(rename = "arch")] Arch, - #[serde(rename = "claude")] - Claude, + #[serde(rename = "anthropic")] + Anthropic, #[serde(rename = "deepseek")] Deepseek, #[serde(rename = "groq")] @@ -167,7 +167,7 @@ impl Display for LlmProviderType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { LlmProviderType::Arch => write!(f, "arch"), - LlmProviderType::Claude => write!(f, "claude"), + LlmProviderType::Anthropic => write!(f, "anthropic"), LlmProviderType::Deepseek => write!(f, "deepseek"), LlmProviderType::Groq => write!(f, "groq"), LlmProviderType::Gemini => write!(f, "gemini"), diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index c457da41..d83e4968 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -3,11 +3,10 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_with::skip_serializing_none; use std::collections::HashMap; -use std::error::Error; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; -use crate::providers::response::{ProviderStreamResponse, SseStreamIter}; +use crate::providers::response::ProviderStreamResponse; use crate::clients::transformer::ExtractText; // Enum for all supported Anthropic APIs @@ -520,55 +519,6 @@ impl MessagesRole { } } -// Anthropic SSE streaming iterator for MessagesStreamEvent -pub struct AnthropicSseIter -where - I: Iterator, - I::Item: AsRef, -{ - sse_stream: SseStreamIter, -} - -impl AnthropicSseIter -where - I: Iterator, - I::Item: AsRef, -{ - pub fn new(sse_stream: SseStreamIter) -> Self { - Self { sse_stream } - } -} - -impl Iterator for AnthropicSseIter -where - I: Iterator, - I::Item: AsRef, -{ - type Item = Result, Box>; - - fn next(&mut self) -> Option { - for line in &mut self.sse_stream.lines { - let line = line.as_ref(); - if line.is_empty() { - continue; - } - - if line.starts_with("data: ") { - let data = &line[6..]; - if data == "[DONE]" { - return None; - } - // Anthropic-specific parsing of MessagesStreamEvent - match serde_json::from_str::(data) { - Ok(event) => return Some(Ok(Box::new(event))), - Err(e) => return Some(Err(Box::new(e))), - } - } - } - None - } -} - // Implement ProviderStreamResponse for MessagesStreamEvent impl ProviderStreamResponse for MessagesStreamEvent { fn content_delta(&self) -> Option<&str> { @@ -594,6 +544,7 @@ impl ProviderStreamResponse for MessagesStreamEvent { _ => None, } } + } #[cfg(test)] diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index f2534961..287b1cde 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -6,7 +6,7 @@ use std::fmt::Display; use thiserror::Error; use crate::providers::request::{ProviderRequest, ProviderRequestError}; -use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter}; +use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage}; use super::ApiDefinition; use crate::clients::transformer::{ExtractText}; @@ -615,68 +615,6 @@ impl ProviderResponse for ChatCompletionsResponse { } } -// ============================================================================ -// OPENAI SSE STREAMING ITERATOR -// ============================================================================ - -/// OpenAI-specific SSE streaming iterator -/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing -pub struct OpenAISseIter -where - I: Iterator, - I::Item: AsRef, -{ - sse_stream: SseStreamIter, -} - -impl OpenAISseIter -where - I: Iterator, - I::Item: AsRef, -{ - pub fn new(sse_stream: SseStreamIter) -> Self { - Self { sse_stream } - } -} - -impl Iterator for OpenAISseIter -where - I: Iterator, - I::Item: AsRef, -{ - type Item = Result, Box>; - - fn next(&mut self) -> Option { - for line in &mut self.sse_stream.lines { - let line = line.as_ref(); - if line.is_empty() { - continue; - } - - if line.starts_with("data: ") { - let data = &line[6..]; // Remove "data: " prefix - if data == "[DONE]" { - return None; - } - - // Skip ping messages (usually from other providers, but handle gracefully) - if data == r#"{"type": "ping"}"# { - continue; - } - - // OpenAI-specific parsing of ChatCompletionsStreamResponse - match serde_json::from_str::(data) { - Ok(response) => return Some(Ok(Box::new(response))), - Err(e) => return Some(Err(Box::new( - OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data)) - ))), - } - } - } - None - } -} - // Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse impl ProviderStreamResponse for ChatCompletionsStreamResponse { fn content_delta(&self) -> Option<&str> { @@ -702,6 +640,7 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse { Role::Tool => "tool", })) } + } diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index ca751fbd..7f04bd76 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -57,7 +57,7 @@ impl SupportedAPIs { match self { SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => { match provider_id { - ProviderId::Claude => "/v1/messages".to_string(), + ProviderId::Anthropic => "/v1/messages".to_string(), _ => default_endpoint, } } diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index d1ad8c7c..26348378 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -7,7 +7,7 @@ pub mod clients; // Re-export important types and traits pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError}; -pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, ProviderResponseError, TokenUsage}; +pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, SseEvent, SseStreamIter}; pub use providers::id::ProviderId; pub use providers::adapters::{has_compatible_api, supported_apis}; @@ -68,29 +68,38 @@ mod tests { // Test streaming response parsing with sample SSE data let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} -data: [DONE] -"#; + data: [DONE] + "#; use crate::clients::endpoints::SupportedAPIs; let api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); - let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &api, &ProviderId::OpenAI)); - assert!(result.is_ok()); - let mut streaming_response = result.unwrap(); + // Test the new simplified architecture - create SseStreamIter directly + let sse_iter = SseStreamIter::try_from(sse_data.as_bytes()); + assert!(sse_iter.is_ok()); - // Test that we can iterate over chunks - it's just an iterator now! - let first_chunk = streaming_response.next(); - assert!(first_chunk.is_some()); + let mut streaming_iter = sse_iter.unwrap(); - let chunk_result = first_chunk.unwrap(); - assert!(chunk_result.is_ok()); + // Test that we can iterate over SseEvents + let first_event = streaming_iter.next(); + assert!(first_event.is_some()); - let chunk = chunk_result.unwrap(); - assert_eq!(chunk.content_delta(), Some("Hello")); - assert!(!chunk.is_final()); + let sse_event = first_event.unwrap(); - // Test that stream ends properly - let final_chunk = streaming_response.next(); - assert!(final_chunk.is_none()); + // Test SseEvent properties + assert!(!sse_event.is_done()); + assert!(sse_event.data.contains("Hello")); + + // Test that we can parse the event into a provider stream response + let provider_response = sse_event.to_provider_stream_response(&api); + assert!(provider_response.is_ok()); + + let stream_response = provider_response.unwrap(); + assert_eq!(stream_response.content_delta(), Some("Hello")); + assert!(!stream_response.is_final()); + + // Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE]) + let final_event = streaming_iter.next(); + assert!(final_event.is_none()); // Should be None because iterator stops at [DONE] } } diff --git a/crates/hermesllm/src/providers/adapters.rs b/crates/hermesllm/src/providers/adapters.rs index 17b57dc5..09bf7108 100644 --- a/crates/hermesllm/src/providers/adapters.rs +++ b/crates/hermesllm/src/providers/adapters.rs @@ -47,7 +47,7 @@ pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig { adapter_type: AdapterType::OpenAICompatible, } } - ProviderId::Claude => { + ProviderId::Anthropic => { ProviderConfig { supported_apis: &["/v1/messages", "/v1/chat/completions"], adapter_type: AdapterType::AnthropicCompatible, diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index e2f64be7..6ffb0e71 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -10,7 +10,7 @@ pub enum ProviderId { Deepseek, Groq, Gemini, - Claude, + Anthropic, GitHub, Arch, } @@ -23,7 +23,7 @@ impl From<&str> for ProviderId { "deepseek" => ProviderId::Deepseek, "groq" => ProviderId::Groq, "gemini" => ProviderId::Gemini, - "claude" => ProviderId::Claude, + "anthropic" => ProviderId::Anthropic, "github" => ProviderId::GitHub, "arch" => ProviderId::Arch, _ => panic!("Unknown provider: {}", value), @@ -36,9 +36,9 @@ impl ProviderId { pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs { match (self, client_api) { // Claude/Anthropic providers natively support Anthropic APIs - (ProviderId::Claude, SupportedAPIs::AnthropicMessagesAPI(_)) => client_api.clone(), + (ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => client_api.clone(), // Claude/Anthropic providers can also support OpenAI chat completions by mapping to Anthropic Messages - (ProviderId::Claude, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), + (ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), // OpenAI-compatible providers only support OpenAI chat completions (ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), @@ -55,7 +55,7 @@ impl Display for ProviderId { ProviderId::Deepseek => write!(f, "Deepseek"), ProviderId::Groq => write!(f, "Groq"), ProviderId::Gemini => write!(f, "Gemini"), - ProviderId::Claude => write!(f, "Claude"), + ProviderId::Anthropic => write!(f, "Anthropic"), ProviderId::GitHub => write!(f, "GitHub"), ProviderId::Arch => write!(f, "Arch"), } diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 4abccc0c..6b9a3e9c 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -10,5 +10,5 @@ pub mod adapters; pub use id::ProviderId; pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ; -pub use response::{ProviderResponseType, ProviderStreamResponseIter, ProviderResponse, ProviderStreamResponse, TokenUsage }; +pub use response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage }; pub use adapters::*; diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index a153962b..077d54ba 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,29 +1,152 @@ use crate::providers::id::ProviderId; -use serde::Serialize; +use serde::{Serialize, Deserialize}; use std::error::Error; use std::fmt; use std::convert::TryFrom; +use std::str::FromStr; use crate::apis::openai::ChatCompletionsResponse; -use crate::apis::OpenAISseIter; use crate::clients::endpoints::SupportedAPIs; -use crate::apis::anthropic::AnthropicSseIter; use crate::apis::anthropic::MessagesResponse; -#[derive(Serialize)] +#[derive(Serialize, Debug, Clone)] #[serde(untagged)] pub enum ProviderResponseType { ChatCompletionsResponse(ChatCompletionsResponse), MessagesResponse(MessagesResponse), } - - -pub enum ProviderStreamResponseIter { - ChatCompletionsStream(OpenAISseIter>), - MessagesStream(AnthropicSseIter>), +#[derive(Serialize, Debug, Clone)] +#[serde(untagged)] +pub enum ProviderStreamResponseType { + ChatCompletionsStreamResponse(crate::apis::openai::ChatCompletionsStreamResponse), + MessagesStreamEvent(crate::apis::anthropic::MessagesStreamEvent), } +pub trait ProviderResponse: Send + Sync { + /// Get usage information if available - returns dynamic trait object + fn usage(&self) -> Option<&dyn TokenUsage>; + + /// Extract token counts for metrics + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) + } +} + +impl ProviderResponse for ProviderResponseType { + fn usage(&self) -> Option<&dyn TokenUsage> { + match self { + ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(), + ProviderResponseType::MessagesResponse(resp) => resp.usage(), + } + } + + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + match self { + ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(), + ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(), + } + } +} + +pub trait ProviderStreamResponse: Send + Sync { + /// Get the content delta for this chunk + fn content_delta(&self) -> Option<&str>; + + /// Check if this is the final chunk in the stream + fn is_final(&self) -> bool; + + /// Get role information if available + fn role(&self) -> Option<&str>; + +} + +// ============================================================================ +// SSE EVENT CONTAINER +// ============================================================================ + +/// Represents a single Server-Sent Event with the complete wire format +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SseEvent { + #[serde(rename = "data")] + pub data: String, // The JSON payload after "data: " + + #[serde(skip_serializing, skip_deserializing)] + pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n" + + #[serde(skip_serializing, skip_deserializing)] + pub provider_stream_response: Option, // Parsed provider stream response object +} + +impl SseEvent { + /// Check if this event represents the end of the stream + pub fn is_done(&self) -> bool { + self.data == "[DONE]" + } + + /// Check if this event should be skipped during processing + /// This includes ping messages and other provider-specific events that don't contain content + pub fn should_skip(&self) -> bool { + // Skip ping messages (commonly used by providers for connection keep-alive) + self.data == r#"{"type": "ping"}"# + } + + /// Get the parsed provider response if available + pub fn provider_response(&self) -> Option<&ProviderStreamResponseType> { + self.provider_stream_response.as_ref() + } + + /// Parse the data field into a ProviderStreamResponse for the given API + pub fn to_provider_stream_response(&self, client_api: &SupportedAPIs) -> Result, Box> { + if self.is_done() { + return Err("Cannot parse [DONE] event as ProviderStreamResponse".into()); + } + + match client_api { + SupportedAPIs::OpenAIChatCompletions(_) => { + let response: crate::apis::openai::ChatCompletionsStreamResponse = + serde_json::from_str(&self.data)?; + Ok(Box::new(response)) + } + SupportedAPIs::AnthropicMessagesAPI(_) => { + let response: crate::apis::anthropic::MessagesStreamEvent = + serde_json::from_str(&self.data)?; + Ok(Box::new(response)) + } + } + } +} + +impl FromStr for SseEvent { + type Err = SseParseError; + + fn from_str(line: &str) -> Result { + if line.starts_with("data: ") { + let data = line[6..].to_string(); // Remove "data: " prefix + if data.is_empty() { + return Err(SseParseError { + message: "Empty data field is not a valid SSE event".to_string(), + }); + } + // [DONE] marker is a valid SSE event that indicates end of stream + Ok(SseEvent { + data, + raw_line: format!("{}\n\n", line), // Store complete SSE format + provider_stream_response: None, // Will be populated later via TryFrom + }) + } else { + Err(SseParseError { + message: format!("Line does not start with 'data: ': {}", line), + }) + } + } +} + +impl fmt::Display for SseEvent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.raw_line) + } +} // --- Response transformation logic for client API compatibility --- impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { @@ -73,83 +196,141 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { } } -impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderStreamResponseIter { +// Stream response transformation logic for client API compatibility +impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderStreamResponseType { type Error = Box; fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result { let upstream_api = provider_id.compatible_api_for_client(client_api); + + // Step 1: Parse bytes using upstream API format (what the provider actually sent) + // Step 2: Return response type that matches client API format (what client expects) match (&upstream_api, client_api) { + // Upstream sent OpenAI format, client expects OpenAI format - direct pass-through (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let s = std::str::from_utf8(bytes)?; - let lines: Vec = s.lines().map(|line| line.to_string()).collect(); - let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter()); - let iter = crate::apis::openai::OpenAISseIter::new(sse_container); - Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter)) + let resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp)) } + // Upstream sent Anthropic format, client expects Anthropic format - direct pass-through (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let s = std::str::from_utf8(bytes)?; - let lines: Vec = s.lines().map(|line| line.to_string()).collect(); - let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter()); - let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container); - Ok(ProviderStreamResponseIter::MessagesStream(iter)) + let resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) } + // Upstream sent Anthropic format, client expects OpenAI format - need transformation + (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + // Parse as Anthropic Messages stream event first + let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; + + // Transform to OpenAI ChatCompletions stream format using the transformer + let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = anthropic_resp.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(chat_resp)) + } + // Upstream sent OpenAI format, client expects Anthropic format - need transformation (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let s = std::str::from_utf8(bytes)?; - let lines: Vec = s.lines().map(|line| line.to_string()).collect(); - let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter()); - let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container); - Ok(ProviderStreamResponseIter::MessagesStream(iter)) + // Parse as OpenAI ChatCompletions stream response first + let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; + + // Transform to Anthropic Messages stream format using the transformer + let messages_resp: crate::apis::anthropic::MessagesStreamEvent = openai_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(messages_resp)) + } + } + } +} + +// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response +impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for SseEvent { + type Error = Box; + + fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result { + // Convert bytes to string + let body_str = std::str::from_utf8(bytes)?; + let mut sse_event: SseEvent = body_str.parse()?; + + // If not [DONE], parse the data as a provider stream response (business logic layer) + if !sse_event.is_done() { + // Use the new ProviderStreamResponseType::try_from to parse the JSON data + let provider_response = ProviderStreamResponseType::try_from((sse_event.data.as_bytes(), client_api, provider_id))?; + sse_event.provider_stream_response = Some(provider_response); + } + + Ok(sse_event) + } +} + +// TryFrom implementation for transforming SseEvent between API formats +impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent { + type Error = Box; + + fn try_from((mut event, upstream_api, client_api): (SseEvent, &SupportedAPIs, &SupportedAPIs)) -> Result { + // If APIs are the same, no transformation needed + if std::mem::discriminant(upstream_api) == std::mem::discriminant(client_api) { + return Ok(event); + } + + // Handle [DONE] events - they don't need transformation + if event.is_done() { + return Ok(event); + } + + // Transform the data field based on API conversion + let transformed_data = match (upstream_api, client_api) { + (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + // Parse OpenAI response and convert to Anthropic + let openai_response: crate::apis::openai::ChatCompletionsStreamResponse = + serde_json::from_str(&event.data)?; + let anthropic_response: crate::apis::anthropic::MessagesStreamEvent = + openai_response.try_into()?; + serde_json::to_string(&anthropic_response)? } (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let s = std::str::from_utf8(bytes)?; - let lines: Vec = s.lines().map(|line| line.to_string()).collect(); - let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter()); - let iter = crate::apis::openai::OpenAISseIter::new(sse_container); - Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter)) + // Parse Anthropic response and convert to OpenAI + let anthropic_response: crate::apis::anthropic::MessagesStreamEvent = + serde_json::from_str(&event.data)?; + let openai_response: crate::apis::openai::ChatCompletionsStreamResponse = + anthropic_response.try_into()?; + serde_json::to_string(&openai_response)? } - } + _ => { + return Err(format!("Unsupported API transformation: {:?} -> {:?}", upstream_api, client_api).into()); + } + }; + + // Update the event with transformed data and reconstruct raw_line + event.data = transformed_data; + event.raw_line = format!("data: {}", event.data); + + Ok(event) } } -impl Iterator for ProviderStreamResponseIter { - type Item = Result, Box>; - - fn next(&mut self) -> Option { - match self { - ProviderStreamResponseIter::ChatCompletionsStream(iter) => iter.next(), - ProviderStreamResponseIter::MessagesStream(iter) => iter.next(), - } - } -} -pub trait ProviderResponse: Send + Sync { - /// Get usage information if available - returns dynamic trait object - fn usage(&self) -> Option<&dyn TokenUsage>; - - /// Extract token counts for metrics - fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { - self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) +// Into implementation to convert SseEvent to bytes for response buffer +impl Into> for SseEvent { + fn into(self) -> Vec { + format!("{}\n\n", self.raw_line).into_bytes() } } -pub trait ProviderStreamResponse: Send + Sync { - /// Get the content delta for this chunk - fn content_delta(&self) -> Option<&str>; - /// Check if this is the final chunk in the stream - fn is_final(&self) -> bool; - - /// Get role information if available - fn role(&self) -> Option<&str>; +#[derive(Debug)] +pub struct SseParseError { + pub message: String, } +impl fmt::Display for SseParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SSE parse error: {}", self.message) + } +} +impl Error for SseParseError {} // ============================================================================ // GENERIC SSE STREAMING ITERATOR (Container Only) // ============================================================================ /// Generic SSE (Server-Sent Events) streaming iterator container -/// This is just a simple wrapper - actual Iterator implementation is delegated to provider-specific modules +/// Parses raw SSE lines into SseEvent objects pub struct SseStreamIter where I: Iterator, @@ -168,26 +349,41 @@ where } } +// TryFrom implementation to parse bytes into SseStreamIter +impl TryFrom<&[u8]> for SseStreamIter> { + type Error = Box; -impl ProviderResponse for ProviderResponseType { - fn usage(&self) -> Option<&dyn TokenUsage> { - match self { - ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(), - ProviderResponseType::MessagesResponse(resp) => resp.usage(), - } - } - - fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { - match self { - ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(), - ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(), - } + fn try_from(bytes: &[u8]) -> Result { + let s = std::str::from_utf8(bytes)?; + let lines: Vec = s.lines().map(|line| line.to_string()).collect(); + Ok(SseStreamIter::new(lines.into_iter())) } } -// Implement Send + Sync for the enum to match the original trait requirements -unsafe impl Send for ProviderStreamResponseIter {} -unsafe impl Sync for ProviderStreamResponseIter {} +impl Iterator for SseStreamIter +where + I: Iterator, + I::Item: AsRef, +{ + type Item = SseEvent; + + fn next(&mut self) -> Option { + for line in &mut self.lines { + if let Ok(event) = line.as_ref().parse::() { + // Check if this is the [DONE] marker - if so, end the stream + if event.is_done() { + return None; + } + // Skip events that should be filtered at the transport layer + if event.should_skip() { + continue; + } + return Some(event); + } + } + None + } +} /// Trait for token usage information pub trait TokenUsage { @@ -268,7 +464,7 @@ mod tests { "usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 } }); let bytes = serde_json::to_vec(&resp).unwrap(); - let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Claude)); + let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic)); assert!(result.is_ok()); match result.unwrap() { ProviderResponseType::MessagesResponse(r) => { @@ -326,7 +522,7 @@ mod tests { "usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 } }); let bytes = serde_json::to_vec(&resp).unwrap(); - let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Claude)); + let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic)); assert!(result.is_ok()); match result.unwrap() { ProviderResponseType::ChatCompletionsResponse(r) => { @@ -337,4 +533,122 @@ mod tests { _ => panic!("Expected ChatCompletionsResponse variant"), } } + + #[test] + fn test_sse_event_parsing() { + // Test valid SSE data line + let line = r#"data: {"id":"test","object":"chat.completion.chunk"}"#; + let event: Result = line.parse(); + assert!(event.is_ok()); + let event = event.unwrap(); + assert_eq!(event.data, r#"{"id":"test","object":"chat.completion.chunk"}"#); + + // Test conversion back to line using Display trait + let wire_format = event.to_string(); + assert_eq!(wire_format, "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"); + + // Test [DONE] marker - should be valid SSE event + let done_line = "data: [DONE]"; + let done_result: Result = done_line.parse(); + assert!(done_result.is_ok()); + let done_event = done_result.unwrap(); + assert_eq!(done_event.data, "[DONE]"); + assert!(done_event.is_done()); // Test the helper method + + // Test non-DONE event + assert!(!event.is_done()); + + // Test empty data - should return error + let empty_line = "data: "; + let empty_result: Result = empty_line.parse(); + assert!(empty_result.is_err()); + + // Test non-data line - should return error + let comment_line = ": this is a comment"; + let comment_result: Result = comment_line.parse(); + assert!(comment_result.is_err()); + } + + #[test] + fn test_sse_event_serde() { + // Test serialization and deserialization with serde + let event = SseEvent { + data: r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string(), + raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"} + + "#.to_string(), + provider_stream_response: None, + }; + + // Test JSON serialization - raw_line should be skipped + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("test")); + assert!(json.contains("chat.completion.chunk")); + assert!(!json.contains("raw_line")); // Should be excluded from serialization + + // Test JSON deserialization + let deserialized: SseEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.data, event.data); + assert_eq!(deserialized.raw_line, ""); // Should be empty since it's skipped + + // Test round trip for data field only + assert_eq!(event.data, deserialized.data); + } + + #[test] + fn test_sse_event_should_skip() { + // Test ping message should be skipped + let ping_event = SseEvent { + data: r#"{"type": "ping"}"#.to_string(), + raw_line: r#"data: {"type": "ping"}"#.to_string(), + provider_stream_response: None, + }; + assert!(ping_event.should_skip()); + assert!(!ping_event.is_done()); + + // Test normal event should not be skipped + let normal_event = SseEvent { + data: r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string(), + raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), + provider_stream_response: None, + }; + assert!(!normal_event.should_skip()); + assert!(!normal_event.is_done()); + + // Test [DONE] event should not be skipped (but is handled separately) + let done_event = SseEvent { + data: "[DONE]".to_string(), + raw_line: "data: [DONE]".to_string(), + provider_stream_response: None, + }; + assert!(!done_event.should_skip()); + assert!(done_event.is_done()); + } + + #[test] + fn test_sse_stream_iter_filters_ping_messages() { + // Create test data with ping messages mixed in + let test_lines = vec![ + "data: {\"id\": \"msg1\", \"object\": \"chat.completion.chunk\"}".to_string(), + "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out + "data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(), + "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out + "data: [DONE]".to_string(), // This should end the stream + ]; + + let mut iter = SseStreamIter::new(test_lines.into_iter()); + + // First event should be msg1 (ping filtered out) + let event1 = iter.next().unwrap(); + assert!(event1.data.contains("msg1")); + assert!(!event1.should_skip()); + + // Second event should be msg2 (ping filtered out) + let event2 = iter.next().unwrap(); + assert!(event2.data.contains("msg2")); + assert!(!event2.should_skip()); + + // Iterator should end at [DONE] (no more events) + assert!(iter.next().is_none()); + } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 415e89f3..7053848e 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -22,7 +22,7 @@ use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use hermesllm::clients::endpoints::SupportedAPIs; -use hermesllm::providers::response::{ProviderResponse, ProviderStreamResponseIter}; +use hermesllm::providers::response::{ProviderResponse, SseEvent, SseStreamIter}; use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType}; pub struct StreamContext { @@ -129,9 +129,19 @@ impl StreamContext { ), })?; - let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); - - self.set_http_request_header("Authorization", Some(&authorization_header_value)); + // Set API-specific headers based on the resolved upstream API + match self.resolved_api.as_ref() { + Some(SupportedAPIs::AnthropicMessagesAPI(_)) => { + // Anthropic API requires x-api-key and anthropic-version headers + self.set_http_request_header("x-api-key", Some(llm_provider_api_key_value)); + self.set_http_request_header("anthropic-version", Some("2023-06-01")); + } + Some(SupportedAPIs::OpenAIChatCompletions(_)) | None => { + // OpenAI and default: use Authorization Bearer token + let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); + self.set_http_request_header("Authorization", Some(&authorization_header_value)); + } + } Ok(()) } @@ -334,7 +344,7 @@ impl StreamContext { fn debug_log_body(&self, body: &[u8]) { if log::log_enabled!(log::Level::Debug) { debug!( - "response data (converted to utf8): {}", + "raw response data (converted to utf8): {}", String::from_utf8_lossy(body) ); } @@ -348,49 +358,67 @@ impl StreamContext { debug!("processing streaming response"); match self.client_api.as_ref() { Some(client_api) => { - match ProviderStreamResponseIter::try_from((body, client_api, &provider_id)) { - Ok(mut streaming_response) => { - while let Some(chunk_result) = streaming_response.next() { - match chunk_result { - Ok(chunk) => { - self.record_ttft_if_needed(); + let client_api = client_api.clone(); // Clone to avoid borrowing issues + let upstream_api = provider_id.compatible_api_for_client(&client_api); - if chunk.is_final() { - debug!("Received final streaming chunk"); - } - if let Some(content) = chunk.content_delta() { - let estimated_tokens = content.len() / 4; - self.response_tokens += estimated_tokens.max(1); - } + // Parse body into SSE iterator using TryFrom + let sse_iter: SseStreamIter> = + match SseStreamIter::try_from(body) { + Ok(iter) => iter, + Err(e) => { + warn!("Failed to parse body into SSE iterator: {}", e); + return Err(Action::Continue); + } + }; + + let mut response_buffer = Vec::new(); + + // Process each SSE event + for sse_event in sse_iter { + // Transform event if upstream API != client API + let transformed_event: SseEvent = + match SseEvent::try_from((sse_event, &upstream_api, &client_api)) { + Ok(event) => event, + Err(e) => { + warn!("Failed to transform SSE event: {}", e); + return Err(Action::Continue); + } + }; + + // Extract ProviderStreamResponse for processing (token counting, etc.) + if !transformed_event.is_done() { + match transformed_event.to_provider_stream_response(&client_api) { + Ok(provider_response) => { + self.record_ttft_if_needed(); + + if provider_response.is_final() { + debug!("Received final streaming chunk"); } - Err(e) => { - warn!("Error processing streaming chunk: {}", e); - return Err(Action::Continue); + + if let Some(content) = provider_response.content_delta() { + let estimated_tokens = content.len() / 4; + self.response_tokens += estimated_tokens.max(1); } } + Err(e) => { + warn!("Error processing streaming chunk: {}", e); + return Err(Action::Continue); + } } } - Err(e) => { - warn!("Failed to parse streaming response: {}", e); - } + + // Add transformed event to response buffer + let bytes: Vec = transformed_event.into(); + response_buffer.extend_from_slice(&bytes); } + + Ok(response_buffer) } None => { warn!("Missing client_api for non-streaming response"); - return Err(Action::Continue); + Err(Action::Continue) } - }; - // NOTE: - // We currently pass-through the original SSE bytes for streaming responses. - // Non-streaming responses are parsed into ProviderResponseType and re-serialized to - // normalize the payload to the client API. Doing the same for streaming would require - // a streaming serializer that emits normalized SSE events for the target client API. - // That doesn't exist yet in hermesllm; implementing it is a follow-up. - // TODO(salmanap): Add a normalized SSE serializer in hermesllm and use it here so both - // streaming and non-streaming paths perform the same compatibility mapping. - // Until then, we keep behavior unchanged and forward upstream SSE as-is. - // For consistency of the method contract, still return Vec. - Ok(body.to_vec()) + } } fn handle_non_streaming_response(