diff --git a/crates/Cargo.lock b/crates/Cargo.lock index fae883a5..dba20614 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -777,6 +777,7 @@ name = "hermesllm" version = "0.1.0" dependencies = [ "aws-smithy-eventstream", + "bytes", "serde", "serde_json", "serde_with", @@ -1230,6 +1231,7 @@ name = "llm_gateway" version = "0.1.0" dependencies = [ "acap", + "bytes", "common", "derivative", "governor", @@ -2111,9 +2113,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.14.0" +version = "2.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" dependencies = [ "core-foundation-sys", "libc", diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index 7ba48ea3..ab2390bf 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -9,3 +9,4 @@ serde_json = "1.0.140" serde_with = {version = "3.12.0", features = ["base64"]} thiserror = "2.0.12" aws-smithy-eventstream = "0.60" +bytes = "1.10" diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs index 64bf6e74..0c4eb262 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock.rs +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -7,6 +7,7 @@ use std::collections::HashMap; use super::ApiDefinition; use crate::providers::request::{ProviderRequest, ProviderRequestError}; +use crate::providers::response::ProviderStreamResponse; // ============================================================================ // AMAZON BEDROCK CONVERSE API ENUMERATION @@ -685,7 +686,7 @@ pub struct MessageStartEvent { pub struct ContentBlockStartEvent { /// Content block index #[serde(rename = "contentBlockIndex")] - pub content_block_index: u32, + pub content_block_index: i32, /// Start information pub start: ContentBlockStart, } @@ -707,18 +708,16 @@ pub enum ContentBlockStart { pub struct ContentBlockDeltaEvent { /// Content block index #[serde(rename = "contentBlockIndex")] - pub content_block_index: u32, + pub content_block_index: i32, /// Delta information pub delta: ContentBlockDelta, } /// Content block delta information #[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(tag = "type")] +#[serde(untagged)] pub enum ContentBlockDelta { - #[serde(rename = "text")] Text { text: String }, - #[serde(rename = "toolUse")] ToolUse { input: String }, } @@ -727,7 +726,7 @@ pub enum ContentBlockDelta { pub struct ContentBlockStopEvent { /// Content block index #[serde(rename = "contentBlockIndex")] - pub content_block_index: u32, + pub content_block_index: i32, } /// Message stop event @@ -867,6 +866,198 @@ impl crate::providers::response::TokenUsage for BedrockTokenUsage { } } +// ============================================================================ +// EVENT STREAM PARSING +// ============================================================================ + +/// Convert from aws-smithy-eventstream DecodedFrame to ConverseStreamEvent +impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEvent { + type Error = BedrockError; + + fn try_from(frame: &aws_smithy_eventstream::frame::DecodedFrame) -> Result { + // Only process Complete frames, skip Incomplete + let message = match frame { + aws_smithy_eventstream::frame::DecodedFrame::Complete(msg) => msg, + aws_smithy_eventstream::frame::DecodedFrame::Incomplete => { + return Err(BedrockError::Validation { + message: "Expected Complete frame, got Incomplete".to_string(), + }) + } + }; + + // Extract the :event-type and :message-type headers + let event_type = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + .and_then(|h| h.value().as_string().ok()) + .ok_or_else(|| BedrockError::Validation { + message: "Missing :event-type header".to_string(), + })? + .as_str(); + + let message_type = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":message-type") + .and_then(|h| h.value().as_string().ok()) + .ok_or_else(|| BedrockError::Validation { + message: "Missing :message-type header".to_string(), + })? + .as_str(); + + let payload = message.payload(); + + // Parse the event based on message type and event type + match message_type { + "event" => match event_type { + "messageStart" => { + let event: MessageStartEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::MessageStart(event)) + } + "contentBlockStart" => { + let event: ContentBlockStartEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ContentBlockStart(event)) + } + "contentBlockDelta" => { + let event: ContentBlockDeltaEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ContentBlockDelta(event)) + } + "contentBlockStop" => { + let event: ContentBlockStopEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ContentBlockStop(event)) + } + "messageStop" => { + let event: MessageStopEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::MessageStop(event)) + } + "metadata" => { + let event: ConverseStreamMetadataEvent = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::Metadata(event)) + } + unknown => Err(BedrockError::Validation { + message: format!("Unknown event type: {}", unknown), + }), + }, + "exception" => match event_type { + "internalServerException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::InternalServerException(exception)) + } + "modelStreamErrorException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ModelStreamErrorException(exception)) + } + "serviceUnavailableException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ServiceUnavailableException(exception)) + } + "throttlingException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ThrottlingException(exception)) + } + "validationException" => { + let exception: BedrockException = + serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; + Ok(ConverseStreamEvent::ValidationException(exception)) + } + unknown => Err(BedrockError::Validation { + message: format!("Unknown exception type: {}", unknown), + }), + }, + unknown => Err(BedrockError::Validation { + message: format!("Unknown message type: {}", unknown), + }), + } + } +} + +impl Into for ConverseStreamEvent { + fn into(self) -> String { + let transformed_json = serde_json::to_string(&self).unwrap_or_default(); + let event_type = match &self { + ConverseStreamEvent::MessageStart { .. } => "message_start", + ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start", + ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta", + ConverseStreamEvent::ContentBlockStop { .. } => "content_block_stop", + ConverseStreamEvent::MessageStop { .. } => "message_stop", + ConverseStreamEvent::Metadata { .. } => "metadata", + ConverseStreamEvent::InternalServerException { .. } => "internal_server_exception", + ConverseStreamEvent::ModelStreamErrorException { .. } => "model_stream_error_exception", + ConverseStreamEvent::ServiceUnavailableException { .. } => "service_unavailable_exception", + ConverseStreamEvent::ThrottlingException { .. } => "throttling_exception", + ConverseStreamEvent::ValidationException { .. } => "validation_exception", + }; + + let event = format!("event: {}\n", event_type); + let data = format!("data: {}\n\n", transformed_json); + event + &data + } +} + + +// Implement ProviderStreamResponse for ConverseStreamEvent +impl ProviderStreamResponse for ConverseStreamEvent { + fn content_delta(&self) -> Option<&str> { + match self { + ConverseStreamEvent::ContentBlockDelta(event) => { + match &event.delta { + ContentBlockDelta::Text { text } => Some(text), + ContentBlockDelta::ToolUse { .. } => None, + } + } + _ => None, + } + } + + fn is_final(&self) -> bool { + matches!(self, ConverseStreamEvent::MessageStop(_)) + } + + fn role(&self) -> Option<&str> { + match self { + ConverseStreamEvent::MessageStart(event) => Some(event.role.as_str()), + _ => None, + } + } + + fn event_type(&self) -> Option<&str> { + Some(match self { + ConverseStreamEvent::MessageStart(_) => "messageStart", + ConverseStreamEvent::ContentBlockStart(_) => "contentBlockStart", + ConverseStreamEvent::ContentBlockDelta(_) => "contentBlockDelta", + ConverseStreamEvent::ContentBlockStop(_) => "contentBlockStop", + ConverseStreamEvent::MessageStop(_) => "messageStop", + ConverseStreamEvent::Metadata(_) => "metadata", + ConverseStreamEvent::InternalServerException(_) => "internalServerException", + ConverseStreamEvent::ModelStreamErrorException(_) => "modelStreamErrorException", + ConverseStreamEvent::ServiceUnavailableException(_) => "serviceUnavailableException", + ConverseStreamEvent::ThrottlingException(_) => "throttlingException", + ConverseStreamEvent::ValidationException(_) => "validationException", + }) + } +} + +// Add as_str helper for ConversationRole +impl ConversationRole { + pub fn as_str(&self) -> &'static str { + match self { + ConversationRole::User => "user", + ConversationRole::Assistant => "assistant", + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 308cf1b7..2ae58b67 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -7,8 +7,9 @@ pub mod clients; pub mod transforms; // Re-export important types and traits pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError}; -pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, SseEvent, SseStreamIter}; +pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, SseEvent, SseStreamIter, BedrockBinaryFrameDecoder}; pub use providers::id::ProviderId; +pub use aws_smithy_eventstream::frame::DecodedFrame; //TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings @@ -77,4 +78,133 @@ mod tests { let final_event = streaming_iter.next(); assert!(final_event.is_none()); // Should be None because iterator stops at [DONE] } + + /// Test AWS Event Stream decoding for Bedrock ConverseStream responses. + /// + /// This test demonstrates how to: + /// 1. Use MessageFrameDecoder to decode AWS Event Stream frames + /// 2. Handle chunked network arrivals with buffering + /// 3. Extract event types from message headers + /// 4. Parse JSON payloads from decoded messages + /// 5. Reconstruct streaming content from contentBlockDelta events + /// + /// The decoder handles frame boundaries automatically - you just keep calling + /// decode_frame() until it returns Incomplete, which means you've processed + /// all complete frames in the buffer. + #[test] + fn test_amazon_bedrock_streaming_response() { + use aws_smithy_eventstream::frame::{MessageFrameDecoder, DecodedFrame}; + use bytes::{Buf, BytesMut}; + use std::fs; + use std::path::PathBuf; + + // Read the response.hex file from tests/e2e directory + // Use absolute path to avoid cargo test working directory issues + let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../tests/e2e/response.hex"); + let response_data = fs::read(&test_file) + .unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e)); + + println!("šŸ“Š Response data size: {} bytes\n", response_data.len()); + + // Create decoder and buffer that implements Buf trait + // BytesMut automatically tracks position as decoder advances it! + let mut decoder = MessageFrameDecoder::new(); + let mut simulated_network_buffer = BytesMut::new(); + let mut frame_count = 0; + let mut content_chunks = Vec::new(); + + // Simulate chunked network arrivals - process as data comes in + let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000]; + let mut offset = 0; + let mut chunk_num = 0; + + println!("šŸ”„ Simulating chunked network arrivals...\n"); + + // Process chunks as they "arrive" from the network + while offset < response_data.len() { + // Receive next chunk from network + let chunk_size = chunk_sizes[chunk_num % chunk_sizes.len()]; + let end = (offset + chunk_size).min(response_data.len()); + let chunk = &response_data[offset..end]; + + chunk_num += 1; + simulated_network_buffer.extend_from_slice(chunk); + offset = end; + + println!("šŸ“¦ Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)", + chunk_num, chunk.len(), simulated_network_buffer.len(), simulated_network_buffer.remaining()); + + // Try to decode all complete frames from buffer + // The Buf trait tracks position automatically! + loop { + let bytes_before = simulated_network_buffer.remaining(); + match decoder.decode_frame(&mut simulated_network_buffer) { + Ok(DecodedFrame::Complete(message)) => { + frame_count += 1; + let consumed = bytes_before - simulated_network_buffer.remaining(); + + println!(" āœ… Frame {}: decoded ({} bytes, {} bytes remaining)", + frame_count, consumed, simulated_network_buffer.remaining()); + + // Get event type from headers + let event_type = message.headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + .and_then(|h| { + h.value().as_string().ok().map(|s| s.as_str().to_string()) + }); + + if let Some(ref evt) = event_type { + println!(" Event: {}", evt); + } + + // Parse payload and extract content + let payload = message.payload(); + if !payload.is_empty() { + if let Ok(json) = serde_json::from_slice::(payload) { + if event_type.as_deref() == Some("contentBlockDelta") { + if let Some(delta) = json.get("delta") { + if let Some(text) = delta.get("text").and_then(|t| t.as_str()) { + println!(" šŸ“ Content: \"{}\"", text); + content_chunks.push(text.to_string()); + } + } + } + } + } // Continue loop to check for more complete frames in buffer + } + Ok(DecodedFrame::Incomplete) => { + // Not enough data for a complete frame - need more chunks + println!(" ā³ Incomplete frame ({} bytes remaining) - waiting for more data\n", simulated_network_buffer.remaining()); + break; // Wait for next chunk + } + Err(e) => { + panic!("āŒ Frame decode error: {}", e); + } + } + } + } + + println!("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"); + println!("šŸ“‹ Summary:"); + println!(" Total chunks received: {}", chunk_num); + println!(" Total frames decoded: {}", frame_count); + println!(" Total content chunks: {}", content_chunks.len()); + println!(" Final buffer remaining: {} bytes", simulated_network_buffer.remaining()); + + if !content_chunks.is_empty() { + let full_text = content_chunks.join(""); + println!("\nšŸ“„ Full reconstructed content:"); + println!("{}", full_text); + println!("\n Characters: {}", full_text.len()); + println!(" Estimated tokens: ~{}", full_text.len() / 4); + } + + // Ensure we decoded at least one frame + assert!(frame_count > 0, "Should decode at least one frame"); + + // Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame + assert_eq!(simulated_network_buffer.remaining(), 0, "All bytes should be consumed, {} bytes remain", simulated_network_buffer.remaining()); + } } diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 87c4aee8..0af3982c 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -180,8 +180,13 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT Ok(ProviderRequestType::BedrockConverse(bedrock_req)) } - (ProviderRequestType::ChatCompletionsRequest(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { - todo!("ChatCompletionsRequest to Amazon Bedrock Stream conversion not implemented yet") + (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { + let bedrock_req = ConverseStreamRequest::try_from(chat_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) } (ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { let bedrock_req = ConverseRequest::try_from(messages_req) @@ -191,8 +196,13 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT })?; Ok(ProviderRequestType::BedrockConverse(bedrock_req)) } - (ProviderRequestType::MessagesRequest(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { - todo!("MessagesRequest to Amazon Bedrock Stream conversion not implemented yet") + (ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { + let bedrock_req = ConverseStreamRequest::try_from(messages_req) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert MessagesRequest to Amazon Bedrock request: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) } // Amazon Bedrock to other APIs conversions diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 831636a8..3b14d0b8 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,5 +1,6 @@ use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::providers::id::ProviderId; +use bytes::Buf; use serde::{Serialize, Deserialize}; use std::error::Error; use std::fmt; @@ -12,6 +13,7 @@ use crate::apis::anthropic::MessagesStreamEvent; use crate::clients::endpoints::SupportedAPIs; use crate::apis::anthropic::MessagesResponse; use crate::apis::amazon_bedrock::ConverseResponse; +use crate::apis::amazon_bedrock::ConverseStreamEvent; /// Trait for token usage information pub trait TokenUsage { @@ -32,6 +34,7 @@ pub enum ProviderResponseType { pub enum ProviderStreamResponseType { ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), MessagesStreamEvent(MessagesStreamEvent), + ConverseStreamEvent(ConverseStreamEvent), } @@ -60,7 +63,6 @@ impl ProviderResponse for ProviderResponseType { } } } - pub trait ProviderStreamResponse: Send + Sync { /// Get the content delta for this chunk fn content_delta(&self) -> Option<&str>; @@ -73,6 +75,7 @@ pub trait ProviderStreamResponse: Send + Sync { /// Get event type for SSE streaming (used by Anthropic) fn event_type(&self) -> Option<&str>; + } impl ProviderStreamResponse for ProviderStreamResponseType { @@ -80,6 +83,7 @@ impl ProviderStreamResponse for ProviderStreamResponseType { match self { ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(), ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.content_delta(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.content_delta(), } } @@ -87,6 +91,7 @@ impl ProviderStreamResponse for ProviderStreamResponseType { match self { ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(), ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.is_final(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.is_final(), } } @@ -94,6 +99,7 @@ impl ProviderStreamResponse for ProviderStreamResponseType { match self { ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(), ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.role(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.role(), } } @@ -101,10 +107,140 @@ impl ProviderStreamResponse for ProviderStreamResponseType { match self { ProviderStreamResponseType::ChatCompletionsStreamResponse(_resp) => None, // OpenAI doesn't use event types ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.event_type(), + ProviderStreamResponseType::ConverseStreamEvent(resp) => resp.event_type(), // Bedrock doesn't use event types } } } +impl Into for ProviderStreamResponseType { + fn into(self) -> String { + match self { + ProviderStreamResponseType::MessagesStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ConverseStreamEvent(event) => { + // Use the Into implementation for proper SSE formatting with event lines + event.into() + } + ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => { + // For OpenAI, use simple data line format + let json = serde_json::to_string(&self).unwrap_or_default(); + format!("data: {}\n\n", json) + } + } + } +} + +// --- Response transformation logic for client API compatibility --- +impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { + type Error = std::io::Error; + + fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result { + let upstream_api = provider_id.compatible_api_for_client(client_api, false); + match (&upstream_api, client_api) { + (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderResponseType::ChatCompletionsResponse(resp)) + } + (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let resp: MessagesResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderResponseType::MessagesResponse(resp)) + } + (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to OpenAI ChatCompletions format using the transformer + let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) + } + (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to Anthropic Messages format using the transformer + let messages_resp: MessagesResponse = openai_resp.try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + Ok(ProviderResponseType::MessagesResponse(messages_resp)) + } + // Amazon Bedrock transformations + (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to OpenAI ChatCompletions format using the transformer + let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) + } + (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to Anthropic Messages format using the transformer + let messages_resp: MessagesResponse = bedrock_resp.try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + Ok(ProviderResponseType::MessagesResponse(messages_resp)) + } + _ => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation")) + } + } + } +} + +// Stream response transformation logic for client API compatibility +impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { + type Error = Box; + + fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result { + // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion + if bytes == b"[DONE]" && matches!(client_api, SupportedAPIs::AnthropicMessagesAPI(_)) { + return Ok(ProviderStreamResponseType::MessagesStreamEvent( + crate::apis::anthropic::MessagesStreamEvent::MessageStop + )); + } + match (upstream_api, client_api) { + // OpenAI upstream + (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp)) + } + (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; + let anthropic_resp = openai_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_resp)) + } + + // Anthropic upstream + (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let resp = serde_json::from_slice(bytes)?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) + } + (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; + let openai_resp = anthropic_resp.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_resp)) + } + + // Amazon Bedrock ConverseStream upstream + (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = serde_json::from_slice(bytes)?; + let anthropic_resp = bedrock_resp.try_into()?; + Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_resp)) + } + _ => { + Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation").into()) + } + } + } +} + + // ============================================================================ // SSE EVENT CONTAINER // ============================================================================ @@ -210,215 +346,6 @@ impl Into> for SseEvent { } } - -// --- Response transformation logic for client API compatibility --- -impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { - type Error = std::io::Error; - - fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result { - let upstream_api = provider_id.compatible_api_for_client(client_api, false); - match (&upstream_api, client_api) { - (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderResponseType::ChatCompletionsResponse(resp)) - } - (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let resp: MessagesResponse = serde_json::from_slice(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderResponseType::MessagesResponse(resp)) - } - (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - - // Transform to OpenAI ChatCompletions format using the transformer - let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; - Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) - } - (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - - // Transform to Anthropic Messages format using the transformer - let messages_resp: MessagesResponse = openai_resp.try_into() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; - Ok(ProviderResponseType::MessagesResponse(messages_resp)) - } - // Amazon Bedrock transformations - (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - - // Transform to OpenAI ChatCompletions format using the transformer - let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; - Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) - } - (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - - // Transform to Anthropic Messages format using the transformer - let messages_resp: MessagesResponse = bedrock_resp.try_into() - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; - Ok(ProviderResponseType::MessagesResponse(messages_resp)) - } - (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), _) => { - todo!("Amazon Bedrock streaming response transformation not implemented yet") - } - } - } -} - -// Stream response transformation logic for client API compatibility -impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { - type Error = Box; - - fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result { - match (upstream_api, client_api) { - (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp)) - } - (SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion - if bytes == b"[DONE]" { - return Ok(ProviderStreamResponseType::MessagesStreamEvent( - crate::apis::anthropic::MessagesStreamEvent::MessageStop - )); - } - - let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; - let messages_resp: crate::apis::anthropic::MessagesStreamEvent = openai_resp.try_into()?; - Ok(ProviderStreamResponseType::MessagesStreamEvent(messages_resp)) - } - (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; - Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) - } - (SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; - - // Transform to OpenAI ChatCompletions stream format using the transformer - let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = anthropic_resp.try_into()?; - Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(chat_resp)) - } - - (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - todo!("Amazon Bedrock to OpenAI streaming transformation not implemented yet") - } - - (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - todo!("Anthropic to Amazon Bedrock streaming transformation not implemented yet") - } - - (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - todo!("Amazon Bedrock streaming response transformation not implemented yet") - } - - (SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - todo!("Amazon Bedrock streaming response transformation not implemented yet") - } - } - } -} - -// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response -impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent { - type Error = Box; - - fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result { - // Create a new transformed event based on the original - let mut transformed_event = sse_event; - - // If has data, parse the data as a provider stream response (business logic layer) - if transformed_event.data.is_some() { - let data_str = transformed_event.data.as_ref().unwrap(); - let data_bytes = data_str.as_bytes(); - let transformed_response = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; - let transformed_json = serde_json::to_string(&transformed_response)?; - transformed_event.sse_transform_buffer = format!("data: {}\n\n", transformed_json); - transformed_event.provider_stream_response = Some(transformed_response); - } - - match (client_api, upstream_api) { - (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { - // No transformation needed - } - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { - // No transformation needed - } - - (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { - // This should never get called since we are in the streaming path - - } - - (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { - // TODO: Implement OpenAI to Amazon Bedrock SSE transformation - } - - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { - // TODO: Implement Anthropic to Amazon Bedrock SSE transformation - } - - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { - // TODO: Implement Anthropic to Amazon Bedrock SSE transformation - } - - (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { - if let Some(provider_response) = &transformed_event.provider_stream_response { - if let Some(event_type) = provider_response.event_type() { - // This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s) - if event_type == "message_start" { - let content_block_start_json = serde_json::json!({ - "type": "content_block_start", - "index": 0, - "content_block": { - "type": "text", - "text": "" - } - }); - // Format as proper SSE: MessageStart first, then ContentBlockStart - transformed_event.sse_transform_buffer = format!( - "event: {}\n{}\nevent: content_block_start\ndata: {}\n\n", - event_type, - transformed_event.sse_transform_buffer, - content_block_start_json, - ); - } else if event_type == "message_delta" { - let content_block_stop_json = serde_json::json!({ - "type": "content_block_stop", - "index": 0 - }); - // Format as proper SSE: ContentBlockStop first, then MessageDelta - transformed_event.sse_transform_buffer = format!( - "event: content_block_stop\ndata: {}\n\nevent: {}\n{}", - content_block_stop_json, - event_type, - transformed_event.sse_transform_buffer - ); - } else { - transformed_event.sse_transform_buffer = format!("event: {}\n{}", event_type, transformed_event.sse_transform_buffer); - } - } - // If event_type is None, we just keep the data line as-is without an event line - // This handles cases where the transformation might not produce a valid event type - } - } - (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { - if transformed_event.is_event_only() && transformed_event.event.is_some() { - transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI - } - } - } - - Ok(transformed_event) - } -} - #[derive(Debug)] pub struct SseParseError { pub message: String, @@ -432,9 +359,178 @@ impl fmt::Display for SseParseError { impl Error for SseParseError {} -// ============================================================================ -// GENERIC SSE STREAMING ITERATOR (Container Only) -// ============================================================================ +// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response +impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent { + type Error = Box; + + fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result { + // Create a new transformed event based on the original + let mut transformed_event = sse_event; + + // If has data, parse the data as a provider stream response (business logic layer) + if transformed_event.data.is_some() { + let data_str = transformed_event.data.as_ref().unwrap(); + let data_bytes = data_str.as_bytes(); + let transformed_response: ProviderStreamResponseType = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; + + // Convert to SSE string explicitly to avoid type ambiguity + let sse_string: String = transformed_response.clone().into(); + transformed_event.sse_transform_buffer = sse_string; + transformed_event.provider_stream_response = Some(transformed_response); + } + + match (client_api, upstream_api) { + (SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { + if let Some(provider_response) = &transformed_event.provider_stream_response { + if let Some(event_type) = provider_response.event_type() { + // This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s) + if event_type == "message_start" { + // Create ContentBlockStart event and format it using Into + let content_block_start = MessagesStreamEvent::ContentBlockStart { + index: 0, + content_block: crate::apis::anthropic::MessagesContentBlock::Text { + text: String::new(), + cache_control: None, + }, + }; + let content_block_start_sse: String = content_block_start.into(); + + // Format as proper SSE: MessageStart first, then ContentBlockStart + // The sse_transform_buffer already contains the properly formatted MessageStart + transformed_event.sse_transform_buffer = format!( + "{}{}", + transformed_event.sse_transform_buffer, + content_block_start_sse, + ); + } else if event_type == "message_delta" { + // Create ContentBlockStop event and format it using Into + let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 }; + let content_block_stop_sse: String = content_block_stop.into(); + + // Format as proper SSE: ContentBlockStop first, then MessageDelta + transformed_event.sse_transform_buffer = format!( + "{}{}", + content_block_stop_sse, + transformed_event.sse_transform_buffer + ); + } + // For other event types, the sse_transform_buffer already has the correct format from Into + } + // If event_type is None, we just keep the data line as-is without an event line + // This handles cases where the transformation might not produce a valid event type + } + } + (SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { + if transformed_event.is_event_only() && transformed_event.event.is_some() { + transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI + } + } + _ => { + // Other combinations can be handled here as needed + } + } + + Ok(transformed_event) + } +} + +// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType +impl TryFrom<(&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { + type Error = Box; + + fn try_from((frame, client_api, upstream_api): (&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result { + use aws_smithy_eventstream::frame::DecodedFrame; + + match frame { + DecodedFrame::Complete(_) => { + // We have a complete frame - parse it based on upstream API + match (upstream_api, client_api) { + (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = bedrock_event.try_into()?; + + Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_event)) + } + (SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::OpenAIChatCompletions(_)) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_event.try_into()?; + Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event)) + } + _ => { + Err("Unsupported API combination for event-stream decoding".into()) + } + } + } + DecodedFrame::Incomplete => { + Err("Cannot convert incomplete frame to provider response".into()) + } + } + } +} + + +/// AWS Event Stream frame decoder wrapper +pub struct BedrockBinaryFrameDecoder +where + B: Buf, +{ + decoder: aws_smithy_eventstream::frame::MessageFrameDecoder, + buffer: B, + has_content_block_start_been_sent: bool, +} + +impl BedrockBinaryFrameDecoder { + /// This is a convenience constructor that creates a BytesMut buffer internally + pub fn from_bytes(bytes: &[u8]) -> Self { + let buffer = bytes::BytesMut::from(bytes); + Self { + decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(), + buffer, + has_content_block_start_been_sent: false, + } + } +} + +impl BedrockBinaryFrameDecoder +where + B: Buf, +{ + pub fn new(buffer: B) -> Self { + Self { + decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(), + buffer, + has_content_block_start_been_sent: false, + } + } + + pub fn decode_frame(&mut self) -> Option { + match self.decoder.decode_frame(&mut self.buffer) { + Ok(frame) => Some(frame), + Err(_e) => None, // Fatal decode error + } + } + + pub fn buffer_mut(&mut self) -> &mut B { + &mut self.buffer + } + + /// Check if there are any bytes remaining in the buffer + pub fn has_remaining(&self) -> bool { + self.buffer.has_remaining() + } + + /// Check if a content_block_start event has been sent + pub fn has_content_block_start_been_sent(&self) -> bool { + self.has_content_block_start_been_sent + } + + /// Set the content_block_start flag + pub fn set_content_block_start_sent(&mut self, sent: bool) { + self.has_content_block_start_been_sent = sent; + } +} /// Generic SSE (Server-Sent Events) streaming iterator container /// Parses raw SSE lines into SseEvent objects @@ -470,6 +566,7 @@ impl TryFrom<&[u8]> for SseStreamIter> { } } + impl Iterator for SseStreamIter where I: Iterator, @@ -878,4 +975,530 @@ mod tests { panic!("Expected MessagesStreamEvent::MessageStop"); } } + + #[test] + fn test_bedrock_event_stream_decoder_basic() { + use bytes::BytesMut; + + // Create a simple test with minimal data + let mut buffer = BytesMut::new(); + + // Add some arbitrary bytes (not a real event-stream frame, just for testing the decoder) + buffer.extend_from_slice(b"test data"); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + // The decoder should return Incomplete for incomplete/invalid data + // This signals the caller to wait for more data + let result = decoder.decode_frame(); + assert!(result.is_some()); + assert!(matches!(result.unwrap(), aws_smithy_eventstream::frame::DecodedFrame::Incomplete)); + + // Verify we can still access the buffer + assert!(decoder.has_remaining()); + } + + #[test] + fn test_bedrock_event_stream_decoder_with_real_frames() { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file from tests/e2e directory + let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../tests/e2e/response.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + let mut frame_count = 0; + + // Decode all frames + loop { + match decoder.decode_frame() { + Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(message)) => { + frame_count += 1; + + // Verify we can access headers + let event_type = message.headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + .and_then(|h| h.value().as_string().ok()); + + assert!(event_type.is_some(), "Frame should have :event-type header"); + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer, no more complete frames available + break; + } + None => { + // Decode error + panic!("Decode error encountered"); + } + } + } + + // We should have decoded multiple frames + assert!(frame_count > 0, "Should have decoded at least one frame"); + } + + #[test] + fn test_bedrock_event_stream_decoder_chunked_data() { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file + let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../tests/e2e/response.hex"); + + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + + // Simulate chunked network arrivals with realistic chunk sizes + // Using varying chunk sizes to test partial frame handling + let mut buffer = BytesMut::new(); + let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500]; + let mut offset = 0; + let mut total_frames = 0; + let mut chunk_num = 0; + + // CRITICAL: Create ONE decoder and reuse it across chunks + // The MessageFrameDecoder maintains state about partial frames + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + // Process all data in chunks + while offset < response_data.len() { + let chunk_size = chunk_size_pattern[chunk_num % chunk_size_pattern.len()]; + chunk_num += 1; + + let end = (offset + chunk_size).min(response_data.len()); + let chunk = &response_data[offset..end]; + + // Add new data to the buffer (accessing via buffer_mut()) + decoder.buffer_mut().extend_from_slice(chunk); + offset = end; + + // Process all available complete frames from this chunk + loop { + match decoder.decode_frame() { + Some(aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + total_frames += 1; + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // Need more data - wait for next chunk + break; + } + None => { + // Decode error + panic!("Decode error in chunked test"); + } + } + } + } + + assert!(total_frames > 0, "Should have decoded frames from chunked data"); + } + + #[test] + fn test_bedrock_decoded_frame_to_provider_response() { + test_bedrock_conversion(false); + } + + #[test] + #[ignore] // Run with: cargo test -- --ignored --nocapture + fn test_bedrock_decoded_frame_to_provider_response_verbose() { + test_bedrock_conversion(true); + } + + fn test_bedrock_conversion(verbose: bool) { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response.hex file from tests/e2e directory + let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../tests/e2e/response.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream); + + let mut conversion_count = 0; + let mut message_start_seen = false; + + // Decode and convert frames + loop { + match decoder.decode_frame() { + Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + // Convert DecodedFrame to ProviderStreamResponseType + let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); + + match result { + Ok(provider_response) => { + conversion_count += 1; + + // Verify we got a MessagesStreamEvent + assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_))); + + if verbose { + // Print the SSE string output + let sse_string: String = provider_response.clone().into(); + println!("{}", sse_string); + } + + // Check for MessageStart event + if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response { + if matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. }) { + message_start_seen = true; + } + } + } + Err(e) => { + println!("Conversion error (frame {}): {}", conversion_count, e); + } + } + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer + break; + } + None => { + panic!("Decode error"); + } + } + } + + assert!(conversion_count > 0, "Should have converted at least one frame"); + assert!(message_start_seen, "Should have seen MessageStart event"); + } + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_message_start() { + use crate::apis::openai::OpenAIApi; + use crate::apis::anthropic::AnthropicApi; + + // Create an OpenAI stream response that represents a role start (which becomes message_start in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": null + }] + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the transformation includes both message_start and content_block_start + let buffer = transformed.sse_transform_buffer; + assert!(buffer.contains("event: message_start"), "Should contain message_start event"); + assert!(buffer.contains("event: content_block_start"), "Should contain content_block_start event"); + + // Verify proper SSE format with event lines before data lines + assert!(buffer.find("event: message_start").unwrap() < buffer.find("data:").unwrap()); + assert!(buffer.find("content_block_start").is_some()); + } + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_message_delta() { + use crate::apis::openai::OpenAIApi; + use crate::apis::anthropic::AnthropicApi; + + // Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 25, + "total_tokens": 35 + } + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the transformation includes both content_block_stop and message_delta + let buffer = transformed.sse_transform_buffer; + assert!(buffer.contains("event: content_block_stop"), "Should contain content_block_stop event"); + assert!(buffer.contains("event: message_delta"), "Should contain message_delta event"); + + // Verify content_block_stop comes before message_delta + let stop_pos = buffer.find("content_block_stop").unwrap(); + let delta_pos = buffer.find("message_delta").unwrap(); + assert!(stop_pos < delta_pos, "content_block_stop should come before message_delta"); + } + + #[test] + fn test_sse_event_transformation_openai_to_anthropic_content_delta() { + use crate::apis::openai::OpenAIApi; + use crate::apis::anthropic::AnthropicApi; + + // Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic) + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": null + }] + }); + + // Create SSE event with this data + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the transformation is a content_block_delta (no extra events injected) + let buffer = transformed.sse_transform_buffer; + assert!(buffer.contains("event: content_block_delta"), "Should contain content_block_delta event"); + assert!(!buffer.contains("content_block_start"), "Should not inject content_block_start for content delta"); + assert!(!buffer.contains("content_block_stop"), "Should not inject content_block_stop for content delta"); + + // Verify the content is preserved + assert!(buffer.contains("Hello"), "Should preserve the content text"); + } + + #[test] + fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() { + use crate::apis::openai::OpenAIApi; + use crate::apis::anthropic::AnthropicApi; + + // Create an Anthropic event-only SSE line (no data) + let sse_event = SseEvent { + data: None, + event: Some("message_start".to_string()), + raw_line: "event: message_start".to_string(), + sse_transform_buffer: "event: message_start".to_string(), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify the event line is suppressed (replaced with just newline) + assert_eq!(transformed.sse_transform_buffer, "\n", "Event-only lines should be suppressed to newline for OpenAI"); + assert!(transformed.is_event_only(), "Should still be marked as event-only"); + } + + #[test] + fn test_sse_event_transformation_anthropic_to_openai_preserves_data() { + use crate::apis::openai::OpenAIApi; + use crate::apis::anthropic::AnthropicApi; + + // Create an Anthropic message_start event with data + let anthropic_event = json!({ + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-sonnet", + "stop_reason": null, + "usage": {"input_tokens": 10, "output_tokens": 0} + } + }); + + let sse_event = SseEvent { + data: Some(anthropic_event.to_string()), + event: None, + raw_line: format!("data: {}", anthropic_event.to_string()), + sse_transform_buffer: format!("data: {}", anthropic_event.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify data is transformed to OpenAI format + let buffer = transformed.sse_transform_buffer; + assert!(buffer.starts_with("data: "), "Should have data: prefix"); + assert!(!buffer.contains("event:"), "Should not have event: lines for OpenAI"); + + // Verify provider response was parsed + assert!(transformed.provider_stream_response.is_some()); + } + + #[test] + fn test_sse_event_transformation_no_change_for_matching_apis() { + use crate::apis::openai::OpenAIApi; + + // Create an OpenAI stream response + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Hello"}, + "finish_reason": null + }] + }); + + let original_data = openai_stream_chunk.to_string(); + let sse_event = SseEvent { + data: Some(original_data.clone()), + event: None, + raw_line: format!("data: {}", original_data), + sse_transform_buffer: format!("data: {}\n\n", original_data), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify minimal transformation - just SSE formatting, no API conversion + let buffer = transformed.sse_transform_buffer; + assert!(buffer.starts_with("data: "), "Should preserve data: prefix"); + assert!(!buffer.contains("event:"), "Should not add event: lines"); + + // Verify provider response was parsed + assert!(transformed.provider_stream_response.is_some()); + } + + #[test] + fn test_sse_event_transformation_preserves_provider_response() { + use crate::apis::openai::OpenAIApi; + use crate::apis::anthropic::AnthropicApi; + + // Create an OpenAI stream response + let openai_stream_chunk = json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "delta": {"content": "Test"}, + "finish_reason": null + }] + }); + + let sse_event = SseEvent { + data: Some(openai_stream_chunk.to_string()), + event: None, + raw_line: format!("data: {}", openai_stream_chunk.to_string()), + sse_transform_buffer: format!("data: {}", openai_stream_chunk.to_string()), + provider_stream_response: None, + }; + + let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + + // Transform the event + let result = SseEvent::try_from((sse_event, &client_api, &upstream_api)); + assert!(result.is_ok()); + + let transformed = result.unwrap(); + + // Verify provider_stream_response is populated + assert!(transformed.provider_stream_response.is_some(), "Should parse and store provider response"); + + // Verify we can access the provider response + let provider_response = transformed.provider_response(); + assert!(provider_response.is_ok(), "Should be able to access provider response"); + + // Verify the content delta is accessible + let content = provider_response.unwrap().content_delta(); + assert_eq!(content, Some("Test"), "Should preserve content delta"); + } } diff --git a/crates/hermesllm/src/transforms/response/to_anthropic.rs b/crates/hermesllm/src/transforms/response/to_anthropic.rs index f6dc07d5..c19debc6 100644 --- a/crates/hermesllm/src/transforms/response/to_anthropic.rs +++ b/crates/hermesllm/src/transforms/response/to_anthropic.rs @@ -8,7 +8,9 @@ use crate::apis::anthropic::{ MessagesStreamEvent, MessagesStopReason, MessagesMessageDelta, MessagesResponse, MessagesStreamMessage, MessagesUsage, MessagesContentDelta, MessagesRole, MessagesContentBlock }; -use crate::apis::amazon_bedrock::{ConverseResponse, ConverseOutput, StopReason}; +use crate::apis::amazon_bedrock::{ + ConverseResponse, ConverseOutput, StopReason, ConverseStreamEvent, ContentBlockDelta +}; // ============================================================================ // STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience @@ -47,7 +49,6 @@ impl TryFrom for MessagesResponse { } } - impl TryFrom for MessagesResponse { type Error = TransformError; @@ -202,6 +203,157 @@ impl TryFrom for MessagesStreamEvent { } } +impl Into for MessagesStreamEvent { + fn into(self) -> String { + let transformed_json = serde_json::to_string(&self).unwrap_or_default(); + let event_type = match &self { + MessagesStreamEvent::MessageStart { .. } => "message_start", + MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start", + MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta", + MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop", + MessagesStreamEvent::MessageDelta { .. } => "message_delta", + MessagesStreamEvent::MessageStop => "message_stop", + MessagesStreamEvent::Ping => "ping", + }; + + let event = format!("event: {}\n", event_type); + let data = format!("data: {}\n\n", transformed_json); + event + &data + } +} + +impl TryFrom for MessagesStreamEvent { + type Error = TransformError; + + fn try_from(event: ConverseStreamEvent) -> Result { + match event { + // MessageStart - convert to Anthropic MessageStart + ConverseStreamEvent::MessageStart(start_event) => { + let role = match start_event.role { + crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => MessagesRole::Assistant, + }; + + Ok(MessagesStreamEvent::MessageStart { + message: MessagesStreamMessage { + id: format!("bedrock-stream-{}", std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos()), + obj_type: "message".to_string(), + role, + content: vec![], + model: "bedrock-model".to_string(), + stop_reason: None, + stop_sequence: None, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }, + }) + } + + // ContentBlockStart - convert to Anthropic ContentBlockStart + ConverseStreamEvent::ContentBlockStart(start_event) => { + // Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas + // Anthropic expects the same pattern, so we initialize with an empty input object + match start_event.start { + crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use_id, name } => { + Ok(MessagesStreamEvent::ContentBlockStart { + index: start_event.content_block_index as u32, + content_block: MessagesContentBlock::ToolUse { + id: tool_use_id, + name, + input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas + cache_control: None, + }, + }) + } + } + } + + // ContentBlockDelta - convert to Anthropic ContentBlockDelta + ConverseStreamEvent::ContentBlockDelta(delta_event) => { + let delta = match delta_event.delta { + ContentBlockDelta::Text { text } => { + MessagesContentDelta::TextDelta { text } + } + ContentBlockDelta::ToolUse { input } => { + MessagesContentDelta::InputJsonDelta { partial_json: input } + } + }; + + Ok(MessagesStreamEvent::ContentBlockDelta { + index: delta_event.content_block_index as u32, + delta, + }) + } + + // ContentBlockStop - convert to Anthropic ContentBlockStop + ConverseStreamEvent::ContentBlockStop(stop_event) => { + Ok(MessagesStreamEvent::ContentBlockStop { + index: stop_event.content_block_index as u32, + }) + } + + // MessageStop - convert to Anthropic MessageDelta with stop reason + MessageStop + ConverseStreamEvent::MessageStop(stop_event) => { + let anthropic_stop_reason = match stop_event.stop_reason { + StopReason::EndTurn => MessagesStopReason::EndTurn, + StopReason::ToolUse => MessagesStopReason::ToolUse, + StopReason::MaxTokens => MessagesStopReason::MaxTokens, + StopReason::StopSequence => MessagesStopReason::EndTurn, + StopReason::GuardrailIntervened => MessagesStopReason::Refusal, + StopReason::ContentFiltered => MessagesStopReason::Refusal, + }; + + // Return MessageDelta (MessageStop will be sent separately by the streaming handler) + Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: anthropic_stop_reason, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }, + }) + } + + // Metadata - convert usage information to MessageDelta + ConverseStreamEvent::Metadata(metadata_event) => { + Ok(MessagesStreamEvent::MessageDelta { + delta: MessagesMessageDelta { + stop_reason: MessagesStopReason::EndTurn, + stop_sequence: None, + }, + usage: MessagesUsage { + input_tokens: metadata_event.usage.input_tokens, + output_tokens: metadata_event.usage.output_tokens, + cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens, + cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens, + }, + }) + } + + // Exception events - convert to Ping (could be enhanced to return error events) + ConverseStreamEvent::InternalServerException(_) | + ConverseStreamEvent::ModelStreamErrorException(_) | + ConverseStreamEvent::ServiceUnavailableException(_) | + ConverseStreamEvent::ThrottlingException(_) | + ConverseStreamEvent::ValidationException(_) => { + // TODO: Consider adding proper error handling/events + Ok(MessagesStreamEvent::Ping) + } + } + } +} + /// Convert tool call deltas to Anthropic stream events fn convert_tool_call_deltas(tool_calls: Vec) -> Result { for tool_call in tool_calls { diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs index 53820708..974a9339 100644 --- a/crates/hermesllm/src/transforms/response/to_openai.rs +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -1,6 +1,6 @@ use crate::apis::openai::{ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, ResponseMessage, Role, ToolCallDelta, FunctionCallDelta, Usage, StreamChoice, MessageDelta, MessageContent}; use crate::apis::anthropic::{MessagesResponse, MessagesStreamEvent, MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesUsage}; -use crate::apis::amazon_bedrock::{ConverseResponse, ConverseOutput, StopReason}; +use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason}; use crate::clients::TransformError; use crate::transforms::lib::*; @@ -250,6 +250,172 @@ impl TryFrom for ChatCompletionsStreamResponse { } +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(event: ConverseStreamEvent) -> Result { + match event { + ConverseStreamEvent::MessageStart(start_event) => { + let role = match start_event.role { + crate::apis::amazon_bedrock::ConversationRole::User => Role::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: Some(role), + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )) + } + + ConverseStreamEvent::ContentBlockStart(start_event) => { + use crate::apis::amazon_bedrock::ContentBlockStart; + + match start_event.start { + ContentBlockStart::ToolUse { tool_use_id, name } => { + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: start_event.content_block_index as u32, + id: Some(tool_use_id), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(name), + arguments: Some("".to_string()), + }), + }]), + }, + None, + None, + )) + } + } + } + + ConverseStreamEvent::ContentBlockDelta(delta_event) => { + use crate::apis::amazon_bedrock::ContentBlockDelta; + + match delta_event.delta { + ContentBlockDelta::Text { text } => { + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: Some(text), + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + None, + )) + } + ContentBlockDelta::ToolUse { input } => { + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: Some(vec![ToolCallDelta { + index: delta_event.content_block_index as u32, + id: None, + call_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(input), + }), + }]), + }, + None, + None, + )) + } + } + } + + ConverseStreamEvent::ContentBlockStop(_) => { + Ok(create_empty_openai_chunk()) + } + + ConverseStreamEvent::MessageStop(stop_event) => { + let finish_reason = match stop_event.stop_reason { + StopReason::EndTurn => FinishReason::Stop, + StopReason::ToolUse => FinishReason::ToolCalls, + StopReason::MaxTokens => FinishReason::Length, + StopReason::StopSequence => FinishReason::Stop, + StopReason::GuardrailIntervened => FinishReason::ContentFilter, + StopReason::ContentFiltered => FinishReason::ContentFilter, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + Some(finish_reason), + None, + )) + } + + ConverseStreamEvent::Metadata(metadata_event) => { + let usage = Usage { + prompt_tokens: metadata_event.usage.input_tokens, + completion_tokens: metadata_event.usage.output_tokens, + total_tokens: metadata_event.usage.total_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }; + + Ok(create_openai_chunk( + "stream", + "unknown", + MessageDelta { + role: None, + content: None, + refusal: None, + function_call: None, + tool_calls: None, + }, + None, + Some(usage), + )) + } + + // Error events - convert to empty chunks (errors should be handled elsewhere) + ConverseStreamEvent::InternalServerException(_) | + ConverseStreamEvent::ModelStreamErrorException(_) | + ConverseStreamEvent::ServiceUnavailableException(_) | + ConverseStreamEvent::ThrottlingException(_) | + ConverseStreamEvent::ValidationException(_) => { + Ok(create_empty_openai_chunk()) + } + } + } +} + /// Convert content block start to OpenAI chunk fn convert_content_block_start(content_block: MessagesContentBlock) -> Result { match content_block { diff --git a/crates/llm_gateway/Cargo.toml b/crates/llm_gateway/Cargo.toml index b2557477..281e05be 100644 --- a/crates/llm_gateway/Cargo.toml +++ b/crates/llm_gateway/Cargo.toml @@ -23,6 +23,7 @@ thiserror = "1.0.64" derivative = "2.2.0" sha2 = "0.10.8" hermesllm = { version = "0.1.0", path = "../hermesllm" } +bytes = "1.10" [dev-dependencies] serial_test = "3.1.1" diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 8ade853a..f9a699a0 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,3 +1,4 @@ +use bytes::Buf; use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use http::StatusCode; use log::{debug, info, warn}; @@ -23,7 +24,9 @@ use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use hermesllm::clients::endpoints::SupportedAPIs; -use hermesllm::providers::response::{ProviderResponse, SseEvent, SseStreamIter}; +use hermesllm::providers::response::{ + BedrockBinaryFrameDecoder, ProviderResponse, SseEvent, SseStreamIter, +}; use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType}; pub struct StreamContext { @@ -46,8 +49,8 @@ pub struct StreamContext { traces_queue: Arc>>, overrides: Rc>, user_message: Option, - /// Store upstream response status code to handle error responses gracefully upstream_status_code: Option, + binary_frame_decoder: Option>, } impl StreamContext { @@ -76,6 +79,7 @@ impl StreamContext { request_body_sent_time: None, user_message: None, upstream_status_code: None, + binary_frame_decoder: None, } } diff --git a/tests/e2e/response.hex b/tests/e2e/response.hex new file mode 100644 index 00000000..c96504e2 Binary files /dev/null and b/tests/e2e/response.hex differ