diff --git a/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs b/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs index bacbad62..7f68bb26 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs @@ -1,7 +1,6 @@ use aws_smithy_eventstream::frame::DecodedFrame; use aws_smithy_eventstream::frame::MessageFrameDecoder; use bytes::Buf; -use std::collections::HashSet; /// AWS Event Stream frame decoder wrapper pub struct BedrockBinaryFrameDecoder @@ -10,7 +9,6 @@ where { decoder: MessageFrameDecoder, buffer: B, - content_block_start_indices: HashSet, } impl BedrockBinaryFrameDecoder { @@ -20,7 +18,6 @@ impl BedrockBinaryFrameDecoder { Self { decoder: MessageFrameDecoder::new(), buffer, - content_block_start_indices: std::collections::HashSet::new(), } } } @@ -33,7 +30,6 @@ where Self { decoder: MessageFrameDecoder::new(), buffer, - content_block_start_indices: HashSet::new(), } } @@ -52,14 +48,4 @@ where pub fn has_remaining(&self) -> bool { self.buffer.has_remaining() } - - /// Check if a content_block_start event has been sent for the given index - pub fn has_content_block_start_been_sent(&self, index: i32) -> bool { - self.content_block_start_indices.contains(&index) - } - - /// Mark that a content_block_start event has been sent for the given index - pub fn set_content_block_start_sent(&mut self, index: i32) { - self.content_block_start_indices.insert(index); - } } diff --git a/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs index 07ff0140..c326a770 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs @@ -1,6 +1,7 @@ use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; use crate::apis::anthropic::MessagesStreamEvent; -use crate::providers::streaming_response::ProviderStreamResponse; +use crate::providers::streaming_response::ProviderStreamResponseType; +use std::collections::HashSet; /// SSE Stream Buffer for Anthropic Messages API streaming. /// @@ -17,8 +18,8 @@ pub struct AnthropicMessagesStreamBuffer { /// Track if we've seen a message_start event message_started: bool, - /// Track if we've seen a content_block_start event - content_block_started: bool, + /// Track content block indices that have received ContentBlockStart events + content_block_start_indices: HashSet, /// Track if we need to inject ContentBlockStop before message_delta needs_content_block_stop: bool, @@ -32,12 +33,22 @@ impl AnthropicMessagesStreamBuffer { Self { buffered_events: Vec::new(), message_started: false, - content_block_started: false, + content_block_start_indices: HashSet::new(), needs_content_block_stop: false, model: None, } } + /// Check if a content_block_start event has been sent for the given index + fn has_content_block_start_been_sent(&self, index: i32) -> bool { + self.content_block_start_indices.contains(&index) + } + + /// Mark that a content_block_start event has been sent for the given index + fn set_content_block_start_sent(&mut self, index: i32) { + self.content_block_start_indices.insert(index); + } + /// Helper to create and format a ContentBlockStart SSE event fn create_content_block_start_event() -> SseEvent { let content_block_start = MessagesStreamEvent::ContentBlockStart { @@ -124,17 +135,19 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { } } - // Check if this event has a provider response to determine its type - if let Some(provider_response) = &event.provider_stream_response { - if let Some(event_type) = provider_response.event_type() { - match event_type { - "message_start" => { + // Match directly on the provider response type to handle event processing + // We match on a reference first to determine the type, then move the event + match &event.provider_stream_response { + Some(ProviderStreamResponseType::MessagesStreamEvent(evt)) => { + match evt { + MessagesStreamEvent::MessageStart { .. } => { // Add the message_start event self.buffered_events.push(event); self.message_started = true; } - "content_block_start" => { - // If we haven't seen message_start yet, inject it first + MessagesStreamEvent::ContentBlockStart { index, .. } => { + let index = *index as i32; + // Inject message_start if needed if !self.message_started { let model = self.model.as_deref().unwrap_or("unknown"); let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); @@ -144,32 +157,32 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { // Add the content_block_start event (from tool calls or other sources) self.buffered_events.push(event); - self.content_block_started = true; + self.set_content_block_start_sent(index); self.needs_content_block_stop = true; } - "content_block_delta" => { - // If this is the first content delta and we haven't started yet, - // inject message_start and content_block_start first + MessagesStreamEvent::ContentBlockDelta { index, .. } => { + let index = *index as i32; + // Inject message_start if needed if !self.message_started { - // Create and inject message_start event let model = self.model.as_deref().unwrap_or("unknown"); let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); self.buffered_events.push(message_start); self.message_started = true; } - if !self.content_block_started { - // Inject ContentBlockStart after message_start + // Check if ContentBlockStart was sent for this index + if !self.has_content_block_start_been_sent(index) { + // Inject ContentBlockStart before delta let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event(); self.buffered_events.push(content_block_start); - self.content_block_started = true; + self.set_content_block_start_sent(index); self.needs_content_block_stop = true; } // Content deltas are between ContentBlockStart and ContentBlockStop self.buffered_events.push(event); } - "message_delta" => { + MessagesStreamEvent::MessageDelta { .. } => { // Inject ContentBlockStop before message_delta if self.needs_content_block_stop { let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event(); @@ -181,16 +194,16 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { self.buffered_events.push(event); } _ => { - // Other event types, just accumulate the event + // Other Anthropic event types (ContentBlockStop, MessageStop, etc.), just accumulate self.buffered_events.push(event); } } - return; + } + _ => { + // Non-Anthropic events or events without provider_stream_response, just accumulate + self.buffered_events.push(event); } } - - // For events without provider_stream_response or event_type, just accumulate - self.buffered_events.push(event); } fn into_bytes(&mut self) -> Vec { diff --git a/crates/hermesllm/src/apis/streaming_shapes/sse.rs b/crates/hermesllm/src/apis/streaming_shapes/sse.rs index 6a2485e4..17c6873a 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/sse.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/sse.rs @@ -92,6 +92,21 @@ pub struct SseEvent { } impl SseEvent { + /// Create an SseEvent from a ProviderStreamResponseType + /// This is useful for binary frame formats (like Bedrock) that need to be converted to SSE + pub fn from_provider_response(response: ProviderStreamResponseType) -> Self { + // Convert the provider response to SSE format string + let sse_string: String = response.clone().into(); + + SseEvent { + data: None, // Data is embedded in sse_transformed_lines + event: None, // Event type is embedded in sse_transformed_lines + raw_line: sse_string.clone(), + sse_transformed_lines: sse_string, + provider_stream_response: Some(response), + } + } + /// Check if this event represents the end of the stream pub fn is_done(&self) -> bool { self.data == Some("[DONE]".into()) || self.event == Some("message_stop".into()) diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 3c964240..344a795f 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -130,6 +130,15 @@ impl ProviderId { SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) } } + (ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { + if is_streaming { + SupportedUpstreamAPIs::AmazonBedrockConverseStream( + AmazonBedrockApi::ConverseStream, + ) + } else { + SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) + } + } // Non-OpenAI providers: if client requested the Responses API, fall back to Chat Completions (_, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index b1d88e58..a2494c6d 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -199,6 +199,31 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons })?; Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) } + ( + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + // Chain transform: Bedrock Converse -> ChatCompletions -> ResponsesAPI + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to ChatCompletions format + let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Bedrock to ChatCompletions transformation error: {}", e), + ) + })?; + + // Transform to ResponsesAPI format + let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("ChatCompletions to ResponsesAPI transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) + } _ => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation", diff --git a/crates/hermesllm/src/providers/streaming_response.rs b/crates/hermesllm/src/providers/streaming_response.rs index 7707d88d..55e52f3d 100644 --- a/crates/hermesllm/src/providers/streaming_response.rs +++ b/crates/hermesllm/src/providers/streaming_response.rs @@ -145,6 +145,7 @@ impl ProviderStreamResponse for ProviderStreamResponseType { ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(), } } + } impl Into for ProviderStreamResponseType { @@ -259,6 +260,19 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov anthropic_resp, )) } + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + // Chain: Bedrock -> ChatCompletions -> ResponsesAPI + let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = + serde_json::from_slice(bytes)?; + let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_resp.try_into()?; + let responses_resp = chat_resp.try_into()?; + Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( + responses_resp, + )) + } _ => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation", @@ -400,6 +414,22 @@ impl Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse( openai_event, )) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + // Parse the DecodedFrame into ConverseStreamEvent + let bedrock_event = + crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; + let openai_chat_completions_event: crate::apis::openai::ChatCompletionsStreamResponse = + bedrock_event.try_into()?; + let openai_responses_api_event: crate::apis::openai_responses::ResponsesAPIStreamEvent = + openai_chat_completions_event.try_into()?; + + Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( + openai_responses_api_event, + )) } _ => Err("Unsupported API combination for event-stream decoding".into()), } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 9541ae45..fbab980c 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -22,13 +22,13 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent}; use hermesllm::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; use hermesllm::apis::streaming_shapes::sse::{ SseEvent, SseStreamBuffer, SseStreamBufferTrait, SseStreamIter, }; use hermesllm::clients::endpoints::SupportedAPIsFromClient; use hermesllm::providers::response::ProviderResponse; +use hermesllm::providers::streaming_response::ProviderStreamResponse; use hermesllm::{ DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType, ProviderStreamResponseType, @@ -575,83 +575,57 @@ impl StreamContext { self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[])); } - // Add incoming bytes to buffer + // Initialize SSE buffer if not present + if self.sse_buffer.is_none() { + self.sse_buffer = match SseStreamBuffer::try_from((client_api, upstream_api)) { + Ok(buffer) => Some(buffer), + Err(e) => { + warn!( + "[ARCHGW_REQ_ID:{}] BEDROCK_BUFFER_INIT_ERROR: {}", + self.request_identifier(), + e + ); + return Err(Action::Continue); + } + }; + } + + // Add incoming bytes to decoder buffer let decoder = self.binary_frame_decoder.as_mut().unwrap(); decoder.buffer_mut().extend_from_slice(body); - let mut response_buffer = Vec::new(); + // Process all complete frames loop { let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame(); match decoded_frame { Some(DecodedFrame::Complete(ref frame_ref)) => { let frame = DecodedFrame::Complete(frame_ref.clone()); + + // Convert frame to provider response type match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) { Ok(provider_response) => { self.record_ttft_if_needed(); - // Handle ContentBlockStart and ContentBlockDelta events - match &provider_response { - ProviderStreamResponseType::MessagesStreamEvent(evt) => { - match evt { - MessagesStreamEvent::ContentBlockStart { - index, .. - } => { - // Mark that we've seen ContentBlockStart for this index - self.binary_frame_decoder - .as_mut() - .unwrap() - .set_content_block_start_sent(*index as i32); - debug!( - "[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}", - self.request_identifier(), - *index - ); - } - MessagesStreamEvent::ContentBlockDelta { - index, .. - } => { - // Check if ContentBlockStart was sent for this index - let needs_start = !self - .binary_frame_decoder - .as_ref() - .unwrap() - .has_content_block_start_been_sent(*index as i32); - - if needs_start { - // Emit empty ContentBlockStart before delta - let content_block_start = - MessagesStreamEvent::ContentBlockStart { - index: *index, - content_block: MessagesContentBlock::Text { - text: String::new(), - cache_control: None, - }, - }; - let start_sse: String = content_block_start.into(); - response_buffer - .extend_from_slice(start_sse.as_bytes()); - - // Mark that we've now sent it - self.binary_frame_decoder - .as_mut() - .unwrap() - .set_content_block_start_sent(*index as i32); - - debug!( - "[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}", - self.request_identifier(), - *index - ); - } - } - _ => {} - } - } - _ => {} + // Track token usage + if let Some(content) = provider_response.content_delta() { + let estimated_tokens = content.len() / 4; + self.response_tokens += estimated_tokens.max(1); + debug!( + "[ARCHGW_REQ_ID:{}] BEDROCK_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}", + self.request_identifier(), + content.len(), + estimated_tokens.max(1), + self.response_tokens + ); } - let sse_string: String = provider_response.into(); - response_buffer.extend_from_slice(sse_string.as_bytes()); + // Create SseEvent from provider response + let event = SseEvent::from_provider_response(provider_response); + + // Add to buffer (buffer handles all shim logic including ContentBlockStart injection) + if let Some(buffer) = self.sse_buffer.as_mut() { + buffer.add_transformed_event(event); + } } Err(e) => { warn!( @@ -681,8 +655,17 @@ impl StreamContext { } } - // Return accumulated complete frames (may be empty if all frames incomplete) - Ok(response_buffer) + // Get accumulated bytes from buffer and return + match self.sse_buffer.as_mut() { + Some(buffer) => Ok(buffer.into_bytes()), + None => { + warn!( + "[ARCHGW_REQ_ID:{}] BEDROCK_BUFFER_MISSING", + self.request_identifier() + ); + Err(Action::Continue) + } + } } fn handle_non_streaming_response( diff --git a/tests/e2e/test_openai_responses_api_client.py b/tests/e2e/test_openai_responses_api_client.py index 2d9fa248..800db93d 100644 --- a/tests/e2e/test_openai_responses_api_client.py +++ b/tests/e2e/test_openai_responses_api_client.py @@ -327,6 +327,158 @@ def test_openai_responses_api_streaming_with_tools_upstream_chat_completions(): ), "Expected streamed text or tool call argument deltas from Responses tools stream" +def test_openai_responses_api_non_streaming_upstream_bedrock(): + """Send a v1/responses request using the coding-model alias to verify Bedrock translation/routing""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + resp = client.responses.create( + model="coding-model", + input="Hello, translate this via coding-model alias to Bedrock", + ) + + # Print the response content - handle both responses format and chat completions format + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + assert resp is not None + assert resp.id is not None + + +def test_openai_responses_api_with_streaming_upstream_bedrock(): + """Build a v1/responses API streaming request routed to Bedrock via coding-model alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple streaming responses API request using coding-model alias + stream = client.responses.create( + model="coding-model", + input="Write a short haiku about coding", + stream=True, + ) + + # Collect streamed content using the official Responses API streaming shape + text_chunks = [] + final_message = None + + for event in stream: + # The Python SDK surfaces a high-level Responses streaming interface. + # We rely on its typed helpers instead of digging into model_extra. + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + # Each delta contains a text fragment + text_chunks.append(event.delta) + + # Track the final response message if provided by the SDK + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + final_message = event.response + + full_content = "".join(text_chunks) + + # Print the streaming response + print(f"\n{'='*80}") + print( + f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}" + ) + print(f"Streamed Output: {full_content}") + print(f"{'='*80}\n") + + assert len(text_chunks) > 0, "Should have received streaming text deltas" + assert len(full_content) > 0, "Should have received content" + + +def test_openai_responses_api_non_streaming_with_tools_upstream_bedrock(): + """Responses API with tools routed to Bedrock via coding-model alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + resp = client.responses.create( + model="coding-model", + input="Call the echo tool", + tools=tools, + ) + + assert resp.id is not None + + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + +def test_openai_responses_api_streaming_with_tools_upstream_bedrock(): + """Responses API with a function/tool definition streaming to Bedrock via coding-model alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + stream = client.responses.create( + model="coding-model", + input="Call the echo tool", + tools=tools, + stream=True, + ) + + text_chunks = [] + tool_calls = [] + + for event in stream: + etype = getattr(event, "type", None) + + # Collect streamed text output + if etype == "response.output_text.delta" and getattr(event, "delta", None): + text_chunks.append(event.delta) + + # Collect streamed tool call arguments + if etype == "response.function_call_arguments.delta" and getattr( + event, "delta", None + ): + tool_calls.append(event.delta) + + full_text = "".join(text_chunks) + + print(f"\n{'='*80}") + print("Responses tools streaming test (Bedrock)") + print(f"Streamed text: {full_text}") + print(f"Tool call argument chunks: {len(tool_calls)}") + print(f"{'='*80}\n") + + # We expect either streamed text output or streamed tool-call arguments + assert ( + full_text or tool_calls + ), "Expected streamed text or tool call argument deltas from Responses tools stream" + + def test_openai_responses_api_non_streaming_upstream_anthropic(): """Send a v1/responses request using the grok alias to verify translation/routing""" base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")