diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 5797d5a2..b243c365 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -912,10 +912,12 @@ version = "0.1.0" dependencies = [ "aws-smithy-eventstream", "bytes", + "log", "serde", "serde_json", "serde_with", "thiserror 2.0.12", + "uuid", ] [[package]] diff --git a/crates/brightstaff/src/handlers/router.rs b/crates/brightstaff/src/handlers/router.rs index 3d01a13b..c369729a 100644 --- a/crates/brightstaff/src/handlers/router.rs +++ b/crates/brightstaff/src/handlers/router.rs @@ -3,7 +3,7 @@ use common::configuration::{ModelAlias, ModelUsagePreference}; use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER}; use hermesllm::apis::openai::ChatCompletionsRequest; use hermesllm::clients::endpoints::SupportedUpstreamAPIs; -use hermesllm::clients::SupportedAPIsFromClients; +use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full}; @@ -39,7 +39,7 @@ pub async fn router_chat( let mut client_request = match ProviderRequestType::try_from(( &chat_request_bytes[..], - &SupportedAPIsFromClients::from_endpoint(request_path.as_str()).unwrap(), + &SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(), )) { Ok(request) => request, Err(err) => { @@ -58,7 +58,7 @@ pub async fn router_chat( let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() { if let Some(model_alias) = model_aliases.get(&model_from_request) { debug!( - "Model Alias: 'From {}' -> 'To{}'", + "Model Alias: 'From {}' -> 'To {}'", model_from_request, model_alias.target ); model_alias.target.clone() diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index ab2390bf..d877fc00 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -10,3 +10,5 @@ serde_with = {version = "3.12.0", features = ["base64"]} thiserror = "2.0.12" aws-smithy-eventstream = "0.60" bytes = "1.10" +uuid = { version = "1.11", features = ["v4"] } +log = "0.4" diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs index eb1f3ddf..252bd0f1 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock.rs +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -7,7 +7,7 @@ use thiserror::Error; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; -use crate::providers::response::ProviderStreamResponse; +use crate::providers::streaming_response::ProviderStreamResponse; // ============================================================================ // AMAZON BEDROCK CONVERSE API ENUMERATION diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index f91b381c..7e1951e4 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -6,7 +6,8 @@ use std::collections::HashMap; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; -use crate::providers::response::{ProviderResponse, ProviderStreamResponse}; +use crate::providers::response::ProviderResponse; +use crate::providers::streaming_response::ProviderStreamResponse; use crate::transforms::lib::ExtractText; use crate::MESSAGES_PATH; diff --git a/crates/hermesllm/src/apis/mod.rs b/crates/hermesllm/src/apis/mod.rs index b3573621..ea056392 100644 --- a/crates/hermesllm/src/apis/mod.rs +++ b/crates/hermesllm/src/apis/mod.rs @@ -1,9 +1,8 @@ pub mod amazon_bedrock; -pub mod amazon_bedrock_binary_frame; pub mod anthropic; pub mod openai; pub mod openai_responses; -pub mod sse; +pub mod streaming_shapes; // Explicit exports to avoid naming conflicts pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest}; diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 7e7036c8..d7f7a07d 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -7,7 +7,8 @@ use thiserror::Error; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; -use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage}; +use crate::providers::response::{ProviderResponse, TokenUsage}; +use crate::providers::streaming_response::ProviderStreamResponse; use crate::transforms::lib::ExtractText; use crate::{CHAT_COMPLETIONS_PATH, OPENAI_RESPONSES_API_PATH}; diff --git a/crates/hermesllm/src/apis/openai_responses.rs b/crates/hermesllm/src/apis/openai_responses.rs index 78728c75..4f0cf663 100644 --- a/crates/hermesllm/src/apis/openai_responses.rs +++ b/crates/hermesllm/src/apis/openai_responses.rs @@ -252,9 +252,12 @@ pub struct ConversationParam { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Tool { - /// Function tool + /// Function tool - flat structure in Responses API Function { - function: FunctionDefinition, + name: String, + description: Option, + parameters: Option, + strict: Option, }, /// File search tool FileSearch { @@ -279,20 +282,6 @@ pub enum Tool { }, } -/// Function definition for function calling -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FunctionDefinition { - /// Function name - pub name: String, - /// Function description - pub description: Option, - /// JSON schema for function parameters - pub parameters: Option, - /// Whether the function is strict - pub strict: Option, -} - /// Ranking options for file search #[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize)] @@ -433,8 +422,8 @@ pub struct ResponsesAPIResponse { } /// Response status -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] pub enum ResponseStatus { Completed, Failed, @@ -573,7 +562,7 @@ pub enum OutputItem { /// Output item status #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "snake_case")] pub enum OutputItemStatus { InProgress, Completed, @@ -731,7 +720,7 @@ pub struct Conversation { #[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] -pub enum ResponseAPIStreamEvent { +pub enum ResponsesAPIStreamEvent { /// Response created #[serde(rename = "response.created")] ResponseCreated { @@ -739,6 +728,13 @@ pub enum ResponseAPIStreamEvent { sequence_number: i32, }, + /// Response in progress + #[serde(rename = "response.in_progress")] + ResponseInProgress { + response: ResponsesAPIResponse, + sequence_number: i32, + }, + /// Response completed #[serde(rename = "response.completed")] ResponseCompleted { @@ -851,6 +847,8 @@ pub enum ResponseAPIStreamEvent { item_id: String, delta: String, sequence_number: i32, + call_id: Option, + name: Option, }, /// Function call arguments done @@ -1089,57 +1087,58 @@ impl ProviderRequest for ResponsesAPIRequest { // Into Implementation for SSE Formatting // ============================================================================ -impl Into for ResponseAPIStreamEvent { +impl Into for ResponsesAPIStreamEvent { fn into(self) -> String { let transformed_json = serde_json::to_string(&self).unwrap_or_default(); let event_type = match &self { - ResponseAPIStreamEvent::ResponseCreated { .. } => "response.created", - ResponseAPIStreamEvent::ResponseCompleted { .. } => "response.completed", - ResponseAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added", - ResponseAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", - ResponseAPIStreamEvent::ResponseContentPartAdded { .. } => { + ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created", + ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress", + ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed", + ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added", + ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", + ResponsesAPIStreamEvent::ResponseContentPartAdded { .. } => { "response.content_part.added" } - ResponseAPIStreamEvent::ResponseContentPartDone { .. } => "response.content_part.done", - ResponseAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", - ResponseAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", - ResponseAPIStreamEvent::ResponseAudioDelta { .. } => "response.audio.delta", - ResponseAPIStreamEvent::ResponseAudioDone { .. } => "response.audio.done", - ResponseAPIStreamEvent::ResponseAudioTranscriptDelta { .. } => { + ResponsesAPIStreamEvent::ResponseContentPartDone { .. } => "response.content_part.done", + ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", + ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", + ResponsesAPIStreamEvent::ResponseAudioDelta { .. } => "response.audio.delta", + ResponsesAPIStreamEvent::ResponseAudioDone { .. } => "response.audio.done", + ResponsesAPIStreamEvent::ResponseAudioTranscriptDelta { .. } => { "response.audio_transcript.delta" } - ResponseAPIStreamEvent::ResponseAudioTranscriptDone { .. } => { + ResponsesAPIStreamEvent::ResponseAudioTranscriptDone { .. } => { "response.audio_transcript.done" } - ResponseAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => { + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => { "response.function_call_arguments.delta" } - ResponseAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => { + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => { "response.function_call_arguments.done" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { .. } => { "response.code_interpreter_call.code.delta" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallCodeDone { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDone { .. } => { "response.code_interpreter_call.code.done" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallInProgress { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallInProgress { .. } => { "response.code_interpreter_call.in_progress" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallInterpreting { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallInterpreting { .. } => { "response.code_interpreter_call.interpreting" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallCompleted { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCompleted { .. } => { "response.code_interpreter_call.completed" } - ResponseAPIStreamEvent::ResponseCustomToolCallInputDelta { .. } => { + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDelta { .. } => { "response.custom_tool_call.input.delta" } - ResponseAPIStreamEvent::ResponseCustomToolCallInputDone { .. } => { + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDone { .. } => { "response.custom_tool_call.input.done" } - ResponseAPIStreamEvent::Error { .. } => "error", - ResponseAPIStreamEvent::Done { .. } => "done", + ResponsesAPIStreamEvent::Error { .. } => "error", + ResponsesAPIStreamEvent::Done { .. } => "done", }; let event = format!("event: {}\n", event_type); @@ -1152,19 +1151,19 @@ impl Into for ResponseAPIStreamEvent { // ProviderStreamResponse Implementation // ============================================================================ -impl crate::providers::response::ProviderStreamResponse for ResponseAPIStreamEvent { +impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesAPIStreamEvent { fn content_delta(&self) -> Option<&str> { match self { - ResponseAPIStreamEvent::ResponseOutputTextDelta { delta, .. } => Some(delta), - ResponseAPIStreamEvent::ResponseAudioDelta { delta, .. } => Some(delta), - ResponseAPIStreamEvent::ResponseAudioTranscriptDelta { delta, .. } => Some(delta), - ResponseAPIStreamEvent::ResponseFunctionCallArgumentsDelta { delta, .. } => { + ResponsesAPIStreamEvent::ResponseOutputTextDelta { delta, .. } => Some(delta), + ResponsesAPIStreamEvent::ResponseAudioDelta { delta, .. } => Some(delta), + ResponsesAPIStreamEvent::ResponseAudioTranscriptDelta { delta, .. } => Some(delta), + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { delta, .. } => { Some(delta) } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { delta, .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { delta, .. } => { Some(delta) } - ResponseAPIStreamEvent::ResponseCustomToolCallInputDelta { delta, .. } => Some(delta), + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDelta { delta, .. } => Some(delta), _ => None, } } @@ -1172,14 +1171,14 @@ impl crate::providers::response::ProviderStreamResponse for ResponseAPIStreamEve fn is_final(&self) -> bool { matches!( self, - ResponseAPIStreamEvent::ResponseCompleted { .. } - | ResponseAPIStreamEvent::Done { .. } + ResponsesAPIStreamEvent::ResponseCompleted { .. } + | ResponsesAPIStreamEvent::Done { .. } ) } fn role(&self) -> Option<&str> { match self { - ResponseAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item { + ResponsesAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item { OutputItem::Message { role, .. } => Some(role.as_str()), _ => None, }, @@ -1189,53 +1188,54 @@ impl crate::providers::response::ProviderStreamResponse for ResponseAPIStreamEve fn event_type(&self) -> Option<&str> { Some(match self { - ResponseAPIStreamEvent::ResponseCreated { .. } => "response.created", - ResponseAPIStreamEvent::ResponseCompleted { .. } => "response.completed", - ResponseAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added", - ResponseAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", - ResponseAPIStreamEvent::ResponseContentPartAdded { .. } => { + ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created", + ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress", + ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed", + ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added", + ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", + ResponsesAPIStreamEvent::ResponseContentPartAdded { .. } => { "response.content_part.added" } - ResponseAPIStreamEvent::ResponseContentPartDone { .. } => "response.content_part.done", - ResponseAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", - ResponseAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", - ResponseAPIStreamEvent::ResponseAudioDelta { .. } => "response.audio.delta", - ResponseAPIStreamEvent::ResponseAudioDone { .. } => "response.audio.done", - ResponseAPIStreamEvent::ResponseAudioTranscriptDelta { .. } => { + ResponsesAPIStreamEvent::ResponseContentPartDone { .. } => "response.content_part.done", + ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", + ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", + ResponsesAPIStreamEvent::ResponseAudioDelta { .. } => "response.audio.delta", + ResponsesAPIStreamEvent::ResponseAudioDone { .. } => "response.audio.done", + ResponsesAPIStreamEvent::ResponseAudioTranscriptDelta { .. } => { "response.audio_transcript.delta" } - ResponseAPIStreamEvent::ResponseAudioTranscriptDone { .. } => { + ResponsesAPIStreamEvent::ResponseAudioTranscriptDone { .. } => { "response.audio_transcript.done" } - ResponseAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => { + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => { "response.function_call_arguments.delta" } - ResponseAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => { + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => { "response.function_call_arguments.done" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDelta { .. } => { "response.code_interpreter_call.code.delta" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallCodeDone { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCodeDone { .. } => { "response.code_interpreter_call.code.done" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallInProgress { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallInProgress { .. } => { "response.code_interpreter_call.in_progress" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallInterpreting { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallInterpreting { .. } => { "response.code_interpreter_call.interpreting" } - ResponseAPIStreamEvent::ResponseCodeInterpreterCallCompleted { .. } => { + ResponsesAPIStreamEvent::ResponseCodeInterpreterCallCompleted { .. } => { "response.code_interpreter_call.completed" } - ResponseAPIStreamEvent::ResponseCustomToolCallInputDelta { .. } => { + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDelta { .. } => { "response.custom_tool_call.input.delta" } - ResponseAPIStreamEvent::ResponseCustomToolCallInputDone { .. } => { + ResponsesAPIStreamEvent::ResponseCustomToolCallInputDone { .. } => { "response.custom_tool_call.input.done" } - ResponseAPIStreamEvent::Error { .. } => "error", - ResponseAPIStreamEvent::Done { .. } => "done", + ResponsesAPIStreamEvent::Error { .. } => "error", + ResponsesAPIStreamEvent::Done { .. } => "done", }) } } @@ -1257,11 +1257,11 @@ mod tests { "obfuscation":"sRhca4PA06" }"#; - let event: ResponseAPIStreamEvent = + let event: ResponsesAPIStreamEvent = serde_json::from_str(json).expect("Failed to deserialize"); match event { - ResponseAPIStreamEvent::ResponseOutputTextDelta { + ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id, output_index, content_index, @@ -1297,11 +1297,11 @@ mod tests { "logprobs":[] }"#; - let event: ResponseAPIStreamEvent = + let event: ResponsesAPIStreamEvent = serde_json::from_str(json).expect("Failed to deserialize"); match event { - ResponseAPIStreamEvent::ResponseOutputTextDone { + ResponsesAPIStreamEvent::ResponseOutputTextDone { item_id, output_index, content_index, @@ -1368,11 +1368,11 @@ mod tests { } }"#; - let event: ResponseAPIStreamEvent = + let event: ResponsesAPIStreamEvent = serde_json::from_str(json).expect("Failed to deserialize"); match event { - ResponseAPIStreamEvent::ResponseCompleted { + ResponsesAPIStreamEvent::ResponseCompleted { response, sequence_number, } => { diff --git a/crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs b/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs similarity index 100% rename from crates/hermesllm/src/apis/amazon_bedrock_binary_frame.rs rename to crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs diff --git a/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs new file mode 100644 index 00000000..5879712b --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs @@ -0,0 +1,447 @@ +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; +use crate::apis::anthropic::MessagesStreamEvent; +use crate::providers::streaming_response::ProviderStreamResponse; + +/// SSE Stream Buffer for Anthropic Messages API streaming. +/// +/// This buffer manages the wire format for Anthropic Messages API streaming, +/// handling the specific event sequencing requirements: +/// - MessageStart → ContentBlockStart → ContentBlockDelta(s) → ContentBlockStop → MessageDelta → MessageStop +/// +/// When converting from OpenAI to Anthropic format, this buffer injects the required +/// ContentBlockStart and ContentBlockStop events to maintain proper Anthropic protocol. +pub struct AnthropicMessagesStreamBuffer { + /// Buffered SSE events ready to be written to wire + buffered_events: Vec, + + /// Track if we've seen a message_start event + message_started: bool, + + /// Track if we've seen a content_block_start event + content_block_started: bool, + + /// Track if we need to inject ContentBlockStop before message_delta + needs_content_block_stop: bool, + + /// Model name to use when generating message_start events + model: String, +} + +impl AnthropicMessagesStreamBuffer { + pub fn new() -> Self { + Self { + buffered_events: Vec::new(), + message_started: false, + content_block_started: false, + needs_content_block_stop: false, + model: "unknown".to_string(), + } + } + + /// Helper to create and format a ContentBlockStart SSE event + fn create_content_block_start_event(&self) -> SseEvent { + let content_block_start = MessagesStreamEvent::ContentBlockStart { + index: 0, + content_block: crate::apis::anthropic::MessagesContentBlock::Text { + text: String::new(), + cache_control: None, + }, + }; + let sse_string: String = content_block_start.into(); + + SseEvent { + data: None, + event: Some("content_block_start".to_string()), + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: None, + } + } + + /// Helper to create and format a MessageStart SSE event + fn create_message_start_event(&self) -> SseEvent { + let message_start = MessagesStreamEvent::MessageStart { + message: crate::apis::anthropic::MessagesStreamMessage { + id: format!("msg_{}", uuid::Uuid::new_v4().to_string().replace("-", "")), + obj_type: "message".to_string(), + role: crate::apis::anthropic::MessagesRole::Assistant, + content: vec![], + model: self.model.clone(), + stop_reason: None, + stop_sequence: None, + usage: crate::apis::anthropic::MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }, + }; + let sse_string: String = message_start.into(); + + SseEvent { + data: None, + event: Some("message_start".to_string()), + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: None, + } + } + + /// Helper to create and format a ContentBlockStop SSE event + fn create_content_block_stop_event(&self) -> SseEvent { + let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 }; + let sse_string: String = content_block_stop.into(); + + SseEvent { + data: None, + event: Some("content_block_stop".to_string()), + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: None, + } + } +} + +impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + // Skip ping messages + if event.should_skip() { + return; + } + + // FIRST: Try to extract model name from the raw event data before transformation + // The provider_stream_response has already been transformed to Anthropic format, + // so we need to extract the model from the original raw data if available + if self.model == "unknown" { + if let Some(data) = &event.data { + // Try to parse as JSON and extract model field + if let Ok(json) = serde_json::from_str::(data) { + if let Some(model) = json.get("model").and_then(|m| m.as_str()) { + self.model = model.to_string(); + } + } + } + } + + // Check if this event has a provider response to determine its type + if let Some(provider_response) = &event.provider_stream_response { + if let Some(event_type) = provider_response.event_type() { + match event_type { + "message_start" => { + // Add the message_start event + self.buffered_events.push(event); + self.message_started = true; + return; + } + "content_block_start" => { + // If we haven't seen message_start yet, inject it first + if !self.message_started { + let message_start = self.create_message_start_event(); + self.buffered_events.push(message_start); + self.message_started = true; + } + + // Add the content_block_start event (from tool calls or other sources) + self.buffered_events.push(event); + self.content_block_started = true; + self.needs_content_block_stop = true; + return; + } + "content_block_delta" => { + // If this is the first content delta and we haven't started yet, + // inject message_start and content_block_start first + if !self.message_started { + // Create and inject message_start event + let message_start = self.create_message_start_event(); + self.buffered_events.push(message_start); + self.message_started = true; + } + + if !self.content_block_started { + // Inject ContentBlockStart after message_start + let content_block_start = self.create_content_block_start_event(); + self.buffered_events.push(content_block_start); + self.content_block_started = true; + self.needs_content_block_stop = true; + } + + // Content deltas are between ContentBlockStart and ContentBlockStop + self.buffered_events.push(event); + return; + } + "message_delta" => { + // Inject ContentBlockStop before message_delta + if self.needs_content_block_stop { + let content_block_stop = self.create_content_block_stop_event(); + self.buffered_events.push(content_block_stop); + self.needs_content_block_stop = false; + } + + // Add the message_delta event + self.buffered_events.push(event); + return; + } + _ => { + // Other event types, just accumulate + self.buffered_events.push(event); + return; + } + } + } + } + + // For events without provider_stream_response or event_type, just accumulate + self.buffered_events.push(event); + } + + fn into_bytes(&mut self) -> Vec { + // Inject ContentBlockStop if needed before flushing + if self.needs_content_block_stop { + let content_block_stop = self.create_content_block_stop_event(); + self.buffered_events.push(content_block_stop); + self.needs_content_block_stop = false; + } + + // Convert all accumulated events to bytes and clear buffer + let mut buffer = Vec::new(); + for event in self.buffered_events.drain(..) { + let event_bytes: Vec = event.into(); + buffer.extend_from_slice(&event_bytes); + } + buffer + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + use crate::apis::streaming_shapes::sse::SseStreamIter; + + #[test] + fn test_openai_to_anthropic_complete_transformation() { + // OpenAI ChatCompletions input that will be transformed to Anthropic Messages API + let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + +data: [DONE]"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 1: OpenAI → Anthropic Messages API Complete Transformation"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (OpenAI ChatCompletions):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation (client wants Anthropic, upstream is OpenAI) + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Parse events and apply transformation + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = AnthropicMessagesStreamBuffer::new(); + + for raw_event in stream_iter { + let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed_event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(!output_bytes.is_empty(), "Should have output"); + assert!(output.contains("event: message_start"), "Should have message_start"); + assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)"); + + let delta_count = output.matches("event: content_block_delta").count(); + assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events"); + + // Verify both pieces of content are present + assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'"); + assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'"); + + assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)"); + assert!(output.contains("event: message_delta"), "Should have message_delta"); + assert!(output.contains("event: message_stop"), "Should have message_stop"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API"); + println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"); + println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count); + println!("✓ Complete stream with message_stop"); + println!("✓ Proper Anthropic protocol sequencing\n"); + } + + #[test] + fn test_openai_to_anthropic_partial_transformation() { + // Partial OpenAI ChatCompletions stream - no [DONE] + let raw_input = r#"data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"The weather"},"finish_reason":null}]} + +data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" in San Francisco"},"finish_reason":null}]} + +data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]}"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 2: OpenAI → Anthropic Partial Transformation (NO [DONE])"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (OpenAI ChatCompletions - NO [DONE]):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Parse and transform events + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = AnthropicMessagesStreamBuffer::new(); + + for raw_event in stream_iter { + let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed_event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(!output_bytes.is_empty(), "Should have output"); + assert!(output.contains("event: message_start"), "Should have message_start"); + assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)"); + + let delta_count = output.matches("event: content_block_delta").count(); + assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events"); + + // Verify all three pieces of content are present + assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta"); + assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta"); + assert!(output.contains("\"text\":\" is\""), "Should have third content delta"); + + // For partial streams, the buffer will inject content_block_stop in into_bytes() + // because needs_content_block_stop is true. This is expected behavior to maintain + // proper Anthropic protocol even for incomplete streams. + assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected at flush)"); + + // Should NOT have completion events + assert!(!output.contains("event: message_delta"), "Should NOT have message_delta"); + assert!(!output.contains("event: message_stop"), "Should NOT have message_stop"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)"); + println!("✓ Injected: message_start, content_block_start at beginning, content_block_stop at flush"); + println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count); + println!("✓ NO message completion events (partial stream, no [DONE])"); + println!("✓ Buffer maintains Anthropic protocol even for incomplete streams\n"); + } + + #[test] + fn test_openai_tool_calling_to_anthropic_transformation() { + // OpenAI ChatCompletions tool calling stream + let raw_input = r#"data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_2Uzw0AEZQeOex2CP2TKjcLKc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"uSpCcO"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"24WSqt08jtf"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"6CleV8twTxkKYg"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"San"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Francisco"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"1XLz89l3v"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"sh"} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" CA"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + +data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"obfuscation":"I"} + +data: [DONE]"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 3: OpenAI Tool Calling → Anthropic Messages API Transformation"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (OpenAI ChatCompletions with Tool Calls):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Parse and transform events + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = AnthropicMessagesStreamBuffer::new(); + + for raw_event in stream_iter { + let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed_event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions for tool calling transformation + assert!(!output_bytes.is_empty(), "Should have output"); + + // Should have lifecycle events (injected by buffer) + assert!(output.contains("event: message_start"), "Should have message_start (injected)"); + assert!(output.contains("event: content_block_start"), "Should have content_block_start"); + assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)"); + assert!(output.contains("event: message_delta"), "Should have message_delta"); + assert!(output.contains("event: message_stop"), "Should have message_stop"); + + // Should have tool_use content block + assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type"); + assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name"); + assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID"); + + // Count input_json_delta events - should match the number of argument chunks + let delta_count = output.matches("event: content_block_delta").count(); + assert!(delta_count >= 8, "Should have at least 8 input_json_delta events"); + + // Verify argument deltas are present + assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type"); + assert!(output.contains("\"partial_json\":"), "Should have partial_json field"); + + // Verify the accumulated arguments contain the location + assert!(output.contains("San"), "Arguments should contain 'San'"); + assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'"); + assert!(output.contains("CA"), "Arguments should contain 'CA'"); + + // Verify stop reason is tool_use + assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Complete tool calling transformation: OpenAI → Anthropic Messages API"); + println!("✓ Injected lifecycle: message_start, content_block_stop"); + println!("✓ Tool metadata: name='get_weather', id='call_2Uzw0AEZQeOex2CP2TKjcLKc'"); + println!("✓ Argument deltas: {} events", delta_count); + println!("✓ Complete JSON arguments: '{{\"location\":\"San Francisco, CA\"}}'"); + println!("✓ Stop reason: tool_use"); + println!("✓ Proper Anthropic tool_use protocol\n"); + } +} diff --git a/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs new file mode 100644 index 00000000..0243a5cd --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs @@ -0,0 +1,39 @@ +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; + +/// OpenAI Chat Completions SSE Stream Buffer for when client and upstream APIs match. +pub struct OpenAIChatCompletionsStreamBuffer { + /// Buffered SSE events ready to be written to wire + buffered_events: Vec, +} + +impl OpenAIChatCompletionsStreamBuffer { + pub fn new() -> Self { + Self { + buffered_events: Vec::new(), + } + } +} + +impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + // Skip ping messages + if event.should_skip() { + return; + } + + // For OpenAI Chat Completions, events are already properly transformed + // Just accumulate them for later wire transmission + self.buffered_events.push(event); + } + + fn into_bytes(&mut self) -> Vec { + // No finalization needed for OpenAI Chat Completions + // The [DONE] marker is already handled by the transformation layer + let mut buffer = Vec::new(); + for event in self.buffered_events.drain(..) { + let event_bytes: Vec = event.into(); + buffer.extend_from_slice(&event_bytes); + } + buffer + } +} diff --git a/crates/hermesllm/src/apis/streaming_shapes/mod.rs b/crates/hermesllm/src/apis/streaming_shapes/mod.rs new file mode 100644 index 00000000..1ef7acc7 --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/mod.rs @@ -0,0 +1,6 @@ +pub mod sse; +pub mod amazon_bedrock_binary_frame; +pub mod anthropic_streaming_buffer; +pub mod chat_completions_streaming_buffer; +pub mod passthrough_streaming_buffer; +pub mod responses_api_streaming_buffer; diff --git a/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs new file mode 100644 index 00000000..2ac2a688 --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs @@ -0,0 +1,95 @@ +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; + +/// Passthrough SSE Stream Buffer for when client and upstream APIs match. +pub struct PassthroughStreamBuffer { + /// Buffered SSE events ready to be written to wire + buffered_events: Vec, +} + +impl PassthroughStreamBuffer { + pub fn new() -> Self { + Self { + buffered_events: Vec::new(), + } + } +} + +impl SseStreamBufferTrait for PassthroughStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + // Skip ping messages + if event.should_skip() { + return; + } + + // Skip events with empty transformed lines (e.g., suppressed event-only lines) + if event.sse_transformed_lines.is_empty() { + return; + } + + // Just accumulate events as-is + self.buffered_events.push(event); + } + + fn into_bytes(&mut self) -> Vec { + // No finalization needed for passthrough - just convert accumulated events to bytes + let mut buffer = Vec::new(); + for event in self.buffered_events.drain(..) { + let event_bytes: Vec = event.into(); + buffer.extend_from_slice(&event_bytes); + } + buffer + } +} + +#[cfg(test)] +mod tests { + use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer; + use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait}; + + #[test] + fn test_chat_completions_passthrough_buffer() { + let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} + + data: [DONE]"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 1: ChatCompletions Passthrough Buffer"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (ChatCompletions):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Parse and process through buffer + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = PassthroughStreamBuffer::new(); + + for event in stream_iter { + buffer.add_transformed_event(event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(!output_bytes.is_empty()); + assert!(output.contains("chatcmpl-123")); + assert!(output.contains("[DONE]")); + assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Passthrough buffer: input = output (no transformation)"); + println!("✓ All events preserved including [DONE]"); + println!("✓ Function calling events preserved\n"); + } +} diff --git a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs new file mode 100644 index 00000000..27e6a199 --- /dev/null +++ b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs @@ -0,0 +1,590 @@ +use std::collections::HashMap; +use crate::apis::openai_responses::{ + ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus, + ResponseStatus, TextConfig, TextFormat, Reasoning, +}; +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; + +/// Helper to convert ResponseAPIStreamEvent to SseEvent +fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent { + let event_type = match &event { + ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created", + ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress", + ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed", + ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added", + ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", + ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", + ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta", + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done", + _ => "unknown", + }; + + let json_data = serde_json::to_string(&event).unwrap_or_default(); + let wire_format: String = event.into(); + + SseEvent { + data: Some(json_data), + event: Some(event_type.to_string()), + raw_line: wire_format.clone(), + sse_transformed_lines: wire_format, + provider_stream_response: None, + } +} + +/// SSE Stream Buffer for ResponsesAPIStreamEvent with full lifecycle management. +/// +/// This buffer manages the wire format for v1/responses streaming, handling +/// delta events and emitting complete lifecycle events. +/// +pub struct ResponsesAPIStreamBuffer { + /// Sequence number for events + sequence_number: i32, + + /// Track item IDs by output index + item_ids: HashMap, + + /// Response metadata + response_id: Option, + model: Option, + created_at: Option, + + /// Lifecycle state flags + created_emitted: bool, + in_progress_emitted: bool, + + /// Track which output items we've added + output_items_added: HashMap, // output_index -> item_id + + /// Accumulated content by item_id + text_content: HashMap, + function_arguments: HashMap, + + /// Tool call metadata by output_index + tool_call_metadata: HashMap, // output_index -> (call_id, name) + + /// Final completed response (for logging/tracing/persistence) + completed_response: Option, + + /// Buffered SSE events ready to be written to wire + buffered_events: Vec, +} + +impl ResponsesAPIStreamBuffer { + pub fn new() -> Self { + Self { + sequence_number: 0, + item_ids: HashMap::new(), + response_id: None, + model: None, + created_at: None, + created_emitted: false, + in_progress_emitted: false, + output_items_added: HashMap::new(), + text_content: HashMap::new(), + function_arguments: HashMap::new(), + tool_call_metadata: HashMap::new(), + completed_response: None, + buffered_events: Vec::new(), + } + } + + fn next_sequence_number(&mut self) -> i32 { + let seq = self.sequence_number; + self.sequence_number += 1; + seq + } + + fn generate_item_id(&self, prefix: &str) -> String { + format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", "")) + } + + fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String { + if let Some(id) = self.item_ids.get(&output_index) { + return id.clone(); + } + let id = self.generate_item_id(prefix); + self.item_ids.insert(output_index, id.clone()); + id + } + + /// Create response.created event + fn create_response_created_event(&mut self) -> SseEvent { + let response = self.build_response(ResponseStatus::InProgress); + let event = ResponsesAPIStreamEvent::ResponseCreated { + response, + sequence_number: self.next_sequence_number(), + }; + event_to_sse(event) + } + + /// Create response.in_progress event + fn create_response_in_progress_event(&mut self) -> SseEvent { + let response = self.build_response(ResponseStatus::InProgress); + let event = ResponsesAPIStreamEvent::ResponseInProgress { + response, + sequence_number: self.next_sequence_number(), + }; + event_to_sse(event) + } + + /// Create output_item.added event for text + fn create_output_item_added_event(&mut self, output_index: i32, item_id: &str) -> SseEvent { + let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded { + output_index, + item: OutputItem::Message { + id: item_id.to_string(), + status: OutputItemStatus::InProgress, + role: "assistant".to_string(), + content: vec![], + }, + sequence_number: self.next_sequence_number(), + }; + event_to_sse(event) + } + + /// Create output_item.added event for tool call + fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent { + let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded { + output_index, + item: OutputItem::FunctionCall { + id: item_id.to_string(), + status: OutputItemStatus::InProgress, + call_id: call_id.to_string(), + name: Some(name.to_string()), + arguments: Some(String::new()), + }, + sequence_number: self.next_sequence_number(), + }; + event_to_sse(event) + } + + /// Build the base response object with current state + fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse { + ResponsesAPIResponse { + id: self.response_id.clone().unwrap_or_default(), + object: "response".to_string(), + created_at: self.created_at.unwrap_or(0), + status, + error: None, + incomplete_details: None, + instructions: None, + model: self.model.clone().unwrap_or_else(|| "unknown".to_string()), + output: vec![], + usage: None, + parallel_tool_calls: true, + conversation: None, + previous_response_id: None, + tools: vec![], + tool_choice: "auto".to_string(), + temperature: 1.0, + top_p: 1.0, + metadata: HashMap::new(), + truncation: Some("disabled".to_string()), + max_output_tokens: None, + reasoning: Some(Reasoning { + effort: None, + summary: None, + }), + store: Some(true), + text: Some(TextConfig { + format: TextFormat::Text, + }), + audio: None, + modalities: None, + service_tier: Some("auto".to_string()), + background: Some(false), + top_logprobs: Some(0), + max_tool_calls: None, + } + } + + /// Get the completed response after finalization (for logging/tracing/persistence) + pub fn get_completed_response(&self) -> Option<&ResponsesAPIResponse> { + self.completed_response.as_ref() + } + + /// Finalize the response by emitting all *.done events and response.completed. + /// Call this when the stream is complete (after seeing [DONE] or end_of_stream). + pub fn finalize(&mut self) { + let mut events = Vec::new(); + + // Emit done events for all accumulated content + + // Text content done events + let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect(); + for (item_id, content) in text_items { + let output_index = self.output_items_added.iter() + .find(|(_, id)| **id == item_id) + .map(|(idx, _)| *idx) + .unwrap_or(0); + + let seq1 = self.next_sequence_number(); + let text_done_event = ResponsesAPIStreamEvent::ResponseOutputTextDone { + item_id: item_id.clone(), + output_index, + content_index: 0, + text: content.clone(), + logprobs: vec![], + sequence_number: seq1, + }; + events.push(event_to_sse(text_done_event)); + + let seq2 = self.next_sequence_number(); + let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone { + output_index, + item: OutputItem::Message { + id: item_id.clone(), + status: OutputItemStatus::Completed, + role: "assistant".to_string(), + content: vec![], + }, + sequence_number: seq2, + }; + events.push(event_to_sse(item_done_event)); + } + + // Function call done events + let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect(); + for (item_id, arguments) in func_items { + let output_index = self.output_items_added.iter() + .find(|(_, id)| **id == item_id) + .map(|(idx, _)| *idx) + .unwrap_or(0); + + let seq1 = self.next_sequence_number(); + let args_done_event = ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { + output_index, + item_id: item_id.clone(), + arguments: arguments.clone(), + sequence_number: seq1, + }; + events.push(event_to_sse(args_done_event)); + + let (call_id, name) = self.tool_call_metadata.get(&output_index) + .cloned() + .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + + let seq2 = self.next_sequence_number(); + let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone { + output_index, + item: OutputItem::FunctionCall { + id: item_id.clone(), + status: OutputItemStatus::Completed, + call_id, + name: Some(name), + arguments: Some(arguments.clone()), + }, + sequence_number: seq2, + }; + events.push(event_to_sse(item_done_event)); + } + + // Build final response + let mut output_items = Vec::new(); + + // Add tool calls to output + for (item_id, arguments) in &self.function_arguments { + let output_index = self.output_items_added.iter() + .find(|(_, id)| *id == item_id) + .map(|(idx, _)| *idx) + .unwrap_or(0); + + let (call_id, name) = self.tool_call_metadata.get(&output_index) + .cloned() + .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + + output_items.push(OutputItem::FunctionCall { + id: item_id.clone(), + status: OutputItemStatus::Completed, + call_id, + name: Some(name), + arguments: Some(arguments.clone()), + }); + } + + let mut final_response = self.build_response(ResponseStatus::Completed); + final_response.output = output_items; + + // Store completed response + self.completed_response = Some(final_response.clone()); + + // Emit response.completed + let seq_final = self.next_sequence_number(); + let completed_event = ResponsesAPIStreamEvent::ResponseCompleted { + response: final_response, + sequence_number: seq_final, + }; + events.push(event_to_sse(completed_event)); + + // Add all finalization events to the buffer + self.buffered_events.extend(events); + } +} + +impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + // Skip ping messages + if event.should_skip() { + return; + } + + // Handle [DONE] marker - trigger finalization + if event.is_done() { + self.finalize(); + return; + } + + // Extract the ResponseAPIStreamEvent from the SseEvent's provider_stream_response + let provider_response = match event.provider_stream_response.as_ref() { + Some(response) => response, + None => { + eprintln!("Warning: Event missing provider_stream_response"); + return; + } + }; + + // Extract ResponseAPIStreamEvent from the enum + let stream_event = match provider_response { + crate::providers::streaming_response::ProviderStreamResponseType::ResponseAPIStreamEvent(evt) => evt, + _ => { + eprintln!("Warning: Expected ResponseAPIStreamEvent in provider_stream_response"); + return; + } + }; + + let mut events = Vec::new(); + + // Emit lifecycle events if not yet emitted + if !self.created_emitted { + // Initialize metadata from first event if needed + if self.response_id.is_none() { + self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))); + self.created_at = Some(std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64); + self.model = Some("unknown".to_string()); // Will be set by caller if available + } + + events.push(self.create_response_created_event()); + self.created_emitted = true; + } + + if !self.in_progress_emitted { + events.push(self.create_response_in_progress_event()); + self.in_progress_emitted = true; + } + + // Process the delta event + match stream_event { + ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => { + let item_id = self.get_or_create_item_id(*output_index, "msg"); + + // Emit output_item.added if this is the first time we see this output index + if !self.output_items_added.contains_key(output_index) { + self.output_items_added.insert(*output_index, item_id.clone()); + events.push(self.create_output_item_added_event(*output_index, &item_id)); + } + + // Accumulate text content + self.text_content.entry(item_id.clone()) + .and_modify(|content| content.push_str(delta)) + .or_insert_with(|| delta.clone()); + + // Emit text delta with filled-in item_id and sequence_number + let mut delta_event = stream_event.clone(); + if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event { + *id = item_id; + *seq = self.next_sequence_number(); + } + events.push(event_to_sse(delta_event)); + } + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => { + let item_id = self.get_or_create_item_id(*output_index, "fc"); + + // Store metadata if provided (from initial tool call event) + if let (Some(cid), Some(n)) = (call_id, name) { + self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone())); + } + + // Emit output_item.added if this is the first time we see this tool call + if !self.output_items_added.contains_key(output_index) { + self.output_items_added.insert(*output_index, item_id.clone()); + + // For tool calls, we need call_id and name from metadata + // These should now be populated from the event itself + let (call_id, name) = self.tool_call_metadata.get(output_index) + .cloned() + .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + + events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name)); + } + + // Accumulate function arguments + self.function_arguments.entry(item_id.clone()) + .and_modify(|args| args.push_str(delta)) + .or_insert_with(|| delta.clone()); + + // Emit function call arguments delta with filled-in item_id and sequence_number + let mut delta_event = stream_event.clone(); + if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event { + *id = item_id; + *seq = self.next_sequence_number(); + } + events.push(event_to_sse(delta_event)); + } + _ => { + // For other event types, just pass through with sequence number + let other_event = stream_event.clone(); + // TODO: Add sequence number to other event types if needed + events.push(event_to_sse(other_event)); + } + } + + // Store all generated events in the buffer + self.buffered_events.extend(events); + } + + + fn into_bytes(&mut self) -> Vec { + // For Responses API, we need special handling: + // - Most events are already in buffered_events from add_transformed_event + // - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream + // - Just flush the accumulated events and clear the buffer + + // Convert all accumulated events to bytes and clear buffer + let mut buffer = Vec::new(); + for event in self.buffered_events.drain(..) { + let event_bytes: Vec = event.into(); + buffer.extend_from_slice(&event_bytes); + } + buffer + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; + use crate::apis::openai::OpenAIApi; + use crate::apis::streaming_shapes::sse::SseStreamIter; + + #[test] + fn test_chat_completions_to_responses_api_transformation() { + // ChatCompletions input that will be transformed to ResponsesAPI + let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]} + + data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: [DONE]"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 2: ChatCompletions → ResponsesAPI Transformation (with [DONE])"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (ChatCompletions):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation + let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Parse events and apply transformation + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = ResponsesAPIStreamBuffer::new(); + + for raw_event in stream_iter { + // Transform the event using the client/upstream APIs + let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed_event); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (ResponsesAPI):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(!output_bytes.is_empty(), "Should have output"); + assert!(output.contains("response.created"), "Should have response.created"); + assert!(output.contains("response.in_progress"), "Should have response.in_progress"); + assert!(output.contains("response.output_item.added"), "Should have output_item.added"); + assert!(output.contains("response.output_text.delta"), "Should have text deltas"); + assert!(output.contains("response.output_text.done"), "Should have text.done"); + assert!(output.contains("response.output_item.done"), "Should have output_item.done"); + assert!(output.contains("response.completed"), "Should have response.completed"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Lifecycle events: response.created, response.in_progress, response.completed"); + println!("✓ Output item lifecycle: output_item.added, output_item.done"); + println!("✓ Text streaming: output_text.delta (2 deltas), output_text.done"); + println!("✓ Complete transformation with finalization ([DONE] processed)\n"); + } + + #[test] + fn test_partial_streaming_incremental_output() { + let raw_input = r#"data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_mD5ggLKk3SMKGPFqFdcpKg6q","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"PCFrpy"} + + data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""} + + data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"TC58A3QEIx8"} + + data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"PK4oFzlVlGTUP5"}"#; + + println!("\n{}", "=".repeat(80)); + println!("TEST 3: Partial Streaming - Function Calling (NO [DONE])"); + println!("{}", "=".repeat(80)); + println!("\nRAW INPUT (ChatCompletions - NO [DONE]):"); + println!("{}", "-".repeat(80)); + println!("{}", raw_input); + + // Setup API configuration for transformation + let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform all events + let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap(); + let mut buffer = ResponsesAPIStreamBuffer::new(); + + for raw_event in stream_iter { + let transformed = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + buffer.add_transformed_event(transformed); + } + + let output_bytes = buffer.into_bytes(); + let output = String::from_utf8_lossy(&output_bytes); + + println!("\nTRANSFORMED OUTPUT (ResponsesAPI):"); + println!("{}", "-".repeat(80)); + println!("{}", output); + + // Assertions + assert!(output.contains("response.created"), "Should have response.created"); + assert!(output.contains("response.in_progress"), "Should have response.in_progress"); + assert!(output.contains("response.output_item.added"), "Should have output_item.added"); + assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type"); + assert!(output.contains("\"name\":\"get_weather\""), "Should have function name"); + assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id"); + + let delta_count = output.matches("event: response.function_call_arguments.delta").count(); + assert_eq!(delta_count, 4, "Should have 4 delta events"); + + assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done"); + assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done"); + assert!(!output.contains("response.completed"), "Should NOT have response.completed"); + + println!("\nVALIDATION SUMMARY:"); + println!("{}", "-".repeat(80)); + println!("✓ Lifecycle events: response.created, response.in_progress"); + println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"); + println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)"); + println!("✓ NO completion events (partial stream, no [DONE])"); + println!("✓ Arguments accumulated: '{{\"location\":\"'\n"); + } +} diff --git a/crates/hermesllm/src/apis/sse.rs b/crates/hermesllm/src/apis/streaming_shapes/sse.rs similarity index 55% rename from crates/hermesllm/src/apis/sse.rs rename to crates/hermesllm/src/apis/streaming_shapes/sse.rs index b8a9b492..6a2485e4 100644 --- a/crates/hermesllm/src/apis/sse.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/sse.rs @@ -1,10 +1,73 @@ -use crate::providers::response::ProviderStreamResponse; -use crate::providers::response::ProviderStreamResponseType; +use crate::providers::streaming_response::ProviderStreamResponse; +use crate::providers::streaming_response::ProviderStreamResponseType; +use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer; +use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer; +use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer; +use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer; use serde::{Deserialize, Serialize}; use std::error::Error; use std::fmt; use std::str::FromStr; +/// Trait defining the interface for SSE stream buffers. +/// +/// This trait is implemented by both the enum `SseStreamBuffer` (for zero-cost dispatch) +/// and individual buffer implementations (for direct use). +/// +pub trait SseStreamBufferTrait: Send + Sync { + /// Add a transformed SSE event to the buffer. + /// + /// The buffer may inject additional events as needed based on internal state. + /// For example, Anthropic buffers inject ContentBlockStart before the first ContentBlockDelta. + /// + /// All events (original + injected) are accumulated internally for the next `into_bytes()` call. + /// + /// # Arguments + /// * `event` - A transformed SSE event to accumulate + fn add_transformed_event(&mut self, event: SseEvent); + + /// Get bytes for all accumulated events since the last call. + /// + /// This method: + /// - Converts all buffered events to wire format bytes + /// - Clears the internal event buffer + /// - Preserves state for subsequent `add_transformed_event()` calls + /// + /// Call this after processing each chunk of upstream events to get bytes for immediate transmission. + /// + /// # Returns + /// Bytes ready for wire transmission (may be empty if no events were accumulated) + fn into_bytes(&mut self) -> Vec; +} + +/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction +pub enum SseStreamBuffer { + Passthrough(PassthroughStreamBuffer), + OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer), + AnthropicMessages(AnthropicMessagesStreamBuffer), + OpenAIResponses(ResponsesAPIStreamBuffer), +} + +impl SseStreamBufferTrait for SseStreamBuffer { + fn add_transformed_event(&mut self, event: SseEvent) { + match self { + Self::Passthrough(buffer) => buffer.add_transformed_event(event), + Self::OpenAIChatCompletions(buffer) => buffer.add_transformed_event(event), + Self::AnthropicMessages(buffer) => buffer.add_transformed_event(event), + Self::OpenAIResponses(buffer) => buffer.add_transformed_event(event), + } + } + + fn into_bytes(&mut self) -> Vec { + match self { + Self::Passthrough(buffer) => buffer.into_bytes(), + Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(), + Self::AnthropicMessages(buffer) => buffer.into_bytes(), + Self::OpenAIResponses(buffer) => buffer.into_bytes(), + } + } +} + // ============================================================================ // SSE EVENT CONTAINER // ============================================================================ @@ -22,7 +85,7 @@ pub struct SseEvent { pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n" #[serde(skip_serializing, skip_deserializing)] - pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n" + pub sse_transformed_lines: String, // The complete line as received including "data: " prefix and "\n\n" #[serde(skip_serializing, skip_deserializing)] pub provider_stream_response: Option, // Parsed provider stream response object @@ -31,7 +94,7 @@ pub struct SseEvent { impl SseEvent { /// Check if this event represents the end of the stream pub fn is_done(&self) -> bool { - self.data == Some("[DONE]".into()) + self.data == Some("[DONE]".into()) || self.event == Some("message_stop".into()) } /// Check if this event should be skipped during processing @@ -61,23 +124,35 @@ impl FromStr for SseEvent { type Err = SseParseError; fn from_str(line: &str) -> Result { - if line.starts_with("data: ") { - let data: String = line[6..].to_string(); // Remove "data: " prefix - if data.is_empty() { + // Trim leading/trailing whitespace for parsing + let trimmed_line = line.trim(); + + // Skip empty or whitespace-only lines (SSE event separators) + if trimmed_line.is_empty() { + return Err(SseParseError { + message: "Empty line (SSE event separator)".to_string(), + }); + } + + if trimmed_line.starts_with("data: ") { + let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix + // Allow empty data content after "data: " prefix + // This handles cases like "data: " followed by newline + if data.trim().is_empty() { return Err(SseParseError { - message: "Empty data field is not a valid SSE event".to_string(), + message: "Empty data field after 'data: ' prefix".to_string(), }); } Ok(SseEvent { data: Some(data), event: None, raw_line: line.to_string(), - sse_transform_buffer: line.to_string(), + // Preserve original line format for passthrough, use trimmed for transformations + sse_transformed_lines: line.to_string(), provider_stream_response: None, }) - } else if line.starts_with("event: ") { - //used by Anthropic - let event_type = line[7..].to_string(); + } else if trimmed_line.starts_with("event: ") { + let event_type = trimmed_line[7..].to_string(); if event_type.is_empty() { return Err(SseParseError { message: "Empty event field is not a valid SSE event".to_string(), @@ -87,12 +162,13 @@ impl FromStr for SseEvent { data: None, event: Some(event_type), raw_line: line.to_string(), - sse_transform_buffer: line.to_string(), + // Preserve original line format for passthrough, use trimmed for transformations + sse_transformed_lines: line.to_string(), provider_stream_response: None, }) } else { Err(SseParseError { - message: format!("Line does not start with 'data: ' or 'event: ': {}", line), + message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line), }) } } @@ -100,14 +176,14 @@ impl FromStr for SseEvent { impl fmt::Display for SseEvent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.sse_transform_buffer) + write!(f, "{}", self.sse_transformed_lines) } } // Into implementation to convert SseEvent to bytes for response buffer impl Into> for SseEvent { fn into(self) -> Vec { - format!("{}\n\n", self.sse_transform_buffer).into_bytes() + format!("{}\n\n", self.sse_transformed_lines).into_bytes() } } diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index ea159b75..09ab262d 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -4,7 +4,7 @@ use std::fmt; /// Unified enum representing all supported API endpoints across providers #[derive(Debug, Clone, PartialEq)] -pub enum SupportedAPIsFromClients { +pub enum SupportedAPIsFromClient { OpenAIChatCompletions(OpenAIApi), AnthropicMessagesAPI(AnthropicApi), OpenAIResponsesAPI(OpenAIApi), @@ -19,16 +19,16 @@ pub enum SupportedUpstreamAPIs { OpenAIResponsesAPI(OpenAIApi), } -impl fmt::Display for SupportedAPIsFromClients { +impl fmt::Display for SupportedAPIsFromClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SupportedAPIsFromClients::OpenAIChatCompletions(api) => { + SupportedAPIsFromClient::OpenAIChatCompletions(api) => { write!(f, "OpenAI ({})", api.endpoint()) } - SupportedAPIsFromClients::AnthropicMessagesAPI(api) => { + SupportedAPIsFromClient::AnthropicMessagesAPI(api) => { write!(f, "Anthropic AI ({})", api.endpoint()) } - SupportedAPIsFromClients::OpenAIResponsesAPI(api) => { + SupportedAPIsFromClient::OpenAIResponsesAPI(api) => { write!(f, "OpenAI Responses ({})", api.endpoint()) } } @@ -57,15 +57,20 @@ impl fmt::Display for SupportedUpstreamAPIs { } } -impl SupportedAPIsFromClients { +impl SupportedAPIsFromClient { /// Create a SupportedApi from an endpoint path pub fn from_endpoint(endpoint: &str) -> Option { if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) { - return Some(SupportedAPIsFromClients::OpenAIChatCompletions(openai_api)); + // Check if this is the Responses API endpoint + if openai_api == OpenAIApi::Responses { + return Some(SupportedAPIsFromClient::OpenAIResponsesAPI(openai_api)); + } + // Otherwise it's ChatCompletions + return Some(SupportedAPIsFromClient::OpenAIChatCompletions(openai_api)); } if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) { - return Some(SupportedAPIsFromClients::AnthropicMessagesAPI(anthropic_api)); + return Some(SupportedAPIsFromClient::AnthropicMessagesAPI(anthropic_api)); } None @@ -74,9 +79,9 @@ impl SupportedAPIsFromClients { /// Get the endpoint path for this API pub fn endpoint(&self) -> &'static str { match self { - SupportedAPIsFromClients::OpenAIChatCompletions(api) => api.endpoint(), - SupportedAPIsFromClients::AnthropicMessagesAPI(api) => api.endpoint(), - SupportedAPIsFromClients::OpenAIResponsesAPI(api) => api.endpoint(), + SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(), + SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(), + SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(), } } @@ -103,8 +108,62 @@ impl SupportedAPIsFromClients { } }; + // Helper function to route based on provider with a specific endpoint suffix + let route_by_provider = |endpoint_suffix: &str| -> String { + match provider_id { + ProviderId::Groq => { + if request_path.starts_with("/v1/") { + build_endpoint("/openai", request_path) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::Zhipu => { + if request_path.starts_with("/v1/") { + build_endpoint("/api/paas/v4", endpoint_suffix) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::Qwen => { + if request_path.starts_with("/v1/") { + build_endpoint("/compatible-mode/v1", endpoint_suffix) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::AzureOpenAI => { + if request_path.starts_with("/v1/") { + let suffix = endpoint_suffix.trim_start_matches('/'); + build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix)) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::Gemini => { + if request_path.starts_with("/v1/") { + build_endpoint("/v1beta/openai", endpoint_suffix) + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + ProviderId::AmazonBedrock => { + if request_path.starts_with("/v1/") { + if !is_streaming { + build_endpoint("", &format!("/model/{}/converse", model_id)) + } else { + build_endpoint("", &format!("/model/{}/converse-stream", model_id)) + } + } else { + build_endpoint("/v1", endpoint_suffix) + } + } + _ => build_endpoint("/v1", endpoint_suffix), + } + }; + match self { - SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { + SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { ProviderId::Anthropic => build_endpoint("/v1", "/messages"), ProviderId::AmazonBedrock => { if request_path.starts_with("/v1/") && !is_streaming { @@ -117,55 +176,19 @@ impl SupportedAPIsFromClients { } _ => build_endpoint("/v1", "/chat/completions"), }, - _ => match provider_id { - ProviderId::Groq => { - if request_path.starts_with("/v1/") { - build_endpoint("/openai", request_path) - } else { - build_endpoint("/v1", "/chat/completions") - } + SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { + // For Responses API, check if provider supports it, otherwise translate to chat/completions + match provider_id { + // OpenAI and compatible providers that support /v1/responses + ProviderId::OpenAI => route_by_provider("/responses"), + // All other providers: translate to /chat/completions + _ => route_by_provider("/chat/completions"), } - ProviderId::Zhipu => { - if request_path.starts_with("/v1/") { - build_endpoint("/api/paas/v4", "/chat/completions") - } else { - build_endpoint("/v1", "/chat/completions") - } - } - ProviderId::Qwen => { - if request_path.starts_with("/v1/") { - build_endpoint("/compatible-mode/v1", "/chat/completions") - } else { - build_endpoint("/v1", "/chat/completions") - } - } - ProviderId::AzureOpenAI => { - if request_path.starts_with("/v1/") { - build_endpoint("/openai/deployments", &format!("/{}/chat/completions?api-version=2025-01-01-preview", model_id)) - } else { - build_endpoint("/v1", "/chat/completions") - } - } - ProviderId::Gemini => { - if request_path.starts_with("/v1/") { - build_endpoint("/v1beta/openai", "/chat/completions") - } else { - build_endpoint("/v1", "/chat/completions") - } - } - ProviderId::AmazonBedrock => { - if request_path.starts_with("/v1/") { - if !is_streaming { - build_endpoint("", &format!("/model/{}/converse", model_id)) - } else { - build_endpoint("", &format!("/model/{}/converse-stream", model_id)) - } - } else { - build_endpoint("/v1", "/chat/completions") - } - } - _ => build_endpoint("/v1", "/chat/completions"), - }, + } + SupportedAPIsFromClient::OpenAIChatCompletions(_) => { + // For Chat Completions API, use the standard chat/completions path + route_by_provider("/chat/completions") + } } } } @@ -207,14 +230,14 @@ mod tests { #[test] fn test_is_supported_endpoint() { // OpenAI endpoints - assert!(SupportedAPIsFromClients::from_endpoint("/v1/chat/completions").is_some()); + assert!(SupportedAPIsFromClient::from_endpoint("/v1/chat/completions").is_some()); // Anthropic endpoints - assert!(SupportedAPIsFromClients::from_endpoint("/v1/messages").is_some()); + assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some()); // Unsupported endpoints - assert!(!SupportedAPIsFromClients::from_endpoint("/v1/unknown").is_some()); - assert!(!SupportedAPIsFromClients::from_endpoint("/v2/chat").is_some()); - assert!(!SupportedAPIsFromClients::from_endpoint("").is_some()); + assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some()); + assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some()); + assert!(!SupportedAPIsFromClient::from_endpoint("").is_some()); } #[test] @@ -273,7 +296,7 @@ mod tests { #[test] fn test_target_endpoint_without_base_url_prefix() { - let api = SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test default OpenAI provider assert_eq!( @@ -350,7 +373,7 @@ mod tests { #[test] fn test_target_endpoint_with_base_url_prefix() { - let api = SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test Zhipu with custom base_url_path_prefix assert_eq!( @@ -415,7 +438,7 @@ mod tests { #[test] fn test_target_endpoint_with_empty_base_url_prefix() { - let api = SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test with just slashes - trims to empty, uses provider default assert_eq!( @@ -444,7 +467,7 @@ mod tests { #[test] fn test_amazon_bedrock_endpoints() { - let api = SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages); + let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); // Test Bedrock non-streaming without prefix assert_eq!( @@ -497,7 +520,7 @@ mod tests { #[test] fn test_anthropic_messages_endpoint() { - let api = SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages); + let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); // Test Anthropic without prefix assert_eq!( @@ -526,7 +549,7 @@ mod tests { #[test] fn test_non_v1_request_paths() { - let api = SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test Groq with non-v1 path (should use default) assert_eq!( @@ -567,7 +590,7 @@ mod tests { #[test] fn test_azure_openai_with_query_params() { - let api = SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); // Test Azure without prefix - should include query params assert_eq!( diff --git a/crates/hermesllm/src/clients/mod.rs b/crates/hermesllm/src/clients/mod.rs index e461b3b9..6841804f 100644 --- a/crates/hermesllm/src/clients/mod.rs +++ b/crates/hermesllm/src/clients/mod.rs @@ -1,9 +1,8 @@ pub mod endpoints; pub mod lib; -pub mod transformer; // Re-export the main items for easier access -pub use endpoints::{identify_provider, SupportedAPIsFromClients}; +pub use endpoints::*; pub use lib::*; // Note: transformer module contains TryFrom trait implementations that are automatically available diff --git a/crates/hermesllm/src/clients/transformer.rs b/crates/hermesllm/src/clients/transformer.rs deleted file mode 100644 index cbb9bbe7..00000000 --- a/crates/hermesllm/src/clients/transformer.rs +++ /dev/null @@ -1,694 +0,0 @@ -// Re-export new transformation modules for backward compatibility - -//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING - -// ============================================================================ -// TESTS -// ============================================================================ - -#[cfg(test)] -mod tests { - use crate::apis::anthropic::*; - use crate::apis::openai::*; - use crate::transforms::*; - use serde_json::json; - type AnthropicMessagesRequest = MessagesRequest; - - #[test] - fn test_anthropic_to_openai_basic_request() { - let anthropic_req = AnthropicMessagesRequest { - model: "claude-3-sonnet-20240229".to_string(), - system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())), - messages: vec![MessagesMessage { - role: MessagesRole::User, - content: MessagesMessageContent::Single("Hello, world!".to_string()), - }], - max_tokens: 1024, - container: None, - mcp_servers: None, - service_tier: None, - thinking: None, - temperature: Some(0.7), - top_p: Some(0.9), - top_k: Some(50), - stream: Some(false), - stop_sequences: Some(vec!["STOP".to_string()]), - tools: None, - tool_choice: None, - metadata: None, - }; - - let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap(); - - assert_eq!(openai_req.model, "claude-3-sonnet-20240229"); - assert_eq!(openai_req.messages.len(), 2); // system + user message - assert_eq!(openai_req.max_completion_tokens, Some(1024)); - assert_eq!(openai_req.temperature, Some(0.7)); - assert_eq!(openai_req.top_p, Some(0.9)); - assert_eq!(openai_req.stream, Some(false)); - assert_eq!(openai_req.stop, Some(vec!["STOP".to_string()])); - } - - #[test] - fn test_roundtrip_consistency() { - // Test that converting back and forth maintains consistency - let original_anthropic = AnthropicMessagesRequest { - model: "claude-3-sonnet".to_string(), - system: Some(MessagesSystemPrompt::Single("System prompt".to_string())), - messages: vec![MessagesMessage { - role: MessagesRole::User, - content: MessagesMessageContent::Single("User message".to_string()), - }], - max_tokens: 1000, - container: None, - mcp_servers: None, - service_tier: None, - thinking: None, - temperature: Some(0.5), - top_p: Some(1.0), - top_k: None, - stream: Some(false), - stop_sequences: None, - tools: None, - tool_choice: None, - metadata: None, - }; - - // Convert to OpenAI and back - let openai_req: ChatCompletionsRequest = original_anthropic.clone().try_into().unwrap(); - let roundtrip_anthropic: AnthropicMessagesRequest = openai_req.try_into().unwrap(); - - // Check key fields are preserved - assert_eq!(original_anthropic.model, roundtrip_anthropic.model); - assert_eq!( - original_anthropic.max_tokens, - roundtrip_anthropic.max_tokens - ); - assert_eq!( - original_anthropic.temperature, - roundtrip_anthropic.temperature - ); - assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p); - assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream); - assert_eq!( - original_anthropic.messages.len(), - roundtrip_anthropic.messages.len() - ); - } - - #[test] - fn test_tool_choice_auto() { - let anthropic_req = AnthropicMessagesRequest { - model: "claude-3".to_string(), - system: None, - messages: vec![], - max_tokens: 100, - container: None, - mcp_servers: None, - service_tier: None, - thinking: None, - temperature: None, - top_p: None, - top_k: None, - stream: None, - stop_sequences: None, - tools: Some(vec![MessagesTool { - name: "test_tool".to_string(), - description: Some("A test tool".to_string()), - input_schema: json!({"type": "object"}), - }]), - tool_choice: Some(MessagesToolChoice { - kind: MessagesToolChoiceType::Auto, - name: None, - disable_parallel_tool_use: Some(true), - }), - metadata: None, - }; - - let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap(); - - assert!(openai_req.tools.is_some()); - assert_eq!(openai_req.tools.as_ref().unwrap().len(), 1); - - if let Some(ToolChoice::Type(choice)) = openai_req.tool_choice { - assert_eq!(choice, ToolChoiceType::Auto); - } else { - panic!("Expected auto tool choice"); - } - - assert_eq!(openai_req.parallel_tool_calls, Some(false)); - } - - #[test] - fn test_default_max_tokens_used_when_openai_has_none() { - // Test that DEFAULT_MAX_TOKENS is used when OpenAI request has no max_tokens - let openai_req = ChatCompletionsRequest { - model: "gpt-4".to_string(), - messages: vec![Message { - role: Role::User, - content: MessageContent::Text("Hello".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - }], - max_tokens: None, // No max_tokens specified - ..Default::default() - }; - - let anthropic_req: AnthropicMessagesRequest = openai_req.try_into().unwrap(); - - assert_eq!(anthropic_req.max_tokens, DEFAULT_MAX_TOKENS); - } - - #[test] - fn test_anthropic_message_start_streaming() { - let event = MessagesStreamEvent::MessageStart { - message: MessagesStreamMessage { - id: "msg_stream_123".to_string(), - obj_type: "message".to_string(), - role: MessagesRole::Assistant, - content: vec![], - model: "claude-3".to_string(), - stop_reason: None, - stop_sequence: None, - usage: MessagesUsage { - input_tokens: 5, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.id, "msg_stream_123"); - assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk")); - assert_eq!(openai_resp.model, "claude-3"); - assert_eq!(openai_resp.choices.len(), 1); - - let choice = &openai_resp.choices[0]; - assert_eq!(choice.index, 0); - assert_eq!(choice.delta.role, Some(Role::Assistant)); - assert_eq!(choice.delta.content, None); - assert_eq!(choice.finish_reason, None); - } - - #[test] - fn test_anthropic_content_block_delta_streaming() { - let event = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: "Hello, world!".to_string(), - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk")); - assert_eq!(openai_resp.choices.len(), 1); - - let choice = &openai_resp.choices[0]; - assert_eq!(choice.index, 0); - assert_eq!(choice.delta.content, Some("Hello, world!".to_string())); - assert_eq!(choice.delta.role, None); - assert_eq!(choice.finish_reason, None); - } - - #[test] - fn test_anthropic_tool_use_streaming() { - // Test tool use start - let tool_start = MessagesStreamEvent::ContentBlockStart { - index: 0, - content_block: MessagesContentBlock::ToolUse { - id: "call_123".to_string(), - name: "get_weather".to_string(), - input: json!({}), - cache_control: None, - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = tool_start.try_into().unwrap(); - - assert_eq!(openai_resp.choices.len(), 1); - let choice = &openai_resp.choices[0]; - assert!(choice.delta.tool_calls.is_some()); - - let tool_calls = choice.delta.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 1); - assert_eq!(tool_calls[0].id, Some("call_123".to_string())); - assert_eq!( - tool_calls[0].function.as_ref().unwrap().name, - Some("get_weather".to_string()) - ); - } - - #[test] - fn test_anthropic_tool_input_delta_streaming() { - let event = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: r#"{"location": "San Francisco"#.to_string(), - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.choices.len(), 1); - let choice = &openai_resp.choices[0]; - assert!(choice.delta.tool_calls.is_some()); - - let tool_calls = choice.delta.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 1); - assert_eq!( - tool_calls[0].function.as_ref().unwrap().arguments, - Some(r#"{"location": "San Francisco"#.to_string()) - ); - } - - #[test] - fn test_anthropic_message_delta_with_usage() { - let event = MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: MessagesStopReason::EndTurn, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 10, - output_tokens: 25, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.choices.len(), 1); - let choice = &openai_resp.choices[0]; - assert_eq!(choice.finish_reason, Some(FinishReason::Stop)); - - assert!(openai_resp.usage.is_some()); - let usage = openai_resp.usage.unwrap(); - assert_eq!(usage.prompt_tokens, 10); - assert_eq!(usage.completion_tokens, 25); - assert_eq!(usage.total_tokens, 35); - } - - #[test] - fn test_anthropic_message_stop_streaming() { - let event = MessagesStreamEvent::MessageStop; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.choices.len(), 1); - let choice = &openai_resp.choices[0]; - assert_eq!(choice.finish_reason, Some(FinishReason::Stop)); - } - - #[test] - fn test_anthropic_ping_streaming() { - let event = MessagesStreamEvent::Ping; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk")); - assert_eq!(openai_resp.choices.len(), 0); // Ping has no choices - } - - #[test] - fn test_openai_to_anthropic_streaming_role_start() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: Some(Role::Assistant), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::MessageStart { message } => { - assert_eq!(message.id, "chatcmpl-123"); - assert_eq!(message.role, MessagesRole::Assistant); - assert_eq!(message.model, "gpt-4"); - } - _ => panic!("Expected MessageStart event"), - } - } - - #[test] - fn test_openai_to_anthropic_streaming_content_delta() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: None, - content: Some("Hello there!".to_string()), - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::ContentBlockDelta { index, delta } => { - assert_eq!(index, 0); - match delta { - MessagesContentDelta::TextDelta { text } => { - assert_eq!(text, "Hello there!"); - } - _ => panic!("Expected TextDelta"), - } - } - _ => panic!("Expected ContentBlockDelta event"), - } - } - - #[test] - fn test_openai_to_anthropic_streaming_tool_calls() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: Some("call_abc123".to_string()), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some("get_current_weather".to_string()), - arguments: Some("".to_string()), - }), - }]), - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::ContentBlockStart { - index, - content_block, - } => { - assert_eq!(index, 0); - match content_block { - MessagesContentBlock::ToolUse { id, name, .. } => { - assert_eq!(id, "call_abc123"); - assert_eq!(name, "get_current_weather"); - } - _ => panic!("Expected ToolUse content block"), - } - } - _ => panic!("Expected ContentBlockStart event"), - } - } - - #[test] - fn test_openai_to_anthropic_streaming_final_usage() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: Some(FinishReason::Stop), - logprobs: None, - }], - usage: Some(Usage { - prompt_tokens: 15, - completion_tokens: 30, - total_tokens: 45, - prompt_tokens_details: None, - completion_tokens_details: None, - }), - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::MessageDelta { delta, usage } => { - assert_eq!(delta.stop_reason, MessagesStopReason::EndTurn); - assert_eq!(usage.input_tokens, 15); - assert_eq!(usage.output_tokens, 30); - } - _ => panic!("Expected MessageDelta event"), - } - } - - #[test] - fn test_openai_empty_choices_to_anthropic_ping() { - let openai_resp = ChatCompletionsStreamResponse { - id: "chatcmpl-123".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "gpt-4".to_string(), - choices: vec![], // Empty choices - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - match anthropic_event { - MessagesStreamEvent::Ping => { - // Expected behavior - } - _ => panic!("Expected Ping event for empty choices"), - } - } - - #[test] - fn test_streaming_roundtrip_consistency() { - // Test that streaming events can roundtrip through conversions - let original_event = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: "Test message".to_string(), - }, - }; - - // Convert to OpenAI and back - let openai_resp: ChatCompletionsStreamResponse = original_event.try_into().unwrap(); - let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - - // Verify the roundtrip maintains the essential information - match roundtrip_event { - MessagesStreamEvent::ContentBlockDelta { index, delta } => { - assert_eq!(index, 0); - match delta { - MessagesContentDelta::TextDelta { text } => { - assert_eq!(text, "Test message"); - } - _ => panic!("Expected TextDelta after roundtrip"), - } - } - _ => panic!("Expected ContentBlockDelta after roundtrip"), - } - } - - #[test] - fn test_streaming_tool_argument_accumulation() { - // Test multiple tool argument deltas that should accumulate - let tool_start = MessagesStreamEvent::ContentBlockStart { - index: 0, - content_block: MessagesContentBlock::ToolUse { - id: "call_weather".to_string(), - name: "get_weather".to_string(), - input: json!({}), - cache_control: None, - }, - }; - - let arg_delta1 = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: r#"{"location": "#.to_string(), - }, - }; - - let arg_delta2 = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: r#"San Francisco", "unit": "fahrenheit"}"#.to_string(), - }, - }; - - // Test that each delta converts properly to OpenAI format - let openai_start: ChatCompletionsStreamResponse = tool_start.try_into().unwrap(); - let openai_delta1: ChatCompletionsStreamResponse = arg_delta1.try_into().unwrap(); - let openai_delta2: ChatCompletionsStreamResponse = arg_delta2.try_into().unwrap(); - - // Verify tool start - let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls[0].id, Some("call_weather".to_string())); - assert_eq!( - tool_calls[0].function.as_ref().unwrap().name, - Some("get_weather".to_string()) - ); - - // Verify argument deltas - let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0] - .function - .as_ref() - .unwrap() - .arguments; - assert_eq!(args1, &Some(r#"{"location": "#.to_string())); - - let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0] - .function - .as_ref() - .unwrap() - .arguments; - assert_eq!( - args2, - &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string()) - ); - } - - #[test] - fn test_streaming_multiple_finish_reasons() { - // Test different finish reasons in streaming - let test_cases = vec![ - (MessagesStopReason::EndTurn, FinishReason::Stop), - (MessagesStopReason::MaxTokens, FinishReason::Length), - (MessagesStopReason::ToolUse, FinishReason::ToolCalls), - (MessagesStopReason::StopSequence, FinishReason::Stop), - ]; - - for (anthropic_reason, expected_openai_reason) in test_cases { - let event = MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_reason.clone(), - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 10, - output_tokens: 20, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - assert_eq!( - openai_resp.choices[0].finish_reason, - Some(expected_openai_reason) - ); - - // Test reverse conversion - let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); - match roundtrip_event { - MessagesStreamEvent::MessageDelta { delta, .. } => { - // Note: Some precision may be lost in roundtrip due to mapping differences - assert!(matches!( - delta.stop_reason, - MessagesStopReason::EndTurn - | MessagesStopReason::MaxTokens - | MessagesStopReason::ToolUse - | MessagesStopReason::StopSequence - )); - } - _ => panic!("Expected MessageDelta after roundtrip"), - } - } - } - - #[test] - fn test_streaming_error_handling() { - // Test that malformed streaming events are handled gracefully - let openai_resp_with_missing_data = ChatCompletionsStreamResponse { - id: "test".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 1234567890, - model: "test".to_string(), - choices: vec![StreamChoice { - index: 0, - delta: MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason: None, - logprobs: None, - }], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - - // Should convert to Ping when no meaningful content - let anthropic_event: MessagesStreamEvent = - openai_resp_with_missing_data.try_into().unwrap(); - assert!(matches!(anthropic_event, MessagesStreamEvent::Ping)); - } - - #[test] - fn test_streaming_content_block_stop() { - let event = MessagesStreamEvent::ContentBlockStop { index: 0 }; - - let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); - - // ContentBlockStop should produce an empty chunk - assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk")); - assert_eq!(openai_resp.choices.len(), 1); - - let choice = &openai_resp.choices[0]; - assert_eq!(choice.delta.role, None); - assert_eq!(choice.delta.content, None); - assert_eq!(choice.delta.tool_calls, None); - assert_eq!(choice.finish_reason, None); - } -} diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 81b29b32..918fd4e9 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -6,14 +6,16 @@ pub mod clients; pub mod providers; pub mod transforms; // Re-export important types and traits -pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; -pub use apis::sse::{SseEvent, SseStreamIter}; +pub use apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; +pub use apis::streaming_shapes::sse::{SseEvent, SseStreamIter}; pub use aws_smithy_eventstream::frame::DecodedFrame; pub use providers::id::ProviderId; pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; pub use providers::response::{ - ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse, - ProviderStreamResponseType, TokenUsage, + ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError +}; +pub use providers::streaming_response::{ + ProviderStreamResponse, ProviderStreamResponseType }; //TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings @@ -43,9 +45,9 @@ mod tests { data: [DONE] "#; - use crate::clients::endpoints::SupportedAPIsFromClients; + use crate::clients::endpoints::SupportedAPIsFromClient; let client_api = - SupportedAPIsFromClients::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); + SupportedAPIsFromClient::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); @@ -80,9 +82,16 @@ mod tests { assert_eq!(stream_response.content_delta(), Some("Hello")); assert!(!stream_response.is_final()); - // Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE]) + // Test that stream ends properly with [DONE] + // The iterator should return the [DONE] event, then None + let done_event = streaming_iter.next(); + assert!(done_event.is_some(), "Should get [DONE] event"); + let done_event = done_event.unwrap(); + assert!(done_event.is_done(), "[DONE] event should be marked as done"); + + // After [DONE], iterator should return None let final_event = streaming_iter.next(); - assert!(final_event.is_none()); // Should be None because iterator stops at [DONE] + assert!(final_event.is_none(), "Iterator should return None after [DONE]"); } /// Test AWS Event Stream decoding for Bedrock ConverseStream responses. diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 49620832..69455eaf 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -1,5 +1,5 @@ use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi}; -use crate::clients::endpoints::{SupportedAPIsFromClients, SupportedUpstreamAPIs}; +use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use std::fmt::Display; /// Provider identifier enum - simple enum for identifying providers @@ -51,21 +51,21 @@ impl ProviderId { /// Given a client API, return the compatible upstream API for this provider pub fn compatible_api_for_client( &self, - client_api: &SupportedAPIsFromClients, + client_api: &SupportedAPIsFromClient, is_streaming: bool, ) -> SupportedUpstreamAPIs { match (self, client_api) { // Claude/Anthropic providers natively support Anthropic APIs - (ProviderId::Anthropic, SupportedAPIsFromClients::AnthropicMessagesAPI(_)) => { + (ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => { SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) } ( ProviderId::Anthropic, - SupportedAPIsFromClients::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), // Anthropic doesn't support Responses API, fall back to chat completions - (ProviderId::Anthropic, SupportedAPIsFromClients::OpenAIResponsesAPI(_)) => { + (ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions) } @@ -85,7 +85,7 @@ impl ProviderId { | ProviderId::Moonshotai | ProviderId::Zhipu | ProviderId::Qwen, - SupportedAPIsFromClients::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), ( @@ -103,21 +103,21 @@ impl ProviderId { | ProviderId::Moonshotai | ProviderId::Zhipu | ProviderId::Qwen, - SupportedAPIsFromClients::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), // OpenAI Responses API - only OpenAI supports this - (ProviderId::OpenAI, SupportedAPIsFromClients::OpenAIResponsesAPI(_)) => { + (ProviderId::OpenAI, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses) } // Non-OpenAI providers: if client requested the Responses API, fall back to Chat Completions - (_, SupportedAPIsFromClients::OpenAIResponsesAPI(_)) => { + (_, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions) } // Amazon Bedrock natively supports Bedrock APIs - (ProviderId::AmazonBedrock, SupportedAPIsFromClients::OpenAIChatCompletions(_)) => { + (ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => { if is_streaming { SupportedUpstreamAPIs::AmazonBedrockConverseStream( AmazonBedrockApi::ConverseStream, @@ -126,7 +126,7 @@ impl ProviderId { SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) } } - (ProviderId::AmazonBedrock, SupportedAPIsFromClients::AnthropicMessagesAPI(_)) => { + (ProviderId::AmazonBedrock, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => { if is_streaming { SupportedUpstreamAPIs::AmazonBedrockConverseStream( AmazonBedrockApi::ConverseStream, diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index 97b14285..4343022f 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -6,7 +6,9 @@ pub mod id; pub mod request; pub mod response; +pub mod streaming_response; pub use id::ProviderId; pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; -pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage}; +pub use response::{ProviderResponse, ProviderResponseType, TokenUsage}; +pub use streaming_response::{ProviderStreamResponse, ProviderStreamResponseType}; diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 75bbb5e2..daeebe70 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -3,7 +3,7 @@ use crate::apis::openai::ChatCompletionsRequest; use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest}; use crate::apis::openai_responses::ResponsesAPIRequest; -use crate::clients::endpoints::SupportedAPIsFromClients; +use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedUpstreamAPIs; use serde_json::Value; @@ -127,13 +127,13 @@ impl ProviderRequest for ProviderRequestType { } /// Parse the client API from a byte slice. -impl TryFrom<(&[u8], &SupportedAPIsFromClients)> for ProviderRequestType { +impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType { type Error = std::io::Error; - fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClients)) -> Result { + fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result { // Use SupportedApi to determine the appropriate request type match client_api { - SupportedAPIsFromClients::OpenAIChatCompletions(_) => { + SupportedAPIsFromClient::OpenAIChatCompletions(_) => { let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -141,13 +141,13 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients)> for ProviderRequestType { chat_completion_request, )) } - SupportedAPIsFromClients::AnthropicMessagesAPI(_) => { + SupportedAPIsFromClient::AnthropicMessagesAPI(_) => { let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderRequestType::MessagesRequest(messages_request)) } - SupportedAPIsFromClients::OpenAIResponsesAPI(_) => { + SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { let responses_apirequest: ResponsesAPIRequest = ResponsesAPIRequest::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -170,14 +170,10 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT // ============================================================================ // ChatCompletionsRequest conversions // ============================================================================ - - // ChatCompletions -> ChatCompletions (pass-through) ( ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_), ) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)), - - // ChatCompletions -> Anthropic Messages ( ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_), @@ -192,8 +188,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::MessagesRequest(messages_req)) } - - // ChatCompletions -> Bedrock Converse ( ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_), @@ -205,8 +199,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::BedrockConverse(bedrock_req)) } - - // ChatCompletions -> Bedrock Converse Stream ( ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), @@ -218,8 +210,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::BedrockConverseStream(bedrock_req)) } - - // ChatCompletions -> ResponsesAPI (not supported) ( ProviderRequestType::ChatCompletionsRequest(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_), @@ -233,14 +223,10 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT // ============================================================================ // MessagesRequest conversions // ============================================================================ - - // MessagesRequest -> MessagesRequest (pass-through) ( ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_), ) => Ok(ProviderRequestType::MessagesRequest(messages_req)), - - // MessagesRequest -> ChatCompletions ( ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_), @@ -256,8 +242,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) } - - // MessagesRequest -> Bedrock Converse ( ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_), @@ -272,8 +256,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::BedrockConverse(bedrock_req)) } - - // MessagesRequest -> Bedrock Converse Stream ( ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), @@ -289,8 +271,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::BedrockConverseStream(bedrock_req)) } - - // MessagesRequest -> ResponsesAPI (not supported) ( ProviderRequestType::MessagesRequest(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_), @@ -304,8 +284,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT // ============================================================================ // ResponsesAPIRequest conversions (only converts TO other formats) // ============================================================================ - - // ResponsesAPI -> ResponsesAPI (pass-through, OpenAI only) ( ProviderRequestType::ResponsesAPIRequest(responses_req), SupportedUpstreamAPIs::OpenAIResponsesAPI(_), @@ -461,7 +439,7 @@ mod tests { use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; use crate::apis::openai::ChatCompletionsRequest; use crate::apis::openai::OpenAIApi::ChatCompletions; - use crate::clients::endpoints::SupportedAPIsFromClients; + use crate::clients::endpoints::SupportedAPIsFromClient; use crate::transforms::lib::ExtractText; use serde_json::json; @@ -475,7 +453,7 @@ mod tests { ] }); let bytes = serde_json::to_vec(&req).unwrap(); - let api = SupportedAPIsFromClients::OpenAIChatCompletions(ChatCompletions); + let api = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions); let result = ProviderRequestType::try_from((bytes.as_slice(), &api)); assert!(result.is_ok()); match result.unwrap() { @@ -498,7 +476,7 @@ mod tests { ] }); let bytes = serde_json::to_vec(&req).unwrap(); - let endpoint = SupportedAPIsFromClients::AnthropicMessagesAPI(Messages); + let endpoint = SupportedAPIsFromClient::AnthropicMessagesAPI(Messages); let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint)); assert!(result.is_ok()); match result.unwrap() { @@ -520,7 +498,7 @@ mod tests { ] }); let bytes = serde_json::to_vec(&req).unwrap(); - let endpoint = SupportedAPIsFromClients::OpenAIChatCompletions(ChatCompletions); + let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions); let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint)); assert!(result.is_ok()); match result.unwrap() { @@ -543,7 +521,7 @@ mod tests { }); let bytes = serde_json::to_vec(&req).unwrap(); // Intentionally use OpenAI endpoint for Anthropic payload - let endpoint = SupportedAPIsFromClients::OpenAIChatCompletions(ChatCompletions); + let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions); let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint)); // Should parse as ChatCompletionsRequest, not error assert!(result.is_ok()); @@ -673,7 +651,7 @@ mod tests { "input": "Hello, how are you?" }); let bytes = serde_json::to_vec(&req).unwrap(); - let api = SupportedAPIsFromClients::OpenAIResponsesAPI(Responses); + let api = SupportedAPIsFromClient::OpenAIResponsesAPI(Responses); let result = ProviderRequestType::try_from((bytes.as_slice(), &api)); assert!(result.is_ok()); match result.unwrap() { diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 2e6f8214..b1d88e58 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -2,26 +2,14 @@ use serde::Serialize; use std::convert::TryFrom; use std::error::Error; use std::fmt; - use crate::apis::amazon_bedrock::ConverseResponse; -use crate::apis::amazon_bedrock::ConverseStreamEvent; use crate::apis::anthropic::MessagesResponse; -use crate::apis::anthropic::MessagesStreamEvent; use crate::apis::openai::ChatCompletionsResponse; -use crate::apis::openai::ChatCompletionsStreamResponse; use crate::apis::openai_responses::ResponsesAPIResponse; -use crate::apis::openai_responses::ResponseAPIStreamEvent; -use crate::apis::sse::SseEvent; -use crate::clients::endpoints::SupportedAPIsFromClients; +use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::providers::id::ProviderId; -/// Trait for token usage information -pub trait TokenUsage { - fn completion_tokens(&self) -> usize; - fn prompt_tokens(&self) -> usize; - fn total_tokens(&self) -> usize; -} #[derive(Serialize, Debug, Clone)] #[serde(untagged)] @@ -31,13 +19,11 @@ pub enum ProviderResponseType { ResponsesAPIResponse(ResponsesAPIResponse), } -#[derive(Serialize, Debug, Clone)] -#[serde(untagged)] -pub enum ProviderStreamResponseType { - ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), - MessagesStreamEvent(MessagesStreamEvent), - ConverseStreamEvent(ConverseStreamEvent), - ResponseAPIStreamEvent(ResponseAPIStreamEvent) +/// Trait for token usage information +pub trait TokenUsage { + fn completion_tokens(&self) -> usize; + fn prompt_tokens(&self) -> usize; + fn total_tokens(&self) -> usize; } pub trait ProviderResponse: Send + Sync { @@ -72,94 +58,19 @@ impl ProviderResponse for ProviderResponseType { } } } -pub trait ProviderStreamResponse: Send + Sync { - /// Get the content delta for this chunk - fn content_delta(&self) -> Option<&str>; - - /// Check if this is the final chunk in the stream - fn is_final(&self) -> bool; - - /// Get role information if available - fn role(&self) -> Option<&str>; - - /// Get event type for SSE streaming (used by Anthropic) - fn event_type(&self) -> Option<&str>; -} - -impl ProviderStreamResponse for ProviderStreamResponseType { - fn content_delta(&self) -> Option<&str> { - match self { - ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(), - ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.content_delta(), - ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.content_delta(), - ProviderStreamResponseType::ResponseAPIStreamEvent(_resp) => None, // ResponsesAPI does not have content deltas - } - } - - fn is_final(&self) -> bool { - match self { - ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(), - ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.is_final(), - ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.is_final(), - ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.is_final(), - } - } - - fn role(&self) -> Option<&str> { - match self { - ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(), - ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.role(), - ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.role(), - ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.role(), - } - } - - fn event_type(&self) -> Option<&str> { - match self { - ProviderStreamResponseType::ChatCompletionsStreamResponse(_resp) => None, // OpenAI doesn't use event types - ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.event_type(), - ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.event_type(), // Bedrock doesn't use event types - ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(), - } - } -} - -impl Into for ProviderStreamResponseType { - fn into(self) -> String { - match self { - ProviderStreamResponseType::MessagesStreamEvent(event) => { - // Use the Into implementation for proper SSE formatting with event lines - event.into() - } - ProviderStreamResponseType::ConverseStreamEvent(event) => { - // Use the Into implementation for proper SSE formatting with event lines - event.into() - } - ProviderStreamResponseType::ResponseAPIStreamEvent(event) => { - // Use the Into implementation for proper SSE formatting with event lines - event.into() - } - ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => { - // For OpenAI, use simple data line format - let json = serde_json::to_string(&self).unwrap_or_default(); - format!("data: {}\n\n", json) - } - } - } -} // --- Response transformation logic for client API compatibility --- -impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderResponseType { +impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderResponseType { type Error = std::io::Error; fn try_from( - (bytes, client_api, provider_id): (&[u8], &SupportedAPIsFromClients, &ProviderId), + (bytes, client_api, provider_id): (&[u8], &SupportedAPIsFromClient, &ProviderId), ) -> Result { let upstream_api = provider_id.compatible_api_for_client(client_api, false); match (&upstream_api, client_api) { ( SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIsFromClients::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => { let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -167,7 +78,7 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon } ( SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - SupportedAPIsFromClients::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => { let resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -175,7 +86,7 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon } ( SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - SupportedAPIsFromClients::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => { let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -192,7 +103,7 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon } ( SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIsFromClients::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => { let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -209,7 +120,7 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon // Amazon Bedrock transformations ( SupportedUpstreamAPIs::AmazonBedrockConverse(_), - SupportedAPIsFromClients::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), ) => { let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -225,7 +136,7 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon } ( SupportedUpstreamAPIs::AmazonBedrockConverse(_), - SupportedAPIsFromClients::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), ) => { let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -241,7 +152,7 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon } ( SupportedUpstreamAPIs::OpenAIResponsesAPI(_), - SupportedAPIsFromClients::OpenAIResponsesAPI(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), ) => { let resp: ResponsesAPIResponse = ResponsesAPIResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -249,7 +160,7 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon } ( SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIsFromClients::OpenAIResponsesAPI(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), ) => { let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; @@ -263,6 +174,31 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon })?; Ok(ProviderResponseType::ResponsesAPIResponse(responses_resp)) } + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + + //Chain transform: Anthropic Messages -> OpenAI ChatCompletions -> ResponsesAPI + let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to ChatCompletions format using the transformer + let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + + let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) + } _ => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation", @@ -271,247 +207,6 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClients, &ProviderId)> for ProviderRespon } } -// Stream response transformation logic for client API compatibility -impl TryFrom<(&[u8], &SupportedAPIsFromClients, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { - type Error = Box; - - fn try_from( - (bytes, client_api, upstream_api): (&[u8], &SupportedAPIsFromClients, &SupportedUpstreamAPIs), - ) -> Result { - // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion - if bytes == b"[DONE]" && matches!(client_api, SupportedAPIsFromClients::AnthropicMessagesAPI(_)) { - return Ok(ProviderStreamResponseType::MessagesStreamEvent( - crate::apis::anthropic::MessagesStreamEvent::MessageStop, - )); - } - match (upstream_api, client_api) { - // OpenAI upstream - ( - SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIsFromClients::OpenAIChatCompletions(_), - ) => { - let resp = serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( - resp, - )) - } - ( - SupportedUpstreamAPIs::OpenAIChatCompletions(_), - SupportedAPIsFromClients::AnthropicMessagesAPI(_), - ) => { - let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = - serde_json::from_slice(bytes)?; - let anthropic_resp = openai_resp.try_into()?; - Ok(ProviderStreamResponseType::MessagesStreamEvent( - anthropic_resp, - )) - } - - // Anthropic upstream - ( - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - SupportedAPIsFromClients::AnthropicMessagesAPI(_), - ) => { - let resp = serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) - } - ( - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - SupportedAPIsFromClients::OpenAIChatCompletions(_), - ) => { - let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = - serde_json::from_slice(bytes)?; - let openai_resp = anthropic_resp.try_into()?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( - openai_resp, - )) - } - - // Amazon Bedrock ConverseStream upstream - ( - SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), - SupportedAPIsFromClients::AnthropicMessagesAPI(_), - ) => { - let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = - serde_json::from_slice(bytes)?; - let anthropic_resp = bedrock_resp.try_into()?; - Ok(ProviderStreamResponseType::MessagesStreamEvent( - anthropic_resp, - )) - } - _ => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "Unsupported API combination for response transformation", - ) - .into()), - } - } -} - -// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response -impl TryFrom<(SseEvent, &SupportedAPIsFromClients, &SupportedUpstreamAPIs)> for SseEvent { - type Error = Box; - - fn try_from( - (sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIsFromClients, &SupportedUpstreamAPIs), - ) -> Result { - // Create a new transformed event based on the original - let mut transformed_event = sse_event; - - // Handle [DONE] marker early - don't try to parse as JSON - if transformed_event.is_done() { - // For OpenAI client API, keep [DONE] as-is - // For Anthropic client API, it will be transformed via ProviderStreamResponseType - if matches!(client_api, SupportedAPIsFromClients::OpenAIChatCompletions(_)) { - // Keep the [DONE] marker as-is for OpenAI clients - transformed_event.sse_transform_buffer = "data: [DONE]".to_string(); - return Ok(transformed_event); - } - } - - // If has data, parse the data as a provider stream response (business logic layer) - if transformed_event.data.is_some() { - let data_str = transformed_event.data.as_ref().unwrap(); - let data_bytes = data_str.as_bytes(); - let transformed_response: ProviderStreamResponseType = - ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; - - // Convert to SSE string explicitly to avoid type ambiguity - let sse_string: String = transformed_response.clone().into(); - transformed_event.sse_transform_buffer = sse_string; - transformed_event.provider_stream_response = Some(transformed_response); - } - - match (client_api, upstream_api) { - ( - SupportedAPIsFromClients::AnthropicMessagesAPI(_), - SupportedUpstreamAPIs::OpenAIChatCompletions(_), - ) => { - if let Some(provider_response) = &transformed_event.provider_stream_response { - if let Some(event_type) = provider_response.event_type() { - // This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s) - if event_type == "message_start" { - // Create ContentBlockStart event and format it using Into - let content_block_start = MessagesStreamEvent::ContentBlockStart { - index: 0, - content_block: crate::apis::anthropic::MessagesContentBlock::Text { - text: String::new(), - cache_control: None, - }, - }; - let content_block_start_sse: String = content_block_start.into(); - - // Format as proper SSE: MessageStart first, then ContentBlockStart - // The sse_transform_buffer already contains the properly formatted MessageStart - transformed_event.sse_transform_buffer = format!( - "{}{}", - transformed_event.sse_transform_buffer, content_block_start_sse, - ); - } else if event_type == "message_delta" { - // Create ContentBlockStop event and format it using Into - let content_block_stop = - MessagesStreamEvent::ContentBlockStop { index: 0 }; - let content_block_stop_sse: String = content_block_stop.into(); - - // Format as proper SSE: ContentBlockStop first, then MessageDelta - transformed_event.sse_transform_buffer = format!( - "{}{}", - content_block_stop_sse, transformed_event.sse_transform_buffer - ); - } - // For other event types, the sse_transform_buffer already has the correct format from Into - } - // If event_type is None, we just keep the data line as-is without an event line - // This handles cases where the transformation might not produce a valid event type - } - } - ( - SupportedAPIsFromClients::OpenAIChatCompletions(_), - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - ) => { - if transformed_event.is_event_only() && transformed_event.event.is_some() { - transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI - } - } - ( - SupportedAPIsFromClients::AnthropicMessagesAPI(_), - SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - ) => { - // When both client and upstream are Anthropic, suppress event-only lines - // because the data line transformation already includes the event line - if transformed_event.is_event_only() && transformed_event.event.is_some() { - transformed_event.sse_transform_buffer = String::new(); // suppress duplicate event line - } - } - _ => { - // Other combinations can be handled here as needed - } - } - - Ok(transformed_event) - } -} - -// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType -impl - TryFrom<( - &aws_smithy_eventstream::frame::DecodedFrame, - &SupportedAPIsFromClients, - &SupportedUpstreamAPIs, - )> for ProviderStreamResponseType -{ - type Error = Box; - - fn try_from( - (frame, client_api, upstream_api): ( - &aws_smithy_eventstream::frame::DecodedFrame, - &SupportedAPIsFromClients, - &SupportedUpstreamAPIs, - ), - ) -> Result { - use aws_smithy_eventstream::frame::DecodedFrame; - - match frame { - DecodedFrame::Complete(_) => { - // We have a complete frame - parse it based on upstream API - match (upstream_api, client_api) { - ( - SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), - SupportedAPIsFromClients::AnthropicMessagesAPI(_), - ) => { - // Parse the DecodedFrame into ConverseStreamEvent - let bedrock_event = - crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; - let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = - bedrock_event.try_into()?; - - Ok(ProviderStreamResponseType::MessagesStreamEvent( - anthropic_event, - )) - } - ( - SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), - SupportedAPIsFromClients::OpenAIChatCompletions(_), - ) => { - // Parse the DecodedFrame into ConverseStreamEvent - let bedrock_event = - crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; - let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = - bedrock_event.try_into()?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( - openai_event, - )) - } - _ => Err("Unsupported API combination for event-stream decoding".into()), - } - } - DecodedFrame::Incomplete => { - Err("Cannot convert incomplete frame to provider response".into()) - } - } - } -} - #[derive(Debug)] pub struct ProviderResponseError { pub message: String, @@ -535,11 +230,9 @@ impl Error for ProviderResponseError { #[cfg(test)] mod tests { use super::*; - use crate::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; - use crate::apis::anthropic::AnthropicApi; use crate::apis::openai::OpenAIApi; - use crate::apis::sse::SseStreamIter; - use crate::clients::endpoints::SupportedAPIsFromClients; + use crate::apis::anthropic::AnthropicApi; + use crate::clients::endpoints::SupportedAPIsFromClient; use crate::providers::id::ProviderId; use serde_json::json; @@ -563,7 +256,7 @@ mod tests { let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from(( bytes.as_slice(), - &SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + &SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI, )); assert!(result.is_ok()); @@ -592,7 +285,7 @@ mod tests { let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from(( bytes.as_slice(), - &SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages), + &SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic, )); assert!(result.is_ok()); @@ -626,7 +319,7 @@ mod tests { let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from(( bytes.as_slice(), - &SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages), + &SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI, )); assert!(result.is_ok()); @@ -668,7 +361,7 @@ mod tests { let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from(( bytes.as_slice(), - &SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + &SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic, )); assert!(result.is_ok()); @@ -681,951 +374,4 @@ mod tests { _ => panic!("Expected ChatCompletionsResponse variant"), } } - - #[test] - fn test_sse_event_parsing() { - // Test valid SSE data line - let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"; - let event: Result = line.parse(); - assert!(event.is_ok()); - let event = event.unwrap(); - assert_eq!( - event.data, - Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string()) - ); - - // Test conversion back to line using Display trait - let wire_format = event.to_string(); - assert_eq!( - wire_format, - "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n" - ); - - // Test [DONE] marker - should be valid SSE event - let done_line = "data: [DONE]"; - let done_result: Result = done_line.parse(); - assert!(done_result.is_ok()); - let done_event = done_result.unwrap(); - assert_eq!(done_event.data, Some("[DONE]".to_string())); - assert!(done_event.is_done()); // Test the helper method - - // Test non-DONE event - assert!(!event.is_done()); - - // Test empty data - should return error - let empty_line = "data: "; - let empty_result: Result = empty_line.parse(); - assert!(empty_result.is_err()); - - // Test non-data line - should return error - let comment_line = ": this is a comment"; - let comment_result: Result = comment_line.parse(); - assert!(comment_result.is_err()); - } - - #[test] - fn test_sse_event_serde() { - // Test serialization and deserialization with serde - let event = SseEvent { - data: Some(r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string()), - event: None, - raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"} - - "# - .to_string(), - sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"} - - "# - .to_string(), - provider_stream_response: None, - }; - - // Test JSON serialization - raw_line should be skipped - let json = serde_json::to_string(&event).unwrap(); - assert!(json.contains("test")); - assert!(json.contains("chat.completion.chunk")); - assert!(!json.contains("raw_line")); // Should be excluded from serialization - - // Test JSON deserialization - let deserialized: SseEvent = serde_json::from_str(&json).unwrap(); - assert_eq!(deserialized.data, event.data); - assert_eq!(deserialized.raw_line, ""); // Should be empty since it's skipped - - // Test round trip for data field only - assert_eq!(event.data, deserialized.data); - } - - #[test] - fn test_sse_event_should_skip() { - // Test ping message should be skipped - let ping_event = SseEvent { - data: Some(r#"{"type": "ping"}"#.to_string()), - event: None, - raw_line: r#"data: {"type": "ping"}"#.to_string(), - sse_transform_buffer: r#"data: {"type": "ping"}"#.to_string(), - provider_stream_response: None, - }; - assert!(ping_event.should_skip()); - assert!(!ping_event.is_done()); - - // Test normal event should not be skipped - let normal_event = SseEvent { - data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()), - event: Some("content_block_delta".to_string()), - raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), - sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"# - .to_string(), - provider_stream_response: None, - }; - assert!(!normal_event.should_skip()); - assert!(!normal_event.is_done()); - - // Test [DONE] event should not be skipped (but is handled separately) - let done_event = SseEvent { - data: Some("[DONE]".to_string()), - event: None, - raw_line: "data: [DONE]".to_string(), - sse_transform_buffer: "data: [DONE]".to_string(), - provider_stream_response: None, - }; - assert!(!done_event.should_skip()); - assert!(done_event.is_done()); - } - - #[test] - fn test_sse_stream_iter_filters_ping_messages() { - // Create test data with ping messages mixed in - let test_lines = vec![ - "data: {\"id\": \"msg1\", \"object\": \"chat.completion.chunk\"}".to_string(), - "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out - "data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(), - "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out - "data: [DONE]".to_string(), // This should end the stream - ]; - - let mut iter = SseStreamIter::new(test_lines.into_iter()); - - // First event should be msg1 (ping filtered out) - let event1 = iter.next().unwrap(); - assert!(event1.data.as_ref().unwrap().contains("msg1")); - assert!(!event1.should_skip()); - - // Second event should be msg2 (ping filtered out) - let event2 = iter.next().unwrap(); - assert!(event2.data.as_ref().unwrap().contains("msg2")); - assert!(!event2.should_skip()); - - // Third event should be [DONE] - let done_event = iter.next().unwrap(); - assert!(done_event.is_done()); - - // Iterator should end after [DONE] - assert!(iter.next().is_none()); - } - - #[test] - fn test_sse_stream_iter_handles_anthropic_events() { - // Create test data with Anthropic-style event/data pairs - let test_lines = vec![ - "event: message_start".to_string(), - "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\"}}".to_string(), - "event: content_block_delta".to_string(), - "data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}".to_string(), - "data: [DONE]".to_string(), - ]; - - let mut iter = SseStreamIter::new(test_lines.into_iter()); - - // First event should be the event: line - let event1 = iter.next().unwrap(); - assert!(event1.is_event_only()); - assert_eq!(event1.event, Some("message_start".to_string())); - assert_eq!(event1.data, None); - - // Second event should be the data: line - let event2 = iter.next().unwrap(); - assert!(!event2.is_event_only()); - assert_eq!(event2.event, None); - assert!(event2.data.as_ref().unwrap().contains("message_start")); - - // Third event should be another event: line - let event3 = iter.next().unwrap(); - assert!(event3.is_event_only()); - assert_eq!(event3.event, Some("content_block_delta".to_string())); - - // Fourth event should be the content delta data - let event4 = iter.next().unwrap(); - assert!(!event4.is_event_only()); - assert!(event4.data.as_ref().unwrap().contains("Hello")); - - // Fifth event should be [DONE] - let done_event = iter.next().unwrap(); - assert!(done_event.is_done()); - - // Iterator should end after [DONE] - assert!(iter.next().is_none()); - } - - #[test] - fn test_provider_stream_response_event_type() { - use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent}; - use crate::apis::openai::ChatCompletionsStreamResponse; - - // Test Anthropic event type - let anthropic_event = MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: "Hello".to_string(), - }, - }; - let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event); - assert_eq!(provider_type.event_type(), Some("content_block_delta")); - - // Test OpenAI event type (should be None) - let openai_event = ChatCompletionsStreamResponse { - id: "test".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: 123456789, - model: "gpt-4".to_string(), - choices: vec![], - usage: None, - system_fingerprint: None, - service_tier: None, - }; - let provider_type = ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event); - assert_eq!(provider_type.event_type(), None); - } - - #[test] - fn test_done_marker_handled_in_stream_response_transformation() { - use crate::apis::anthropic::AnthropicApi; - - // Test that [DONE] marker is properly converted to MessageStop in the transformation layer - let done_bytes = b"[DONE]"; - let client_api = SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions( - crate::apis::openai::OpenAIApi::ChatCompletions, - ); - - let result = ProviderStreamResponseType::try_from(( - done_bytes.as_slice(), - &client_api, - &upstream_api, - )); - assert!(result.is_ok()); - - if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result { - // Verify it's a MessageStop event - assert_eq!(event.event_type(), Some("message_stop")); - assert!(matches!( - event, - crate::apis::anthropic::MessagesStreamEvent::MessageStop - )); - } else { - panic!("Expected MessagesStreamEvent::MessageStop"); - } - } - - #[test] - fn test_bedrock_event_stream_decoder_basic() { - use bytes::BytesMut; - - // Create a simple test with minimal data - let mut buffer = BytesMut::new(); - - // Add some arbitrary bytes (not a real event-stream frame, just for testing the decoder) - buffer.extend_from_slice(b"test data"); - - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - - // The decoder should return Incomplete for incomplete/invalid data - // This signals the caller to wait for more data - let result = decoder.decode_frame(); - assert!(result.is_some()); - assert!(matches!( - result.unwrap(), - aws_smithy_eventstream::frame::DecodedFrame::Incomplete - )); - - // Verify we can still access the buffer - assert!(decoder.has_remaining()); - } - - #[test] - fn test_bedrock_event_stream_decoder_with_real_frames() { - use bytes::BytesMut; - use std::fs; - use std::path::PathBuf; - - // Read the actual response.hex file from tests/e2e directory - let test_file = - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); - - // Only run this test if the file exists - if !test_file.exists() { - println!("Skipping test - response.hex not found"); - return; - } - - let response_data = fs::read(&test_file).unwrap(); - let mut buffer = BytesMut::from(&response_data[..]); - - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - let mut frame_count = 0; - - // Decode all frames - loop { - match decoder.decode_frame() { - Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(message)) => { - frame_count += 1; - - // Verify we can access headers - let event_type = message - .headers() - .iter() - .find(|h| h.name().as_str() == ":event-type") - .and_then(|h| h.value().as_string().ok()); - - assert!(event_type.is_some(), "Frame should have :event-type header"); - } - Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { - // End of buffer, no more complete frames available - break; - } - None => { - // Decode error - panic!("Decode error encountered"); - } - } - } - - // We should have decoded multiple frames - assert!(frame_count > 0, "Should have decoded at least one frame"); - } - - #[test] - fn test_bedrock_event_stream_decoder_chunked_data() { - use bytes::BytesMut; - use std::fs; - use std::path::PathBuf; - - // Read the actual response.hex file - let test_file = - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); - - if !test_file.exists() { - println!("Skipping test - response.hex not found"); - return; - } - - let response_data = fs::read(&test_file).unwrap(); - - // Simulate chunked network arrivals with realistic chunk sizes - // Using varying chunk sizes to test partial frame handling - let mut buffer = BytesMut::new(); - let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500]; - let mut offset = 0; - let mut total_frames = 0; - let mut chunk_num = 0; - - // CRITICAL: Create ONE decoder and reuse it across chunks - // The MessageFrameDecoder maintains state about partial frames - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - - // Process all data in chunks - while offset < response_data.len() { - let chunk_size = chunk_size_pattern[chunk_num % chunk_size_pattern.len()]; - chunk_num += 1; - - let end = (offset + chunk_size).min(response_data.len()); - let chunk = &response_data[offset..end]; - - // Add new data to the buffer (accessing via buffer_mut()) - decoder.buffer_mut().extend_from_slice(chunk); - offset = end; - - // Process all available complete frames from this chunk - loop { - match decoder.decode_frame() { - Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { - total_frames += 1; - } - Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { - // Need more data - wait for next chunk - break; - } - None => { - // Decode error - panic!("Decode error in chunked test"); - } - } - } - } - - assert!( - total_frames > 0, - "Should have decoded frames from chunked data" - ); - } - - #[test] - fn test_bedrock_decoded_frame_to_provider_response() { - test_bedrock_conversion(false); - } - - #[test] - #[ignore] // Run with: cargo test -- --ignored --nocapture - fn test_bedrock_decoded_frame_to_provider_response_verbose() { - test_bedrock_conversion(true); - } - - #[test] - fn test_bedrock_decoded_frame_with_tool_use() { - test_bedrock_conversion_with_tools(false); - } - - #[test] - #[ignore] // Run with: cargo test -- --ignored --nocapture - fn test_bedrock_decoded_frame_with_tool_use_verbose() { - test_bedrock_conversion_with_tools(true); - } - - fn test_bedrock_conversion(verbose: bool) { - use bytes::BytesMut; - use std::fs; - use std::path::PathBuf; - - // Read the actual response.hex file from tests/e2e directory - let test_file = - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); - - // Only run this test if the file exists - if !test_file.exists() { - println!("Skipping test - response.hex not found"); - return; - } - - let response_data = fs::read(&test_file).unwrap(); - let mut buffer = BytesMut::from(&response_data[..]); - - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - - let client_api = - SupportedAPIsFromClients::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( - crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, - ); - - let mut conversion_count = 0; - let mut message_start_seen = false; - - // Decode and convert frames - loop { - match decoder.decode_frame() { - Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { - // Convert DecodedFrame to ProviderStreamResponseType - let result = - ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); - - match result { - Ok(provider_response) => { - conversion_count += 1; - - // Verify we got a MessagesStreamEvent - assert!(matches!( - provider_response, - ProviderStreamResponseType::MessagesStreamEvent(_) - )); - - if verbose { - // Print the SSE string output - let sse_string: String = provider_response.clone().into(); - println!("{}", sse_string); - } - - // Check for MessageStart event - if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = - provider_response - { - if matches!( - event, - crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } - ) { - message_start_seen = true; - } - } - } - Err(e) => { - println!("Conversion error (frame {}): {}", conversion_count, e); - } - } - } - Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { - // End of buffer - break; - } - None => { - panic!("Decode error"); - } - } - } - - assert!( - conversion_count > 0, - "Should have converted at least one frame" - ); - assert!(message_start_seen, "Should have seen MessageStart event"); - } - - fn test_bedrock_conversion_with_tools(verbose: bool) { - use bytes::BytesMut; - use std::fs; - use std::path::PathBuf; - - // Read the actual response_with_tools.hex file from tests/e2e directory - let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("../../tests/e2e/response_with_tools.hex"); - - // Only run this test if the file exists - if !test_file.exists() { - println!("Skipping test - response_with_tools.hex not found"); - return; - } - - let response_data = fs::read(&test_file).unwrap(); - let mut buffer = BytesMut::from(&response_data[..]); - - let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - - let client_api = - SupportedAPIsFromClients::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( - crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, - ); - - let mut conversion_count = 0; - let mut message_start_seen = false; - let mut content_block_start_seen = false; - let mut content_block_delta_tool_use_seen = false; - - // Decode and convert frames - loop { - match decoder.decode_frame() { - Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { - // Convert DecodedFrame to ProviderStreamResponseType - let result = - ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); - - match result { - Ok(provider_response) => { - conversion_count += 1; - - // Verify we got a MessagesStreamEvent - assert!(matches!( - provider_response, - ProviderStreamResponseType::MessagesStreamEvent(_) - )); - - if verbose { - // Print the SSE string output - let sse_string: String = provider_response.clone().into(); - println!("{}", sse_string); - } - - // Check for specific events related to tool use - if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = - provider_response - { - match event { - crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => { - message_start_seen = true; - } - crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => { - content_block_start_seen = true; - } - crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => { - if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) { - content_block_delta_tool_use_seen = true; - } - } - _ => {} - } - } - } - Err(e) => { - println!("Conversion error (frame {}): {}", conversion_count, e); - } - } - } - Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { - // End of buffer - break; - } - None => { - panic!("Decode error"); - } - } - } - - assert!( - conversion_count > 0, - "Should have converted at least one frame" - ); - assert!(message_start_seen, "Should have seen MessageStart event"); - assert!( - content_block_start_seen, - "Should have seen ContentBlockStart event for tool use" - ); - assert!( - content_block_delta_tool_use_seen, - "Should have seen ContentBlockDelta with ToolUseDelta" - ); - } - - #[test] - fn test_sse_event_transformation_openai_to_anthropic_message_start() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response that represents a role start (which becomes message_start in Anthropic) - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"role": "assistant"}, - "finish_reason": null - }] - }); - - // Create SSE event with this data - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify the transformation includes both message_start and content_block_start - let buffer = transformed.sse_transform_buffer; - assert!( - buffer.contains("event: message_start"), - "Should contain message_start event" - ); - assert!( - buffer.contains("event: content_block_start"), - "Should contain content_block_start event" - ); - - // Verify proper SSE format with event lines before data lines - assert!(buffer.find("event: message_start").unwrap() < buffer.find("data:").unwrap()); - assert!(buffer.find("content_block_start").is_some()); - } - - #[test] - fn test_sse_event_transformation_openai_to_anthropic_message_delta() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic) - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {}, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 25, - "total_tokens": 35 - } - }); - - // Create SSE event with this data - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify the transformation includes both content_block_stop and message_delta - let buffer = transformed.sse_transform_buffer; - assert!( - buffer.contains("event: content_block_stop"), - "Should contain content_block_stop event" - ); - assert!( - buffer.contains("event: message_delta"), - "Should contain message_delta event" - ); - - // Verify content_block_stop comes before message_delta - let stop_pos = buffer.find("content_block_stop").unwrap(); - let delta_pos = buffer.find("message_delta").unwrap(); - assert!( - stop_pos < delta_pos, - "content_block_stop should come before message_delta" - ); - } - - #[test] - fn test_sse_event_transformation_openai_to_anthropic_content_delta() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic) - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Hello"}, - "finish_reason": null - }] - }); - - // Create SSE event with this data - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify the transformation is a content_block_delta (no extra events injected) - let buffer = transformed.sse_transform_buffer; - assert!( - buffer.contains("event: content_block_delta"), - "Should contain content_block_delta event" - ); - assert!( - !buffer.contains("content_block_start"), - "Should not inject content_block_start for content delta" - ); - assert!( - !buffer.contains("content_block_stop"), - "Should not inject content_block_stop for content delta" - ); - - // Verify the content is preserved - assert!(buffer.contains("Hello"), "Should preserve the content text"); - } - - #[test] - fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an Anthropic event-only SSE line (no data) - let sse_event = SseEvent { - data: None, - event: Some("message_start".to_string()), - raw_line: "event: message_start".to_string(), - sse_transform_buffer: "event: message_start".to_string(), - provider_stream_response: None, - }; - - let client_api = SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify the event line is suppressed (replaced with just newline) - assert_eq!( - transformed.sse_transform_buffer, "\n", - "Event-only lines should be suppressed to newline for OpenAI" - ); - assert!( - transformed.is_event_only(), - "Should still be marked as event-only" - ); - } - - #[test] - fn test_sse_event_transformation_anthropic_to_openai_preserves_data() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an Anthropic message_start event with data - let anthropic_event = json!({ - "type": "message_start", - "message": { - "id": "msg_123", - "type": "message", - "role": "assistant", - "content": [], - "model": "claude-3-sonnet", - "stop_reason": null, - "usage": {"input_tokens": 10, "output_tokens": 0} - } - }); - - let sse_event = SseEvent { - data: Some(anthropic_event.to_string()), - event: None, - raw_line: format!("data: {}", anthropic_event.to_string()), - sse_transform_buffer: format!("data: {}", anthropic_event.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify data is transformed to OpenAI format - let buffer = transformed.sse_transform_buffer; - assert!(buffer.starts_with("data: "), "Should have data: prefix"); - assert!( - !buffer.contains("event:"), - "Should not have event: lines for OpenAI" - ); - - // Verify provider response was parsed - assert!(transformed.provider_stream_response.is_some()); - } - - #[test] - fn test_sse_event_transformation_no_change_for_matching_apis() { - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Hello"}, - "finish_reason": null - }] - }); - - let original_data = openai_stream_chunk.to_string(); - let sse_event = SseEvent { - data: Some(original_data.clone()), - event: None, - raw_line: format!("data: {}", original_data), - sse_transform_buffer: format!("data: {}\n\n", original_data), - provider_stream_response: None, - }; - - let client_api = SupportedAPIsFromClients::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify minimal transformation - just SSE formatting, no API conversion - let buffer = transformed.sse_transform_buffer; - assert!(buffer.starts_with("data: "), "Should preserve data: prefix"); - assert!(!buffer.contains("event:"), "Should not add event: lines"); - - // Verify provider response was parsed - assert!(transformed.provider_stream_response.is_some()); - } - - #[test] - fn test_sse_event_transformation_preserves_provider_response() { - use crate::apis::anthropic::AnthropicApi; - use crate::apis::openai::OpenAIApi; - - // Create an OpenAI stream response - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Test"}, - "finish_reason": null - }] - }); - - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), - provider_stream_response: None, - }; - - let client_api = SupportedAPIsFromClients::AnthropicMessagesAPI(AnthropicApi::Messages); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Transform the event - let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); - assert!(result.is_ok()); - - let transformed = result.unwrap(); - - // Verify provider_stream_response is populated - assert!( - transformed.provider_stream_response.is_some(), - "Should parse and store provider response" - ); - - // Verify we can access the provider response - let provider_response = transformed.provider_response(); - assert!( - provider_response.is_ok(), - "Should be able to access provider response" - ); - - // Verify the content delta is accessible - let content = provider_response.unwrap().content_delta(); - assert_eq!(content, Some("Test"), "Should preserve content delta"); - } } diff --git a/crates/hermesllm/src/providers/streaming_response.rs b/crates/hermesllm/src/providers/streaming_response.rs new file mode 100644 index 00000000..7707d88d --- /dev/null +++ b/crates/hermesllm/src/providers/streaming_response.rs @@ -0,0 +1,1318 @@ +use serde::Serialize; +use std::convert::TryFrom; + +use crate::apis::openai::ChatCompletionsStreamResponse; +use crate::apis::openai_responses::ResponsesAPIStreamEvent; +use crate::apis::streaming_shapes::sse::SseEvent; +use crate::apis::amazon_bedrock::ConverseStreamEvent; +use crate::apis::anthropic::MessagesStreamEvent; +use crate::apis::streaming_shapes::sse::SseStreamBuffer; +use crate::apis::streaming_shapes::{ + anthropic_streaming_buffer::AnthropicMessagesStreamBuffer, + chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer, + passthrough_streaming_buffer::PassthroughStreamBuffer, + responses_api_streaming_buffer::ResponsesAPIStreamBuffer, + }; + +use crate::clients::endpoints::SupportedAPIsFromClient; +use crate::clients::endpoints::SupportedUpstreamAPIs; + +// ============================================================================ +// SSE STREAM BUFFER FACTORY +// ============================================================================ + +/// Check if streaming buffering is needed based on client and upstream API combination. +pub fn needs_buffering( + client_api: &SupportedAPIsFromClient, + upstream_api: &SupportedUpstreamAPIs, +) -> bool { + match (client_api, upstream_api) { + // Same APIs - no buffering needed + (SupportedAPIsFromClient::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => false, + (SupportedAPIsFromClient::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => false, + (SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) => false, + + // Different APIs - buffering needed + _ => true, + } +} + +/// Factory pattern for creating SSE stream buffers based on client and upstream API combination. +/// # Example +/// ```ignore +/// use hermesllm::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; +/// use hermesllm::apis::streaming_shapes::sse::SseStreamBuffer; +/// +/// // Transformation needed: OpenAI upstream -> Anthropic client +/// let mut buffer = SseStreamBuffer::try_from((&client_api, &upstream_api))?; +/// +/// // Add transformed events +/// let transformed = SseEvent::try_from((raw_event, &client_api, &upstream_api))?; +/// buffer.add_transformed_event(transformed); +/// +/// // Flush to wire +/// let bytes = buffer.into_bytes(); +/// ``` +impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> + for SseStreamBuffer +{ + type Error = Box; + + fn try_from( + (client_api, upstream_api): (&SupportedAPIsFromClient, &SupportedUpstreamAPIs), + ) -> Result { + + // If APIs match, use passthrough - no buffering/transformation needed + if !needs_buffering(client_api, upstream_api) { + return Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new())); + } + + // APIs differ - use appropriate buffer for client API + match client_api { + SupportedAPIsFromClient::OpenAIChatCompletions(_) => { + Ok(SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new())) + } + SupportedAPIsFromClient::AnthropicMessagesAPI(_) => { + Ok(SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new())) + } + SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { + Ok(SseStreamBuffer::OpenAIResponses(ResponsesAPIStreamBuffer::new())) + } + } + } +} + +// ============================================================================ +// PROVIDER STREAM RESPONSE TYPES +// ============================================================================ + +#[derive(Serialize, Debug, Clone)] +#[serde(untagged)] +pub enum ProviderStreamResponseType { + ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), + MessagesStreamEvent(MessagesStreamEvent), + ConverseStreamEvent(ConverseStreamEvent), + ResponseAPIStreamEvent(ResponsesAPIStreamEvent) +} + +pub trait ProviderStreamResponse: Send + Sync { + /// Get the content delta for this chunk + fn content_delta(&self) -> Option<&str>; + + /// Check if this is the final chunk in the stream + fn is_final(&self) -> bool; + + /// Get role information if available + fn role(&self) -> Option<&str>; + + /// Get event type for SSE streaming (used by Anthropic) + fn event_type(&self) -> Option<&str>; +} + +impl ProviderStreamResponse for ProviderStreamResponseType { + fn content_delta(&self) -> Option<&str> { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(), + ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.content_delta(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.content_delta(), + ProviderStreamResponseType::ResponseAPIStreamEvent(_resp) => None, // ResponsesAPI does not have content deltas + } + } + + fn is_final(&self) -> bool { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(), + ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.is_final(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.is_final(), + ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.is_final(), + } + } + + fn role(&self) -> Option<&str> { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(), + ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.role(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.role(), + ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.role(), + } + } + + fn event_type(&self) -> Option<&str> { + match self { + ProviderStreamResponseType::ChatCompletionsStreamResponse(_resp) => None, // OpenAI doesn't use event types + ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.event_type(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.event_type(), // Bedrock doesn't use event types + ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(), + } + } +} + +impl Into for ProviderStreamResponseType { + fn into(self) -> String { + match self { + ProviderStreamResponseType::MessagesStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ConverseStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ResponseAPIStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => { + // For OpenAI, use simple data line format + let json = serde_json::to_string(&self).unwrap_or_default(); + format!("data: {}\n\n", json) + } + } + } +} + + +// Stream response transformation logic for client API compatibility +impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { + type Error = Box; + + fn try_from( + (bytes, client_api, upstream_api): (&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs), + ) -> Result { + // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion + if bytes == b"[DONE]" && matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) { + return Ok(ProviderStreamResponseType::MessagesStreamEvent( + crate::apis::anthropic::MessagesStreamEvent::MessageStop, + )); + } + match (upstream_api, client_api) { + // OpenAI upstream + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), + ) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + resp, + )) + } + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = + serde_json::from_slice(bytes)?; + let anthropic_resp = openai_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent( + anthropic_resp, + )) + } + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = + serde_json::from_slice(bytes)?; + let responses_resp = openai_resp.try_into()?; + Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( + responses_resp, + )) + } + + // OpenAI ResponsesAPI upstream + ( + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(resp)) + } + // Anthropic upstream + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) + } + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), + ) => { + let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = + serde_json::from_slice(bytes)?; + let openai_resp = anthropic_resp.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + openai_resp, + )) + } + + // Amazon Bedrock ConverseStream upstream + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = + serde_json::from_slice(bytes)?; + let anthropic_resp = bedrock_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent( + anthropic_resp, + )) + } + _ => Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Unsupported API combination for response transformation", + ) + .into()), + } + } +} + +// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response +impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseEvent { + type Error = Box; + + fn try_from( + (sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs), + ) -> Result { + // Create a new transformed event based on the original + let mut transformed_event = sse_event; + + // Handle [DONE] marker early - don't try to parse as JSON + if transformed_event.is_done() { + // For OpenAI client APIs (ChatCompletions and ResponsesAPI), keep [DONE] as-is + // For Anthropic client API, it will be transformed via ProviderStreamResponseType + if matches!(client_api, SupportedAPIsFromClient::OpenAIChatCompletions(_) | SupportedAPIsFromClient::OpenAIResponsesAPI(_)) { + // Keep the [DONE] marker as-is for OpenAI clients + transformed_event.sse_transformed_lines = "data: [DONE]".to_string(); + return Ok(transformed_event); + } + } + + // If has data, parse the data as a provider stream response (business logic layer) + if transformed_event.data.is_some() { + let data_str = transformed_event.data.as_ref().unwrap(); + let data_bytes = data_str.as_bytes(); + let transformed_response: ProviderStreamResponseType = + ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; + + // Convert to SSE string explicitly to avoid type ambiguity + let sse_string: String = transformed_response.clone().into(); + transformed_event.sse_transformed_lines = sse_string; + transformed_event.provider_stream_response = Some(transformed_response); + } + + // Apply wire format adjustments for cross-API transformations + // Note: When APIs match (passthrough mode), these adjustments are skipped + // since PassthroughStreamBuffer will handle events as-is + if needs_buffering(client_api, upstream_api) { + match (client_api, upstream_api) { + ( + SupportedAPIsFromClient::OpenAIChatCompletions(_), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) => { + // OpenAI clients don't expect separate event: lines + // Suppress upstream Anthropic event-only lines + if transformed_event.is_event_only() && transformed_event.event.is_some() { + transformed_event.sse_transformed_lines = format!("\n"); + } + } + _ => { + // Other cross-API combinations can be handled here as needed + } + } + } else { + // Passthrough mode: APIs match, no transformation needed + // For Anthropic and ResponsesAPI SSE formats, event-only lines are redundant because + // the Into implementation for MessagesStreamEvent and ResponsesAPIStreamEvent + // couples event and data lines together. We suppress event-only events to + // avoid duplicate event: lines in the output. + match (client_api, upstream_api) { + ( + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) | ( + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + ) => { + if transformed_event.is_event_only() && transformed_event.event.is_some() { + // Mark as should-skip by clearing sse_transformed_lines + // The event line is already included when the data line is transformed + transformed_event.sse_transformed_lines = String::new(); + } + } + _ => { + // Other passthrough combinations (OpenAI ChatCompletions, etc.) don't have this issue + } + } + } + + Ok(transformed_event) + } +} + +// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType +impl + TryFrom<( + &aws_smithy_eventstream::frame::DecodedFrame, + &SupportedAPIsFromClient, + &SupportedUpstreamAPIs, + )> for ProviderStreamResponseType +{ + type Error = Box; + + fn try_from( + (frame, client_api, upstream_api): ( + &aws_smithy_eventstream::frame::DecodedFrame, + &SupportedAPIsFromClient, + &SupportedUpstreamAPIs, + ), + ) -> Result { + use aws_smithy_eventstream::frame::DecodedFrame; + + match frame { + DecodedFrame::Complete(_) => { + // We have a complete frame - parse it based on upstream API + match (upstream_api, client_api) { + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = + crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = + bedrock_event.try_into()?; + + Ok(ProviderStreamResponseType::MessagesStreamEvent( + anthropic_event, + )) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), + ) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = + crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = + bedrock_event.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( + openai_event, + )) + } + _ => Err("Unsupported API combination for event-stream decoding".into()), + } + } + DecodedFrame::Incomplete => { + Err("Cannot convert incomplete frame to provider response".into()) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; + use crate::clients::endpoints::SupportedAPIsFromClient; + use crate::apis::streaming_shapes::sse::SseStreamIter; + use serde_json::json; + + #[test] + fn test_sse_event_parsing() { + // Test valid SSE data line + let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"; + let event: Result = line.parse(); + assert!(event.is_ok()); + let event = event.unwrap(); + // The data field should contain only the JSON content, not the trailing newlines + assert_eq!( + event.data, + Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}".to_string()) + ); + + // Test conversion back to line using Display trait + // The sse_transformed_lines preserves the original format including trailing newlines + let wire_format = event.to_string(); + assert_eq!( + wire_format, + "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n" + ); + + // Test [DONE] marker - should be valid SSE event + let done_line = "data: [DONE]"; + let done_result: Result = done_line.parse(); + assert!(done_result.is_ok()); + let done_event = done_result.unwrap(); + assert_eq!(done_event.data, Some("[DONE]".to_string())); + assert!(done_event.is_done()); // Test the helper method + + // Test non-DONE event + assert!(!event.is_done()); + + // Test empty data - should return error + let empty_line = "data: "; + let empty_result: Result = empty_line.parse(); + assert!(empty_result.is_err()); + + // Test non-data line - should return error + let comment_line = ": this is a comment"; + let comment_result: Result = comment_line.parse(); + assert!(comment_result.is_err()); + } + + #[test] + fn test_sse_event_serde() { + // Test serialization and deserialization with serde + let event = SseEvent { + data: Some(r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string()), + event: None, + raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"} + + "# + .to_string(), + sse_transformed_lines: r#"data: {"id":"test","object":"chat.completion.chunk"} + + "# + .to_string(), + provider_stream_response: None, + }; + + // Test JSON serialization - raw_line should be skipped + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("test")); + assert!(json.contains("chat.completion.chunk")); + assert!(!json.contains("raw_line")); // Should be excluded from serialization + + // Test JSON deserialization + let deserialized: SseEvent = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.data, event.data); + assert_eq!(deserialized.raw_line, ""); // Should be empty since it's skipped + + // Test round trip for data field only + assert_eq!(event.data, deserialized.data); + } + + #[test] + fn test_sse_event_should_skip() { + // Test ping message should be skipped + let ping_event = SseEvent { + data: Some(r#"{"type": "ping"}"#.to_string()), + event: None, + raw_line: r#"data: {"type": "ping"}"#.to_string(), + sse_transformed_lines: r#"data: {"type": "ping"}"#.to_string(), + provider_stream_response: None, + }; + assert!(ping_event.should_skip()); + assert!(!ping_event.is_done()); + + // Test normal event should not be skipped + let normal_event = SseEvent { + data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()), + event: Some("content_block_delta".to_string()), + raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), + sse_transformed_lines: r#"data: {"id": "test", "object": "chat.completion.chunk"}"# + .to_string(), + provider_stream_response: None, + }; + assert!(!normal_event.should_skip()); + assert!(!normal_event.is_done()); + + // Test [DONE] event should not be skipped (but is handled separately) + let done_event = SseEvent { + data: Some("[DONE]".to_string()), + event: None, + raw_line: "data: [DONE]".to_string(), + sse_transformed_lines: "data: [DONE]".to_string(), + provider_stream_response: None, + }; + assert!(!done_event.should_skip()); + assert!(done_event.is_done()); + } + + #[test] + fn test_sse_stream_iter_filters_ping_messages() { + // Create test data with ping messages mixed in + let test_lines = vec![ + "data: {\"id\": \"msg1\", \"object\": \"chat.completion.chunk\"}".to_string(), + "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out + "data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(), + "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out + "data: [DONE]".to_string(), // This should end the stream + ]; + + let mut iter = SseStreamIter::new(test_lines.into_iter()); + + // First event should be msg1 (ping filtered out) + let event1 = iter.next().unwrap(); + assert!(event1.data.as_ref().unwrap().contains("msg1")); + assert!(!event1.should_skip()); + + // Second event should be msg2 (ping filtered out) + let event2 = iter.next().unwrap(); + assert!(event2.data.as_ref().unwrap().contains("msg2")); + assert!(!event2.should_skip()); + + // Third event should be [DONE] + let done_event = iter.next().unwrap(); + assert!(done_event.is_done()); + + // Iterator should end after [DONE] + assert!(iter.next().is_none()); + } + + #[test] + fn test_sse_stream_iter_handles_anthropic_events() { + // Create test data with Anthropic-style event/data pairs + let test_lines = vec![ + "event: message_start".to_string(), + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\"}}".to_string(), + "event: content_block_delta".to_string(), + "data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}".to_string(), + "data: [DONE]".to_string(), + ]; + + let mut iter = SseStreamIter::new(test_lines.into_iter()); + + // First event should be the event: line + let event1 = iter.next().unwrap(); + assert!(event1.is_event_only()); + assert_eq!(event1.event, Some("message_start".to_string())); + assert_eq!(event1.data, None); + + // Second event should be the data: line + let event2 = iter.next().unwrap(); + assert!(!event2.is_event_only()); + assert_eq!(event2.event, None); + assert!(event2.data.as_ref().unwrap().contains("message_start")); + + // Third event should be another event: line + let event3 = iter.next().unwrap(); + assert!(event3.is_event_only()); + assert_eq!(event3.event, Some("content_block_delta".to_string())); + + // Fourth event should be the content delta data + let event4 = iter.next().unwrap(); + assert!(!event4.is_event_only()); + assert!(event4.data.as_ref().unwrap().contains("Hello")); + + // Fifth event should be [DONE] + let done_event = iter.next().unwrap(); + assert!(done_event.is_done()); + + // Iterator should end after [DONE] + assert!(iter.next().is_none()); + } + + #[test] + fn test_provider_stream_response_event_type() { + use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent}; + use crate::apis::openai::ChatCompletionsStreamResponse; + + // Test Anthropic event type + let anthropic_event = MessagesStreamEvent::ContentBlockDelta { + index: 0, + delta: MessagesContentDelta::TextDelta { + text: "Hello".to_string(), + }, + }; + let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event); + assert_eq!(provider_type.event_type(), Some("content_block_delta")); + + // Test OpenAI event type (should be None) + let openai_event = ChatCompletionsStreamResponse { + id: "test".to_string(), + object: Some("chat.completion.chunk".to_string()), + created: 123456789, + model: "gpt-4".to_string(), + choices: vec![], + usage: None, + system_fingerprint: None, + service_tier: None, + }; + let provider_type = ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event); + assert_eq!(provider_type.event_type(), None); + } + + #[test] + fn test_done_marker_handled_in_stream_response_transformation() { + use crate::apis::anthropic::AnthropicApi; + + // Test that [DONE] marker is properly converted to MessageStop in the transformation layer + let done_bytes = b"[DONE]"; + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions( + crate::apis::openai::OpenAIApi::ChatCompletions, + ); + + let result = ProviderStreamResponseType::try_from(( + done_bytes.as_slice(), + &client_api, + &upstream_api, + )); + assert!(result.is_ok()); + + if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result { + // Verify it's a MessageStop event + assert_eq!(event.event_type(), Some("message_stop")); + assert!(matches!( + event, + crate::apis::anthropic::MessagesStreamEvent::MessageStop + )); + } else { + panic!("Expected MessagesStreamEvent::MessageStop"); + } + } + + #[test] + fn test_bedrock_event_stream_decoder_basic() { + use bytes::BytesMut; + + // Create a simple test with minimal data + let mut buffer = BytesMut::new(); + + // Add some arbitrary bytes (not a real event-stream frame, just for testing the decoder) + buffer.extend_from_slice(b"test data"); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + // The decoder should return Incomplete for incomplete/invalid data + // This signals the caller to wait for more data + let result = decoder.decode_frame(); + assert!(result.is_some()); + assert!(matches!( + result.unwrap(), + aws_smithy_eventstream::frame::DecodedFrame::Incomplete + )); + + // Verify we can still access the buffer + assert!(decoder.has_remaining()); + } + + #[test] + fn test_bedrock_event_stream_decoder_with_real_frames() { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file from tests/e2e directory + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + let mut frame_count = 0; + + // Decode all frames + loop { + match decoder.decode_frame() { + Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(message)) => { + frame_count += 1; + + // Verify we can access headers + let event_type = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + .and_then(|h| h.value().as_string().ok()); + + assert!(event_type.is_some(), "Frame should have :event-type header"); + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer, no more complete frames available + break; + } + None => { + // Decode error + panic!("Decode error encountered"); + } + } + } + + // We should have decoded multiple frames + assert!(frame_count > 0, "Should have decoded at least one frame"); + } + + #[test] + fn test_bedrock_event_stream_decoder_chunked_data() { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + + // Simulate chunked network arrivals with realistic chunk sizes + // Using varying chunk sizes to test partial frame handling + let mut buffer = BytesMut::new(); + let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500]; + let mut offset = 0; + let mut total_frames = 0; + let mut chunk_num = 0; + + // CRITICAL: Create ONE decoder and reuse it across chunks + // The MessageFrameDecoder maintains state about partial frames + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + // Process all data in chunks + while offset < response_data.len() { + let chunk_size = chunk_size_pattern[chunk_num % chunk_size_pattern.len()]; + chunk_num += 1; + + let end = (offset + chunk_size).min(response_data.len()); + let chunk = &response_data[offset..end]; + + // Add new data to the buffer (accessing via buffer_mut()) + decoder.buffer_mut().extend_from_slice(chunk); + offset = end; + + // Process all available complete frames from this chunk + loop { + match decoder.decode_frame() { + Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + total_frames += 1; + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // Need more data - wait for next chunk + break; + } + None => { + // Decode error + panic!("Decode error in chunked test"); + } + } + } + } + + assert!( + total_frames > 0, + "Should have decoded frames from chunked data" + ); + } + + #[test] + fn test_bedrock_decoded_frame_to_provider_response() { + test_bedrock_conversion(false); + } + + #[test] + #[ignore] // Run with: cargo test -- --ignored --nocapture + fn test_bedrock_decoded_frame_to_provider_response_verbose() { + test_bedrock_conversion(true); + } + + #[test] + fn test_bedrock_decoded_frame_with_tool_use() { + test_bedrock_conversion_with_tools(false); + } + + #[test] + #[ignore] // Run with: cargo test -- --ignored --nocapture + fn test_bedrock_decoded_frame_with_tool_use_verbose() { + test_bedrock_conversion_with_tools(true); + } + + fn test_bedrock_conversion(verbose: bool) { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file from tests/e2e directory + let test_file = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + let client_api = + SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( + crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, + ); + + let mut conversion_count = 0; + let mut message_start_seen = false; + + // Decode and convert frames + loop { + match decoder.decode_frame() { + Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + // Convert DecodedFrame to ProviderStreamResponseType + let result = + ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); + + match result { + Ok(provider_response) => { + conversion_count += 1; + + // Verify we got a MessagesStreamEvent + assert!(matches!( + provider_response, + ProviderStreamResponseType::MessagesStreamEvent(_) + )); + + if verbose { + // Print the SSE string output + let sse_string: String = provider_response.clone().into(); + println!("{}", sse_string); + } + + // Check for MessageStart event + if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = + provider_response + { + if matches!( + event, + crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } + ) { + message_start_seen = true; + } + } + } + Err(e) => { + println!("Conversion error (frame {}): {}", conversion_count, e); + } + } + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer + break; + } + None => { + panic!("Decode error"); + } + } + } + + assert!( + conversion_count > 0, + "Should have converted at least one frame" + ); + assert!(message_start_seen, "Should have seen MessageStart event"); + } + + fn test_bedrock_conversion_with_tools(verbose: bool) { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response_with_tools.hex file from tests/e2e directory + let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../tests/e2e/response_with_tools.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response_with_tools.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + let client_api = + SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( + crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, + ); + + let mut conversion_count = 0; + let mut message_start_seen = false; + let mut content_block_start_seen = false; + let mut content_block_delta_tool_use_seen = false; + + // Decode and convert frames + loop { + match decoder.decode_frame() { + Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + // Convert DecodedFrame to ProviderStreamResponseType + let result = + ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); + + match result { + Ok(provider_response) => { + conversion_count += 1; + + // Verify we got a MessagesStreamEvent + assert!(matches!( + provider_response, + ProviderStreamResponseType::MessagesStreamEvent(_) + )); + + if verbose { + // Print the SSE string output + let sse_string: String = provider_response.clone().into(); + println!("{}", sse_string); + } + + // Check for specific events related to tool use + if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = + provider_response + { + match event { + crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => { + message_start_seen = true; + } + crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => { + content_block_start_seen = true; + } + crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => { + if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) { + content_block_delta_tool_use_seen = true; + } + } + _ => {} + } + } + } + Err(e) => { + println!("Conversion error (frame {}): {}", conversion_count, e); + } + } + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer + break; + } + None => { + panic!("Decode error"); + } + } + } + + assert!( + conversion_count > 0, + "Should have converted at least one frame" + ); + assert!(message_start_seen, "Should have seen MessageStart event"); + assert!( + content_block_start_seen, + "Should have seen ContentBlockStart event for tool use" + ); + assert!( + content_block_delta_tool_use_seen, + "Should have seen ContentBlockDelta with ToolUseDelta" + ); + } + + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_message_delta() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 25, + "total_tokens": 35 + } + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // NOTE: This test now verifies single-event transformation only. + // Multi-event injection (content_block_stop + message_delta) is now handled + // by AnthropicMessagesStreamBuffer, not by TryFrom transformation. + let buffer = transformed.sse_transformed_lines; + + // Verify the event was transformed to Anthropic format + // This should contain message_delta with stop_reason and usage + assert!( + buffer.contains("event: message_delta") || buffer.contains("\"type\":\"message_delta\""), + "Should contain message_delta in transformed event" + ); + + // Verify usage information is present + assert!( + buffer.contains("\"usage\""), + "Should contain usage information" + ); + } + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_content_delta() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": null + }] + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the transformation is a content_block_delta (no extra events injected) + let buffer = transformed.sse_transformed_lines; + assert!( + buffer.contains("event: content_block_delta"), + "Should contain content_block_delta event" + ); + assert!( + !buffer.contains("content_block_start"), + "Should not inject content_block_start for content delta" + ); + assert!( + !buffer.contains("content_block_stop"), + "Should not inject content_block_stop for content delta" + ); + + // Verify the content is preserved + assert!(buffer.contains("Hello"), "Should preserve the content text"); + } + + #[test] + fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an Anthropic event-only SSE line (no data) + let sse_event = SseEvent { + data: None, + event: Some("message_start".to_string()), + raw_line: "event: message_start".to_string(), + sse_transformed_lines: "event: message_start".to_string(), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the event line is suppressed (replaced with just newline) + assert_eq!( + transformed.sse_transformed_lines, "\n", + "Event-only lines should be suppressed to newline for OpenAI" + ); + assert!( + transformed.is_event_only(), + "Should still be marked as event-only" + ); + } + + #[test] + fn test_sse_event_transformation_anthropic_to_openai_preserves_data() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an Anthropic message_start event with data + let anthropic_event = json!({ + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-sonnet", + "stop_reason": null, + "usage": {"input_tokens": 10, "output_tokens": 0} + } + }); + + let sse_event = SseEvent { + data: Some(anthropic_event.to_string()), + event: None, + raw_line: format!("data: {}", anthropic_event.to_string()), + sse_transformed_lines: format!("data: {}", anthropic_event.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify data is transformed to OpenAI format + let buffer = transformed.sse_transformed_lines; + assert!(buffer.starts_with("data: "), "Should have data: prefix"); + assert!( + !buffer.contains("event:"), + "Should not have event: lines for OpenAI" + ); + + // Verify provider response was parsed + assert!(transformed.provider_stream_response.is_some()); + } + + #[test] + fn test_sse_event_transformation_no_change_for_matching_apis() { + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": null + }] + }); + + let original_data = openai_stream_chunk.to_string(); + let sse_event = SseEvent { + data: Some(original_data.clone()), + event: None, + raw_line: format!("data: {}", original_data), + sse_transformed_lines: format!("data: {}\n\n", original_data), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify minimal transformation - just SSE formatting, no API conversion + let buffer = transformed.sse_transformed_lines; + assert!(buffer.starts_with("data: "), "Should preserve data: prefix"); + assert!(!buffer.contains("event:"), "Should not add event: lines"); + + // Verify provider response was parsed + assert!(transformed.provider_stream_response.is_some()); + } + + #[test] + fn test_sse_event_transformation_preserves_provider_response() { + use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Test"}, + "finish_reason": null + }] + }); + + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify provider_stream_response is populated + assert!( + transformed.provider_stream_response.is_some(), + "Should parse and store provider response" + ); + + // Verify we can access the provider response + let provider_response = transformed.provider_response(); + assert!( + provider_response.is_ok(), + "Should be able to access provider response" + ); + + // Verify the content delta is accessible + let content = provider_response.unwrap().content_delta(); + assert_eq!(content, Some("Test"), "Should preserve content delta"); + } +} diff --git a/crates/hermesllm/src/transforms/mod.rs b/crates/hermesllm/src/transforms/mod.rs index 3fb4e397..ebb4bf20 100644 --- a/crates/hermesllm/src/transforms/mod.rs +++ b/crates/hermesllm/src/transforms/mod.rs @@ -11,11 +11,13 @@ pub mod lib; pub mod request; pub mod response; +pub mod response_streaming; // Re-export commonly used items for convenience pub use lib::*; pub use request::*; pub use response::*; +pub use response_streaming::*; // ============================================================================ // CONSTANTS diff --git a/crates/hermesllm/src/transforms/request/from_openai.rs b/crates/hermesllm/src/transforms/request/from_openai.rs index 097b6960..83f13fe8 100644 --- a/crates/hermesllm/src/transforms/request/from_openai.rs +++ b/crates/hermesllm/src/transforms/request/from_openai.rs @@ -395,16 +395,16 @@ impl TryFrom for ChatCompletionsRequest { // Only convert Function tools - other types are not supported in ChatCompletions match tool { - ResponsesTool::Function { function } => Ok(Tool { + ResponsesTool::Function { name, description, parameters, strict } => Ok(Tool { tool_type: "function".to_string(), function: crate::apis::openai::Function { - name: function.name, - description: function.description, - parameters: function.parameters.unwrap_or_else(|| serde_json::json!({ + name, + description, + parameters: parameters.unwrap_or_else(|| serde_json::json!({ "type": "object", "properties": {} })), - strict: function.strict, + strict, } }), ResponsesTool::FileSearch { .. } => Err(TransformError::UnsupportedConversion( diff --git a/crates/hermesllm/src/transforms/response/to_anthropic.rs b/crates/hermesllm/src/transforms/response/to_anthropic.rs index 1c6ce238..0326fdb3 100644 --- a/crates/hermesllm/src/transforms/response/to_anthropic.rs +++ b/crates/hermesllm/src/transforms/response/to_anthropic.rs @@ -1,16 +1,11 @@ -use crate::apis::amazon_bedrock::{ - ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason, -}; +use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason}; use crate::apis::anthropic::{ - MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse, - MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage, -}; -use crate::apis::openai::{ - ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta, + MessagesContentBlock, MessagesResponse, + MessagesRole, MessagesStopReason, MessagesUsage, }; +use crate::apis::openai::ChatCompletionsResponse; use crate::clients::TransformError; use crate::transforms::lib::*; -use serde_json::Value; // ============================================================================ // STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience @@ -120,289 +115,6 @@ impl TryFrom for MessagesResponse { } } -impl TryFrom for MessagesStreamEvent { - type Error = TransformError; - - fn try_from(resp: ChatCompletionsStreamResponse) -> Result { - if resp.choices.is_empty() { - return Ok(MessagesStreamEvent::Ping); - } - - let choice = &resp.choices[0]; - - // Handle final chunk with usage - let has_usage = resp.usage.is_some(); - if let Some(usage) = resp.usage { - if let Some(finish_reason) = &choice.finish_reason { - let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); - return Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: usage.into(), - }); - } - } - - // Handle role start - if let Some(Role::Assistant) = choice.delta.role { - return Ok(MessagesStreamEvent::MessageStart { - message: MessagesStreamMessage { - id: resp.id, - obj_type: "message".to_string(), - role: MessagesRole::Assistant, - content: vec![], - model: resp.model, - stop_reason: None, - stop_sequence: None, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }, - }); - } - - // Handle content delta - if let Some(content) = &choice.delta.content { - if !content.is_empty() { - return Ok(MessagesStreamEvent::ContentBlockDelta { - index: 0, - delta: MessagesContentDelta::TextDelta { - text: content.clone(), - }, - }); - } - } - - // Handle tool calls - if let Some(tool_calls) = &choice.delta.tool_calls { - return convert_tool_call_deltas(tool_calls.clone()); - } - - // Handle finish reason - generate MessageDelta only (MessageStop comes later) - if let Some(finish_reason) = &choice.finish_reason { - // If we have usage data, it was already handled above - // If not, we need to generate MessageDelta with default usage - if !has_usage { - let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); - return Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }); - } - // If usage was already handled above, we don't need to do anything more here - // MessageStop will be handled when [DONE] is encountered - } - - // Default to ping for unhandled cases - Ok(MessagesStreamEvent::Ping) - } -} - -impl Into for MessagesStreamEvent { - fn into(self) -> String { - let transformed_json = serde_json::to_string(&self).unwrap_or_default(); - let event_type = match &self { - MessagesStreamEvent::MessageStart { .. } => "message_start", - MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start", - MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta", - MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop", - MessagesStreamEvent::MessageDelta { .. } => "message_delta", - MessagesStreamEvent::MessageStop => "message_stop", - MessagesStreamEvent::Ping => "ping", - }; - - let event = format!("event: {}\n", event_type); - let data = format!("data: {}\n\n", transformed_json); - event + &data - } -} - -impl TryFrom for MessagesStreamEvent { - type Error = TransformError; - - fn try_from(event: ConverseStreamEvent) -> Result { - match event { - // MessageStart - convert to Anthropic MessageStart - ConverseStreamEvent::MessageStart(start_event) => { - let role = match start_event.role { - crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, - crate::apis::amazon_bedrock::ConversationRole::Assistant => { - MessagesRole::Assistant - } - }; - - Ok(MessagesStreamEvent::MessageStart { - message: MessagesStreamMessage { - id: format!( - "bedrock-stream-{}", - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_nanos() - ), - obj_type: "message".to_string(), - role, - content: vec![], - model: "bedrock-model".to_string(), - stop_reason: None, - stop_sequence: None, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }, - }) - } - - // ContentBlockStart - convert to Anthropic ContentBlockStart - ConverseStreamEvent::ContentBlockStart(start_event) => { - // Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas - // Anthropic expects the same pattern, so we initialize with an empty input object - match start_event.start { - crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => { - Ok(MessagesStreamEvent::ContentBlockStart { - index: start_event.content_block_index as u32, - content_block: MessagesContentBlock::ToolUse { - id: tool_use.tool_use_id, - name: tool_use.name, - input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas - cache_control: None, - }, - }) - } - } - } - - // ContentBlockDelta - convert to Anthropic ContentBlockDelta - ConverseStreamEvent::ContentBlockDelta(delta_event) => { - let delta = match delta_event.delta { - ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text }, - ContentBlockDelta::ToolUse { tool_use } => { - MessagesContentDelta::InputJsonDelta { - partial_json: tool_use.input, - } - } - }; - - Ok(MessagesStreamEvent::ContentBlockDelta { - index: delta_event.content_block_index as u32, - delta, - }) - } - - // ContentBlockStop - convert to Anthropic ContentBlockStop - ConverseStreamEvent::ContentBlockStop(stop_event) => { - Ok(MessagesStreamEvent::ContentBlockStop { - index: stop_event.content_block_index as u32, - }) - } - - // MessageStop - convert to Anthropic MessageDelta with stop reason + MessageStop - ConverseStreamEvent::MessageStop(stop_event) => { - let anthropic_stop_reason = match stop_event.stop_reason { - StopReason::EndTurn => MessagesStopReason::EndTurn, - StopReason::ToolUse => MessagesStopReason::ToolUse, - StopReason::MaxTokens => MessagesStopReason::MaxTokens, - StopReason::StopSequence => MessagesStopReason::EndTurn, - StopReason::GuardrailIntervened => MessagesStopReason::Refusal, - StopReason::ContentFiltered => MessagesStopReason::Refusal, - }; - - // Return MessageDelta (MessageStop will be sent separately by the streaming handler) - Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: anthropic_stop_reason, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: None, - cache_read_input_tokens: None, - }, - }) - } - - // Metadata - convert usage information to MessageDelta - ConverseStreamEvent::Metadata(metadata_event) => { - Ok(MessagesStreamEvent::MessageDelta { - delta: MessagesMessageDelta { - stop_reason: MessagesStopReason::EndTurn, - stop_sequence: None, - }, - usage: MessagesUsage { - input_tokens: metadata_event.usage.input_tokens, - output_tokens: metadata_event.usage.output_tokens, - cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens, - cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens, - }, - }) - } - - // Exception events - convert to Ping (could be enhanced to return error events) - ConverseStreamEvent::InternalServerException(_) - | ConverseStreamEvent::ModelStreamErrorException(_) - | ConverseStreamEvent::ServiceUnavailableException(_) - | ConverseStreamEvent::ThrottlingException(_) - | ConverseStreamEvent::ValidationException(_) => { - // TODO: Consider adding proper error handling/events - Ok(MessagesStreamEvent::Ping) - } - } - } -} - -/// Convert tool call deltas to Anthropic stream events -fn convert_tool_call_deltas( - tool_calls: Vec, -) -> Result { - for tool_call in tool_calls { - if let Some(id) = &tool_call.id { - // Tool call start - if let Some(function) = &tool_call.function { - if let Some(name) = &function.name { - return Ok(MessagesStreamEvent::ContentBlockStart { - index: tool_call.index, - content_block: MessagesContentBlock::ToolUse { - id: id.clone(), - name: name.clone(), - input: Value::Object(serde_json::Map::new()), - cache_control: None, - }, - }); - } - } - } else if let Some(function) = &tool_call.function { - if let Some(arguments) = &function.arguments { - // Tool arguments delta - return Ok(MessagesStreamEvent::ContentBlockDelta { - index: tool_call.index, - delta: MessagesContentDelta::InputJsonDelta { - partial_json: arguments.clone(), - }, - }); - } - } - } - - // Fallback to ping if no valid tool call found - Ok(MessagesStreamEvent::Ping) -} /// Convert Bedrock Message to Anthropic content blocks /// diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs index 1c8a57c3..e26cc3b4 100644 --- a/crates/hermesllm/src/transforms/response/to_openai.rs +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -1,14 +1,11 @@ use crate::apis::amazon_bedrock::{ - ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason, + ConverseOutput, ConverseResponse, StopReason, }; use crate::apis::anthropic::{ - MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason, - MessagesStreamEvent, MessagesUsage, + MessagesContentBlock, MessagesResponse, MessagesUsage, }; use crate::apis::openai::{ - ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, - FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice, - ToolCallDelta, Usage, + ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage, }; use crate::apis::openai_responses::ResponsesAPIResponse; use crate::clients::TransformError; @@ -331,416 +328,6 @@ impl TryFrom for ChatCompletionsResponse { } } -// ============================================================================ -// STREAMING TRANSFORMATIONS -// ============================================================================ - -impl TryFrom for ChatCompletionsStreamResponse { - type Error = TransformError; - - fn try_from(event: MessagesStreamEvent) -> Result { - match event { - MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk( - &message.id, - &message.model, - MessageDelta { - role: Some(Role::Assistant), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - - MessagesStreamEvent::ContentBlockStart { content_block, .. } => { - convert_content_block_start(content_block) - } - - MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), - - MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), - - MessagesStreamEvent::MessageDelta { delta, usage } => { - let finish_reason: Option = Some(delta.stop_reason.into()); - let openai_usage: Option = Some(usage.into()); - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - finish_reason, - openai_usage, - )) - } - - MessagesStreamEvent::MessageStop => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - Some(FinishReason::Stop), - None, - )), - - MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse { - id: "stream".to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: "unknown".to_string(), - choices: vec![], - usage: None, - system_fingerprint: None, - service_tier: None, - }), - } - } -} - -impl TryFrom for ChatCompletionsStreamResponse { - type Error = TransformError; - - fn try_from(event: ConverseStreamEvent) -> Result { - match event { - ConverseStreamEvent::MessageStart(start_event) => { - let role = match start_event.role { - crate::apis::amazon_bedrock::ConversationRole::User => Role::User, - crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, - }; - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: Some(role), - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )) - } - - ConverseStreamEvent::ContentBlockStart(start_event) => { - use crate::apis::amazon_bedrock::ContentBlockStart; - - match start_event.start { - ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: start_event.content_block_index as u32, - id: Some(tool_use.tool_use_id), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some(tool_use.name), - arguments: Some("".to_string()), - }), - }]), - }, - None, - None, - )), - } - } - - ConverseStreamEvent::ContentBlockDelta(delta_event) => { - use crate::apis::amazon_bedrock::ContentBlockDelta; - - match delta_event.delta { - ContentBlockDelta::Text { text } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(text), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: delta_event.content_block_index as u32, - id: None, - call_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: Some(tool_use.input), - }), - }]), - }, - None, - None, - )), - } - } - - ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()), - - ConverseStreamEvent::MessageStop(stop_event) => { - let finish_reason = match stop_event.stop_reason { - StopReason::EndTurn => FinishReason::Stop, - StopReason::ToolUse => FinishReason::ToolCalls, - StopReason::MaxTokens => FinishReason::Length, - StopReason::StopSequence => FinishReason::Stop, - StopReason::GuardrailIntervened => FinishReason::ContentFilter, - StopReason::ContentFiltered => FinishReason::ContentFilter, - }; - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - Some(finish_reason), - None, - )) - } - - ConverseStreamEvent::Metadata(metadata_event) => { - let usage = Usage { - prompt_tokens: metadata_event.usage.input_tokens, - completion_tokens: metadata_event.usage.output_tokens, - total_tokens: metadata_event.usage.total_tokens, - prompt_tokens_details: None, - completion_tokens_details: None, - }; - - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - Some(usage), - )) - } - - // Error events - convert to empty chunks (errors should be handled elsewhere) - ConverseStreamEvent::InternalServerException(_) - | ConverseStreamEvent::ModelStreamErrorException(_) - | ConverseStreamEvent::ServiceUnavailableException(_) - | ConverseStreamEvent::ThrottlingException(_) - | ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()), - } - } -} - -/// Convert content block start to OpenAI chunk -fn convert_content_block_start( - content_block: MessagesContentBlock, -) -> Result { - match content_block { - MessagesContentBlock::Text { .. } => { - // No immediate output for text block start - Ok(create_empty_openai_chunk()) - } - MessagesContentBlock::ToolUse { id, name, .. } - | MessagesContentBlock::ServerToolUse { id, name, .. } - | MessagesContentBlock::McpToolUse { id, name, .. } => { - // Tool use start → OpenAI chunk with tool_calls - Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: Some(id), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some(name), - arguments: Some("".to_string()), - }), - }]), - }, - None, - None, - )) - } - _ => Err(TransformError::UnsupportedContent( - "Unsupported content block type in stream start".to_string(), - )), - } -} - -/// Convert content delta to OpenAI chunk -fn convert_content_delta( - delta: MessagesContentDelta, -) -> Result { - match delta { - MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(text), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: Some(format!("thinking: {}", thinking)), - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - )), - MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: Some(vec![ToolCallDelta { - index: 0, - id: None, - call_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: Some(partial_json), - }), - }]), - }, - None, - None, - )), - } -} - -/// Helper to create OpenAI streaming chunk -fn create_openai_chunk( - id: &str, - model: &str, - delta: MessageDelta, - finish_reason: Option, - usage: Option, -) -> ChatCompletionsStreamResponse { - ChatCompletionsStreamResponse { - id: id.to_string(), - object: Some("chat.completion.chunk".to_string()), - created: current_timestamp(), - model: model.to_string(), - choices: vec![StreamChoice { - index: 0, - delta, - finish_reason, - logprobs: None, - }], - usage, - system_fingerprint: None, - service_tier: None, - } -} - -/// Helper to create empty OpenAI streaming chunk -fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { - create_openai_chunk( - "stream", - "unknown", - MessageDelta { - role: None, - content: None, - refusal: None, - function_call: None, - tool_calls: None, - }, - None, - None, - ) -} - -/// Convert Anthropic content blocks to OpenAI message content -fn convert_anthropic_content_to_openai( - content: &[MessagesContentBlock], -) -> Result { - let mut text_parts = Vec::new(); - - for block in content { - match block { - MessagesContentBlock::Text { text, .. } => { - text_parts.push(text.clone()); - } - MessagesContentBlock::Thinking { thinking, .. } => { - text_parts.push(format!("thinking: {}", thinking)); - } - _ => { - // Skip other content types for basic text conversion - continue; - } - } - } - - Ok(MessageContent::Text(text_parts.join("\n"))) -} - -// Stop Reason Conversions -impl Into for MessagesStopReason { - fn into(self) -> FinishReason { - match self { - MessagesStopReason::EndTurn => FinishReason::Stop, - MessagesStopReason::MaxTokens => FinishReason::Length, - MessagesStopReason::StopSequence => FinishReason::Stop, - MessagesStopReason::ToolUse => FinishReason::ToolCalls, - MessagesStopReason::PauseTurn => FinishReason::Stop, - MessagesStopReason::Refusal => FinishReason::ContentFilter, - } - } -} - /// Convert Bedrock Message to OpenAI content and tool calls /// This function extracts text content and tool calls from a Bedrock message fn convert_bedrock_message_to_openai( @@ -785,6 +372,31 @@ fn convert_bedrock_message_to_openai( Ok((content, tool_calls)) } +/// Convert Anthropic content blocks to OpenAI message content +fn convert_anthropic_content_to_openai( + content: &[MessagesContentBlock], +) -> Result { + let mut text_parts = Vec::new(); + + for block in content { + match block { + MessagesContentBlock::Text { text, .. } => { + text_parts.push(text.clone()); + } + MessagesContentBlock::Thinking { thinking, .. } => { + text_parts.push(format!("thinking: {}", thinking)); + } + _ => { + // Skip other content types for basic text conversion + continue; + } + } + } + + Ok(MessageContent::Text(text_parts.join("\n"))) +} + + #[cfg(test)] mod tests { use super::*; diff --git a/crates/hermesllm/src/transforms/response_streaming/mod.rs b/crates/hermesllm/src/transforms/response_streaming/mod.rs new file mode 100644 index 00000000..fb06cce3 --- /dev/null +++ b/crates/hermesllm/src/transforms/response_streaming/mod.rs @@ -0,0 +1,2 @@ +pub mod to_anthropic_streaming; +pub mod to_openai_streaming; diff --git a/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs b/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs new file mode 100644 index 00000000..61939dd7 --- /dev/null +++ b/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs @@ -0,0 +1,280 @@ +use crate::apis::amazon_bedrock::{ + ContentBlockDelta, ConverseStreamEvent, StopReason, +}; +use crate::apis::anthropic::{ + MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, + MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage, +}; +use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta, +}; +use crate::clients::TransformError; +use serde_json::Value; + +impl TryFrom for MessagesStreamEvent { + type Error = TransformError; + + fn try_from(resp: ChatCompletionsStreamResponse) -> Result { + if resp.choices.is_empty() { + return Ok(MessagesStreamEvent::Ping); + } + + let choice = &resp.choices[0]; + + // Handle final chunk with usage + let has_usage = resp.usage.is_some(); + if let Some(usage) = resp.usage { + if let Some(finish_reason) = &choice.finish_reason { + let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); + return Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: usage.into(), + }); + } + } + + // NOTE: We do NOT emit MessageStart here anymore! + // The AnthropicMessagesStreamBuffer will inject message_start and content_block_start + // when it sees the first content_block_delta. This solves the problem where OpenAI + // sends both role and content in the same chunk - we can only return one event here, + // so we prioritize the content and let the buffer handle lifecycle events. + + // Handle content delta (even if role is present in the same chunk) + if let Some(content) = &choice.delta.content { + if !content.is_empty() { + return Ok(MessagesStreamEvent::ContentBlockDelta { + index: 0, + delta: MessagesContentDelta::TextDelta { + text: content.clone(), + }, + }); + } + } + + // Handle tool calls + if let Some(tool_calls) = &choice.delta.tool_calls { + return convert_tool_call_deltas(tool_calls.clone()); + } + + // Handle finish reason - generate MessageDelta only (MessageStop comes later) + if let Some(finish_reason) = &choice.finish_reason { + // If we have usage data, it was already handled above + // If not, we need to generate MessageDelta with default usage + if !has_usage { + let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into(); + return Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }); + } + // If usage was already handled above, we don't need to do anything more here + // MessageStop will be handled when [DONE] is encountered + } + + // Default to ping for unhandled cases + Ok(MessagesStreamEvent::Ping) + } +} + +impl Into for MessagesStreamEvent { + fn into(self) -> String { + let transformed_json = serde_json::to_string(&self).unwrap_or_default(); + let event_type = match &self { + MessagesStreamEvent::MessageStart { .. } => "message_start", + MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start", + MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta", + MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop", + MessagesStreamEvent::MessageDelta { .. } => "message_delta", + MessagesStreamEvent::MessageStop => "message_stop", + MessagesStreamEvent::Ping => "ping", + }; + + let event = format!("event: {}\n", event_type); + let data = format!("data: {}\n\n", transformed_json); + event + &data + } +} + +impl TryFrom for MessagesStreamEvent { + type Error = TransformError; + + fn try_from(event: ConverseStreamEvent) -> Result { + match event { + // MessageStart - convert to Anthropic MessageStart + ConverseStreamEvent::MessageStart(start_event) => { + let role = match start_event.role { + crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => { + MessagesRole::Assistant + } + }; + + Ok(MessagesStreamEvent::MessageStart { + message: MessagesStreamMessage { + id: format!( + "bedrock-stream-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + ), + obj_type: "message".to_string(), + role, + content: vec![], + model: "bedrock-model".to_string(), + stop_reason: None, + stop_sequence: None, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }, + }) + } + + // ContentBlockStart - convert to Anthropic ContentBlockStart + ConverseStreamEvent::ContentBlockStart(start_event) => { + // Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas + // Anthropic expects the same pattern, so we initialize with an empty input object + match start_event.start { + crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => { + Ok(MessagesStreamEvent::ContentBlockStart { + index: start_event.content_block_index as u32, + content_block: MessagesContentBlock::ToolUse { + id: tool_use.tool_use_id, + name: tool_use.name, + input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas + cache_control: None, + }, + }) + } + } + } + + // ContentBlockDelta - convert to Anthropic ContentBlockDelta + ConverseStreamEvent::ContentBlockDelta(delta_event) => { + let delta = match delta_event.delta { + ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text }, + ContentBlockDelta::ToolUse { tool_use } => { + MessagesContentDelta::InputJsonDelta { + partial_json: tool_use.input, + } + } + }; + + Ok(MessagesStreamEvent::ContentBlockDelta { + index: delta_event.content_block_index as u32, + delta, + }) + } + + // ContentBlockStop - convert to Anthropic ContentBlockStop + ConverseStreamEvent::ContentBlockStop(stop_event) => { + Ok(MessagesStreamEvent::ContentBlockStop { + index: stop_event.content_block_index as u32, + }) + } + + // MessageStop - convert to Anthropic MessageDelta with stop reason + MessageStop + ConverseStreamEvent::MessageStop(stop_event) => { + let anthropic_stop_reason = match stop_event.stop_reason { + StopReason::EndTurn => MessagesStopReason::EndTurn, + StopReason::ToolUse => MessagesStopReason::ToolUse, + StopReason::MaxTokens => MessagesStopReason::MaxTokens, + StopReason::StopSequence => MessagesStopReason::EndTurn, + StopReason::GuardrailIntervened => MessagesStopReason::Refusal, + StopReason::ContentFiltered => MessagesStopReason::Refusal, + }; + + // Return MessageDelta (MessageStop will be sent separately by the streaming handler) + Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }) + } + + // Metadata - convert usage information to MessageDelta + ConverseStreamEvent::Metadata(metadata_event) => { + Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: MessagesStopReason::EndTurn, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: metadata_event.usage.input_tokens, + output_tokens: metadata_event.usage.output_tokens, + cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens, + cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens, + }, + }) + } + + // Exception events - convert to Ping (could be enhanced to return error events) + ConverseStreamEvent::InternalServerException(_) + | ConverseStreamEvent::ModelStreamErrorException(_) + | ConverseStreamEvent::ServiceUnavailableException(_) + | ConverseStreamEvent::ThrottlingException(_) + | ConverseStreamEvent::ValidationException(_) => { + // TODO: Consider adding proper error handling/events + Ok(MessagesStreamEvent::Ping) + } + } + } +} + +/// Convert tool call deltas to Anthropic stream events +fn convert_tool_call_deltas( + tool_calls: Vec, +) -> Result { + for tool_call in tool_calls { + if let Some(id) = &tool_call.id { + // Tool call start + if let Some(function) = &tool_call.function { + if let Some(name) = &function.name { + return Ok(MessagesStreamEvent::ContentBlockStart { + index: tool_call.index, + content_block: MessagesContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: Value::Object(serde_json::Map::new()), + cache_control: None, + }, + }); + } + } + } else if let Some(function) = &tool_call.function { + if let Some(arguments) = &function.arguments { + // Tool arguments delta + return Ok(MessagesStreamEvent::ContentBlockDelta { + index: tool_call.index, + delta: MessagesContentDelta::InputJsonDelta { + partial_json: arguments.clone(), + }, + }); + } + } + } + + // Fallback to ping if no valid tool call found + Ok(MessagesStreamEvent::Ping) +} diff --git a/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs new file mode 100644 index 00000000..9e2f083e --- /dev/null +++ b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs @@ -0,0 +1,527 @@ +use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason}; +use crate::apis::anthropic::{ + MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent}; +use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason, + FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage, +}; +use crate::apis::openai_responses::ResponsesAPIStreamEvent; + +use crate::clients::TransformError; +use crate::transforms::lib::*; + +// ============================================================================ +// PROVIDER STREAMING TRANSFORMATIONS TO OPENAI FORMAT +// ============================================================================ +// +// This module handles business logic for converting streaming events from +// various providers (Anthropic, Bedrock, etc.) into OpenAI's ChatCompletions format. +// +// # Architecture Separation +// +// **Provider Transformations** (this module): +// - Business logic for converting between provider formats +// - Uses Rust traits (TryFrom, Into) for type-safe conversions +// - Stateless event-by-event transformation +// - Example: MessagesStreamEvent → ChatCompletionsStreamResponse +// +// **Wire Format Buffering** (`apis/streaming_shapes/`): +// - SSE protocol handling (data:, event: lines) +// - State accumulation and lifecycle management +// - Buffering for stateful APIs (v1/responses) +// - Example: ChatCompletionsToResponsesTransformer +// +// # Flow +// +// ```text +// Anthropic Event → [Provider Transform] → OpenAI Event → [Wire Buffer] → SSE Wire Format +// (business) (this module) (protocol) (streaming_shapes) (network) +// ``` +// +// ============================================================================ + +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(event: MessagesStreamEvent) -> Result { + match event { + MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk( + &message.id, + &message.model, + MessageDelta { + role: Some(Role::Assistant), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + + MessagesStreamEvent::ContentBlockStart { content_block, .. } => { + convert_content_block_start(content_block) + } + + MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta), + + MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), + + MessagesStreamEvent::MessageDelta { delta, usage } => { + let finish_reason: Option = Some(delta.stop_reason.into()); + let openai_usage: Option = Some(usage.into()); + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + finish_reason, + openai_usage, + )) + } + + MessagesStreamEvent::MessageStop => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(FinishReason::Stop), + None, + )), + + MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse { + id: "stream".to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: "unknown".to_string(), + choices: vec![], + usage: None, + system_fingerprint: None, + service_tier: None, + }), + } + } +} + +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(event: ConverseStreamEvent) -> Result { + match event { + ConverseStreamEvent::MessageStart(start_event) => { + let role = match start_event.role { + crate::apis::amazon_bedrock::ConversationRole::User => Role::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: Some(role), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )) + } + + ConverseStreamEvent::ContentBlockStart(start_event) => { + use crate::apis::amazon_bedrock::ContentBlockStart; + + match start_event.start { + ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: start_event.content_block_index as u32, + id: Some(tool_use.tool_use_id), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(tool_use.name), + arguments: Some("".to_string()), + }), + }]), + }, + None, + None, + )), + } + } + + ConverseStreamEvent::ContentBlockDelta(delta_event) => { + use crate::apis::amazon_bedrock::ContentBlockDelta; + + match delta_event.delta { + ContentBlockDelta::Text { text } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: delta_event.content_block_index as u32, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(tool_use.input), + }), + }]), + }, + None, + None, + )), + } + } + + ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()), + + ConverseStreamEvent::MessageStop(stop_event) => { + let finish_reason = match stop_event.stop_reason { + StopReason::EndTurn => FinishReason::Stop, + StopReason::ToolUse => FinishReason::ToolCalls, + StopReason::MaxTokens => FinishReason::Length, + StopReason::StopSequence => FinishReason::Stop, + StopReason::GuardrailIntervened => FinishReason::ContentFilter, + StopReason::ContentFiltered => FinishReason::ContentFilter, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(finish_reason), + None, + )) + } + + ConverseStreamEvent::Metadata(metadata_event) => { + let usage = Usage { + prompt_tokens: metadata_event.usage.input_tokens, + completion_tokens: metadata_event.usage.output_tokens, + total_tokens: metadata_event.usage.total_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + Some(usage), + )) + } + + // Error events - convert to empty chunks (errors should be handled elsewhere) + ConverseStreamEvent::InternalServerException(_) + | ConverseStreamEvent::ModelStreamErrorException(_) + | ConverseStreamEvent::ServiceUnavailableException(_) + | ConverseStreamEvent::ThrottlingException(_) + | ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()), + } + } +} + +/// Convert content block start to OpenAI chunk +fn convert_content_block_start( + content_block: MessagesContentBlock, +) -> Result { + match content_block { + MessagesContentBlock::Text { .. } => { + // No immediate output for text block start + Ok(create_empty_openai_chunk()) + } + MessagesContentBlock::ToolUse { id, name, .. } + | MessagesContentBlock::ServerToolUse { id, name, .. } + | MessagesContentBlock::McpToolUse { id, name, .. } => { + // Tool use start → OpenAI chunk with tool_calls + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: Some(id), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(name), + arguments: Some("".to_string()), + }), + }]), + }, + None, + None, + )) + } + _ => Err(TransformError::UnsupportedContent( + "Unsupported content block type in stream start".to_string(), + )), + } +} + +/// Convert content delta to OpenAI chunk +fn convert_content_delta( + delta: MessagesContentDelta, +) -> Result { + match delta { + MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(format!("thinking: {}", thinking)), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )), + MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: 0, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(partial_json), + }), + }]), + }, + None, + None, + )), + } +} + +/// Helper to create OpenAI streaming chunk +fn create_openai_chunk( + id: &str, + model: &str, + delta: MessageDelta, + finish_reason: Option, + usage: Option, +) -> ChatCompletionsStreamResponse { + ChatCompletionsStreamResponse { + id: id.to_string(), + object: Some("chat.completion.chunk".to_string()), + created: current_timestamp(), + model: model.to_string(), + choices: vec![StreamChoice { + index: 0, + delta, + finish_reason, + logprobs: None, + }], + usage, + system_fingerprint: None, + service_tier: None, + } +} + +/// Helper to create empty OpenAI streaming chunk +fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { + create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + ) +} + +// Stop Reason Conversions +impl Into for MessagesStopReason { + fn into(self) -> FinishReason { + match self { + MessagesStopReason::EndTurn => FinishReason::Stop, + MessagesStopReason::MaxTokens => FinishReason::Length, + MessagesStopReason::StopSequence => FinishReason::Stop, + MessagesStopReason::ToolUse => FinishReason::ToolCalls, + MessagesStopReason::PauseTurn => FinishReason::Stop, + MessagesStopReason::Refusal => FinishReason::ContentFilter, + } + } +} + +impl TryFrom for ResponsesAPIStreamEvent { + type Error = TransformError; + + fn try_from(chunk: ChatCompletionsStreamResponse) -> Result { + // Stateless conversion - just extract the delta information + // The buffer will manage state, item IDs, and sequence numbers + + // Extract first choice if available + if let Some(choice) = chunk.choices.first() { + let delta = &choice.delta; + + // Tool call with function name and/or arguments + if let Some(tool_calls) = &delta.tool_calls { + if let Some(tool_call) = tool_calls.first() { + // Extract call_id and name if available (metadata from initial event) + let call_id = tool_call.id.clone(); + let function_name = tool_call.function.as_ref() + .and_then(|f| f.name.clone()); + + // Check if we have function metadata (name, id) + if let Some(function) = &tool_call.function { + // If we have arguments delta, return that + if let Some(args) = &function.arguments { + return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { + output_index: choice.index as i32, + item_id: "".to_string(), // Buffer will fill this + delta: args.clone(), + sequence_number: 0, // Buffer will fill this + call_id, + name: function_name, + }); + } + + // If we have function name but no arguments yet (initial tool call event) + // Return an empty arguments delta so the buffer knows to create the item + if function.name.is_some() { + return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { + output_index: choice.index as i32, + item_id: "".to_string(), // Buffer will fill this + delta: "".to_string(), // Empty delta signals this is the initial event + sequence_number: 0, // Buffer will fill this + call_id, + name: function_name, + }); + } + } + } + } + + // Text content delta + if let Some(content) = &delta.content { + if !content.is_empty() { + return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta { + item_id: "".to_string(), // Buffer will fill this + output_index: choice.index as i32, + content_index: 0, + delta: content.clone(), + logprobs: vec![], + obfuscation: None, + sequence_number: 0, // Buffer will fill this + }); + } + } + + // Handle finish_reason - this is a completion signal + // Return an empty delta that the buffer can use to detect completion + if choice.finish_reason.is_some() { + // Return a minimal text delta to signal completion + // The buffer will handle the finish_reason and generate response.completed + return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta { + item_id: "".to_string(), // Buffer will fill this + output_index: choice.index as i32, + content_index: 0, + delta: "".to_string(), // Empty delta signals completion + logprobs: vec![], + obfuscation: None, + sequence_number: 0, // Buffer will fill this + }); + } + + // Empty delta with role only (common at stream start) + if delta.role.is_some() { + // This is typically the first chunk establishing the assistant role + // Return an empty text delta that the buffer can use to initialize state + return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta { + item_id: "".to_string(), + output_index: choice.index as i32, + content_index: 0, + delta: "".to_string(), + logprobs: vec![], + obfuscation: None, + sequence_number: 0, + }); + } + } + + // Empty chunk or no convertible content (e.g., keep-alive chunks with delta: {}) + // These are valid in OpenAI streaming and should be silently ignored + // Return error so the caller can skip these chunks without warnings + Err(TransformError::UnsupportedConversion( + "Empty or keep-alive chunk with no convertible content".to_string(), + )) + } +} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index e1594cc7..13fbcde4 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -22,10 +22,12 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent}; -use hermesllm::apis::sse::{SseEvent, SseStreamIter}; -use hermesllm::clients::endpoints::SupportedAPIsFromClients; +use hermesllm::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; +use hermesllm::apis::streaming_shapes::sse::{ + SseEvent, SseStreamBuffer, SseStreamBufferTrait, SseStreamIter, +}; +use hermesllm::clients::endpoints::SupportedAPIsFromClient; use hermesllm::providers::response::ProviderResponse; use hermesllm::{ DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType, @@ -38,7 +40,7 @@ pub struct StreamContext { streaming_response: bool, response_tokens: usize, /// The API that is requested by the client (before compatibility mapping) - client_api: Option, + client_api: Option, /// The API that should be used for the upstream provider (after compatibility mapping) resolved_api: Option, llm_providers: Rc, @@ -56,6 +58,7 @@ pub struct StreamContext { binary_frame_decoder: Option>, http_method: Option, http_protocol: Option, + sse_buffer: Option, } impl StreamContext { @@ -87,6 +90,7 @@ impl StreamContext { binary_frame_decoder: None, http_method: None, http_protocol: None, + sse_buffer: None, } } @@ -477,7 +481,17 @@ impl StreamContext { } }; - let mut response_buffer = Vec::new(); + // Initialize SSE buffer if not present + if self.sse_buffer.is_none() { + self.sse_buffer = match SseStreamBuffer::try_from((&client_api, &upstream_api)) + { + Ok(buffer) => Some(buffer), + Err(e) => { + warn!("Failed to create SSE buffer: {}", e); + return Err(Action::Continue); + } + }; + } // Process each SSE event for sse_event in sse_iter { @@ -528,12 +542,15 @@ impl StreamContext { } } - // Add transformed event to response buffer - let bytes: Vec = transformed_event.into(); - response_buffer.extend_from_slice(&bytes); + // Add transformed event to buffer (buffer may inject lifecycle events) + self.sse_buffer + .as_mut() + .unwrap() + .add_transformed_event(transformed_event); } - Ok(response_buffer) + // Get accumulated bytes from buffer and return + Ok(self.sse_buffer.as_mut().unwrap().into_bytes()) } None => { warn!("Missing client_api for non-streaming response"); @@ -545,7 +562,7 @@ impl StreamContext { fn handle_bedrock_binary_stream( &mut self, body: &[u8], - client_api: &SupportedAPIsFromClients, + client_api: &SupportedAPIsFromClient, upstream_api: &SupportedUpstreamAPIs, ) -> Result, Action> { // Initialize decoder if not present @@ -783,14 +800,14 @@ impl HttpContext for StreamContext { self.select_llm_provider(); // Check if this is a supported API endpoint - if SupportedAPIsFromClients::from_endpoint(&request_path).is_none() { + if SupportedAPIsFromClient::from_endpoint(&request_path).is_none() { self.send_http_response(404, vec![], Some(b"Unsupported endpoint")); return Action::Continue; } // Get the SupportedApi for routing decisions - let supported_api: Option = - SupportedAPIsFromClients::from_endpoint(&request_path); + let supported_api: Option = + SupportedAPIsFromClient::from_endpoint(&request_path); self.client_api = supported_api; // Debug: log provider, client API, resolved API, and request path @@ -1133,8 +1150,9 @@ impl HttpContext for StreamContext { } match self.client_api { - Some(SupportedAPIsFromClients::OpenAIChatCompletions(_)) => {} - Some(SupportedAPIsFromClients::AnthropicMessagesAPI(_)) => {} + Some(SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {} + Some(SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {} + Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {} _ => { let api_info = match &self.client_api { Some(api) => format!("{}", api), diff --git a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml index b9dcab81..0aaaa537 100644 --- a/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml +++ b/demos/use_cases/model_alias_routing/arch_config_with_aliases.yaml @@ -47,6 +47,9 @@ llm_providers: - model: ollama/llama3.1 base_url: http://host.docker.internal:11434 + # Grok (xAI) Models + - model: xai/grok-4-0709 + access_key: $GROK_API_KEY # Model aliases - friendly names that map to actual provider names model_aliases: @@ -83,5 +86,9 @@ model_aliases: coding-model: target: us.amazon.nova-premier-v1:0 + # Alias for grok testing + arch.grok.v1: + target: grok-4-0709 + tracing: random_sampling: 100 diff --git a/tests/e2e/run_e2e_tests.sh b/tests/e2e/run_e2e_tests.sh index 856b16ab..f60f79bc 100644 --- a/tests/e2e/run_e2e_tests.sh +++ b/tests/e2e/run_e2e_tests.sh @@ -65,6 +65,10 @@ log running e2e tests for model alias routing log ======================================== poetry run pytest test_model_alias_routing.py +log running e2e tests for openai responses api client +log ======================================== +poetry run pytest test_openai_responses_api_client.py + log shutting down the weather_forecast demo log ======================================= cd ../../demos/samples_python/weather_forecast diff --git a/tests/e2e/test_openai_responses_api_client.py b/tests/e2e/test_openai_responses_api_client.py new file mode 100644 index 00000000..c01c6454 --- /dev/null +++ b/tests/e2e/test_openai_responses_api_client.py @@ -0,0 +1,327 @@ +import openai +import pytest +import os +import logging +import sys + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +LLM_GATEWAY_ENDPOINT = os.getenv( + "LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions" +) + + +# ----------------------- +# v1/responses API tests +# ----------------------- +def test_openai_responses_api_non_streaming_passthrough(): + """Build a v1/responses API request (pass-through) and ensure gateway accepts it""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple responses API request using a direct model (pass-through) + resp = client.responses.create( + model="gpt-4o", input="Hello via responses passthrough" + ) + + # Print the response content - handle both responses format and chat completions format + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + # Minimal sanity checks + assert resp is not None + assert ( + getattr(resp, "id", None) is not None + or getattr(resp, "output", None) is not None + ) + + +def test_openai_responses_api_with_streaming_passthrough(): + """Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple streaming responses API request using a direct model (pass-through) + stream = client.responses.create( + model="gpt-4o", + input="Write a short haiku about coding", + stream=True, + ) + + # Collect streamed content using the official Responses API streaming shape + text_chunks = [] + final_message = None + + for event in stream: + # The Python SDK surfaces a high-level Responses streaming interface. + # We rely on its typed helpers instead of digging into model_extra. + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + # Each delta contains a text fragment + text_chunks.append(event.delta) + + # Track the final response message if provided by the SDK + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + final_message = event.response + + full_content = "".join(text_chunks) + + # Print the streaming response + print(f"\n{'='*80}") + print( + f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}" + ) + print(f"Streamed Output: {full_content}") + print(f"{'='*80}\n") + + assert len(text_chunks) > 0, "Should have received streaming text deltas" + assert len(full_content) > 0, "Should have received content" + + +def test_openai_responses_api_non_streaming_with_tools_passthrough(): + """Responses API with a function/tool definition (pass-through)""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + # Define a simple tool/function for the Responses API + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + resp = client.responses.create( + model="gpt-5", + input="Call the echo tool", + tools=tools, + ) + + assert resp is not None + assert ( + getattr(resp, "id", None) is not None + or getattr(resp, "output", None) is not None + ) + + +def test_openai_responses_api_with_streaming_with_tools_passthrough(): + """Responses API with a function/tool definition (streaming, pass-through)""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + stream = client.responses.create( + model="gpt-5", + input="Call the echo tool", + tools=tools, + stream=True, + ) + + text_chunks = [] + tool_calls = [] + + for event in stream: + etype = getattr(event, "type", None) + + # Collect streamed text output + if etype == "response.output_text.delta" and getattr(event, "delta", None): + text_chunks.append(event.delta) + + # Collect streamed tool call arguments + if etype == "response.function_call_arguments.delta" and getattr( + event, "delta", None + ): + tool_calls.append(event.delta) + + full_text = "".join(text_chunks) + + print(f"\n{'='*80}") + print("Responses tools streaming test") + print(f"Streamed text: {full_text}") + print(f"Tool call argument chunks: {len(tool_calls)}") + print(f"{'='*80}\n") + + # We expect either streamed text output or streamed tool-call arguments + assert ( + full_text or tool_calls + ), "Expected streamed text or tool call argument deltas from Responses tools stream" + + +def test_openai_responses_api_non_streaming_upstream_chat_completions(): + """Send a v1/responses request using the grok alias to verify translation/routing""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + resp = client.responses.create( + model="arch.grok.v1", input="Hello, translate this via grok alias" + ) + + # Print the response content - handle both responses format and chat completions format + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + assert resp is not None + assert resp.id is not None + + +def test_openai_responses_api_with_streaming_upstream_chat_completions(): + """Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple streaming responses API request using a direct model (pass-through) + stream = client.responses.create( + model="arch.grok.v1", + input="Write a short haiku about coding", + stream=True, + ) + + # Collect streamed content using the official Responses API streaming shape + text_chunks = [] + final_message = None + + for event in stream: + # The Python SDK surfaces a high-level Responses streaming interface. + # We rely on its typed helpers instead of digging into model_extra. + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + # Each delta contains a text fragment + text_chunks.append(event.delta) + + # Track the final response message if provided by the SDK + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + final_message = event.response + + full_content = "".join(text_chunks) + + # Print the streaming response + print(f"\n{'='*80}") + print( + f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}" + ) + print(f"Streamed Output: {full_content}") + print(f"{'='*80}\n") + + assert len(text_chunks) > 0, "Should have received streaming text deltas" + assert len(full_content) > 0, "Should have received content" + + +def test_openai_responses_api_non_streaming_with_tools_upstream_chat_completions(): + """Responses API wioutputling routed to grok via alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + resp = client.responses.create( + model="arch.grok.v1", + input="Call the echo tool", + tools=tools, + ) + + assert resp.id is not None + + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + +def test_openai_responses_api_streaming_with_tools_upstream_chat_completions(): + """Responses API with a function/tool definition (streaming, pass-through)""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + stream = client.responses.create( + model="arch.grok.v1", + input="Call the echo tool", + tools=tools, + stream=True, + ) + + text_chunks = [] + tool_calls = [] + + for event in stream: + etype = getattr(event, "type", None) + + # Collect streamed text output + if etype == "response.output_text.delta" and getattr(event, "delta", None): + text_chunks.append(event.delta) + + # Collect streamed tool call arguments + if etype == "response.function_call_arguments.delta" and getattr( + event, "delta", None + ): + tool_calls.append(event.delta) + + full_text = "".join(text_chunks) + + print(f"\n{'='*80}") + print("Responses tools streaming test") + print(f"Streamed text: {full_text}") + print(f"Tool call argument chunks: {len(tool_calls)}") + print(f"{'='*80}\n") + + # We expect either streamed text output or streamed tool-call arguments + assert ( + full_text or tool_calls + ), "Expected streamed text or tool call argument deltas from Responses tools stream"