diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs index 0c4eb262..096c84c8 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock.rs +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -693,16 +693,22 @@ pub struct ContentBlockStartEvent { /// Content block start information #[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(tag = "type")] +#[serde(untagged)] pub enum ContentBlockStart { - #[serde(rename = "toolUse")] ToolUse { - #[serde(rename = "toolUseId")] - tool_use_id: String, - name: String, + #[serde(rename = "toolUse")] + tool_use: ToolUseStart, }, } +/// Tool use start information +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolUseStart { + #[serde(rename = "toolUseId")] + pub tool_use_id: String, + pub name: String, +} + /// Content block delta event #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ContentBlockDeltaEvent { @@ -718,7 +724,15 @@ pub struct ContentBlockDeltaEvent { #[serde(untagged)] pub enum ContentBlockDelta { Text { text: String }, - ToolUse { input: String }, + ToolUse { + #[serde(rename = "toolUse")] + tool_use: ToolUseDelta + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ToolUseDelta { + pub input: String, } /// Content block stop event diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 3b14d0b8..61b38f4b 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -478,7 +478,7 @@ where { decoder: aws_smithy_eventstream::frame::MessageFrameDecoder, buffer: B, - has_content_block_start_been_sent: bool, + content_block_start_indices: std::collections::HashSet, } impl BedrockBinaryFrameDecoder { @@ -488,7 +488,7 @@ impl BedrockBinaryFrameDecoder { Self { decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(), buffer, - has_content_block_start_been_sent: false, + content_block_start_indices: std::collections::HashSet::new(), } } } @@ -501,7 +501,7 @@ where Self { decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(), buffer, - has_content_block_start_been_sent: false, + content_block_start_indices: std::collections::HashSet::new(), } } @@ -521,14 +521,14 @@ where self.buffer.has_remaining() } - /// Check if a content_block_start event has been sent - pub fn has_content_block_start_been_sent(&self) -> bool { - self.has_content_block_start_been_sent + /// Check if a content_block_start event has been sent for the given index + pub fn has_content_block_start_been_sent(&self, index: i32) -> bool { + self.content_block_start_indices.contains(&index) } - /// Set the content_block_start flag - pub fn set_content_block_start_sent(&mut self, sent: bool) { - self.has_content_block_start_been_sent = sent; + /// Mark that a content_block_start event has been sent for the given index + pub fn set_content_block_start_sent(&mut self, index: i32) { + self.content_block_start_indices.insert(index); } } @@ -1122,6 +1122,17 @@ mod tests { test_bedrock_conversion(true); } + #[test] + fn test_bedrock_decoded_frame_with_tool_use() { + test_bedrock_conversion_with_tools(false); + } + + #[test] + #[ignore] // Run with: cargo test -- --ignored --nocapture + fn test_bedrock_decoded_frame_with_tool_use_verbose() { + test_bedrock_conversion_with_tools(true); + } + fn test_bedrock_conversion(verbose: bool) { use bytes::BytesMut; use std::fs; @@ -1194,6 +1205,93 @@ mod tests { assert!(message_start_seen, "Should have seen MessageStart event"); } + fn test_bedrock_conversion_with_tools(verbose: bool) { + use bytes::BytesMut; + use std::fs; + use std::path::PathBuf; + + // Read the actual response_with_tools.hex file from tests/e2e directory + let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../tests/e2e/response_with_tools.hex"); + + // Only run this test if the file exists + if !test_file.exists() { + println!("Skipping test - response_with_tools.hex not found"); + return; + } + + let response_data = fs::read(&test_file).unwrap(); + let mut buffer = BytesMut::from(&response_data[..]); + + let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); + + let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream); + + let mut conversion_count = 0; + let mut message_start_seen = false; + let mut content_block_start_seen = false; + let mut content_block_delta_tool_use_seen = false; + + // Decode and convert frames + loop { + match decoder.decode_frame() { + Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { + // Convert DecodedFrame to ProviderStreamResponseType + let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); + + match result { + Ok(provider_response) => { + conversion_count += 1; + + // Verify we got a MessagesStreamEvent + assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_))); + + if verbose { + // Print the SSE string output + let sse_string: String = provider_response.clone().into(); + println!("{}", sse_string); + } + + // Check for specific events related to tool use + if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response { + match event { + crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => { + message_start_seen = true; + } + crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => { + content_block_start_seen = true; + } + crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => { + if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) { + content_block_delta_tool_use_seen = true; + } + } + _ => {} + } + } + } + Err(e) => { + println!("Conversion error (frame {}): {}", conversion_count, e); + } + } + } + Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => { + // End of buffer + break; + } + None => { + panic!("Decode error"); + } + } + } + + assert!(conversion_count > 0, "Should have converted at least one frame"); + assert!(message_start_seen, "Should have seen MessageStart event"); + assert!(content_block_start_seen, "Should have seen ContentBlockStart event for tool use"); + assert!(content_block_delta_tool_use_seen, "Should have seen ContentBlockDelta with ToolUseDelta"); + } + #[test] fn test_sse_event_transformation_openai_to_anthropic_message_start() { use crate::apis::openai::OpenAIApi; diff --git a/crates/hermesllm/src/transforms/response/to_anthropic.rs b/crates/hermesllm/src/transforms/response/to_anthropic.rs index c19debc6..2076313f 100644 --- a/crates/hermesllm/src/transforms/response/to_anthropic.rs +++ b/crates/hermesllm/src/transforms/response/to_anthropic.rs @@ -261,12 +261,12 @@ impl TryFrom for MessagesStreamEvent { // Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas // Anthropic expects the same pattern, so we initialize with an empty input object match start_event.start { - crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use_id, name } => { + crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => { Ok(MessagesStreamEvent::ContentBlockStart { index: start_event.content_block_index as u32, content_block: MessagesContentBlock::ToolUse { - id: tool_use_id, - name, + id: tool_use.tool_use_id, + name: tool_use.name, input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas cache_control: None, }, @@ -281,8 +281,8 @@ impl TryFrom for MessagesStreamEvent { ContentBlockDelta::Text { text } => { MessagesContentDelta::TextDelta { text } } - ContentBlockDelta::ToolUse { input } => { - MessagesContentDelta::InputJsonDelta { partial_json: input } + ContentBlockDelta::ToolUse { tool_use } => { + MessagesContentDelta::InputJsonDelta { partial_json: tool_use.input } } }; diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs index 974a9339..8bf6896b 100644 --- a/crates/hermesllm/src/transforms/response/to_openai.rs +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -280,7 +280,7 @@ impl TryFrom for ChatCompletionsStreamResponse { use crate::apis::amazon_bedrock::ContentBlockStart; match start_event.start { - ContentBlockStart::ToolUse { tool_use_id, name } => { + ContentBlockStart::ToolUse { tool_use } => { Ok(create_openai_chunk( "stream", "unknown", @@ -291,10 +291,10 @@ impl TryFrom for ChatCompletionsStreamResponse { function_call: None, tool_calls: Some(vec![ToolCallDelta { index: start_event.content_block_index as u32, - id: Some(tool_use_id), + id: Some(tool_use.tool_use_id), call_type: Some("function".to_string()), function: Some(FunctionCallDelta { - name: Some(name), + name: Some(tool_use.name), arguments: Some("".to_string()), }), }]), @@ -325,7 +325,7 @@ impl TryFrom for ChatCompletionsStreamResponse { None, )) } - ContentBlockDelta::ToolUse { input } => { + ContentBlockDelta::ToolUse { tool_use } => { Ok(create_openai_chunk( "stream", "unknown", @@ -340,7 +340,7 @@ impl TryFrom for ChatCompletionsStreamResponse { call_type: None, function: Some(FunctionCallDelta { name: None, - arguments: Some(input), + arguments: Some(tool_use.input), }), }]), }, diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index f9a699a0..43151a7f 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,4 +1,3 @@ -use bytes::Buf; use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use http::StatusCode; use log::{debug, info, warn}; @@ -25,9 +24,11 @@ use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use hermesllm::clients::endpoints::SupportedAPIs; use hermesllm::providers::response::{ - BedrockBinaryFrameDecoder, ProviderResponse, SseEvent, SseStreamIter, + BedrockBinaryFrameDecoder, ProviderResponse, ProviderStreamResponse, SseEvent, SseStreamIter, +}; +use hermesllm::{ + DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType, }; -use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType}; pub struct StreamContext { metrics: Rc, @@ -424,6 +425,14 @@ impl StreamContext { let upstream_api = provider_id.compatible_api_for_client(&client_api, self.streaming_response); + // Check if this is Bedrock binary stream + if matches!( + upstream_api, + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_) + ) { + return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api); + } + // Parse body into SSE iterator using TryFrom let sse_iter: SseStreamIter> = match SseStreamIter::try_from(body) { @@ -499,6 +508,157 @@ impl StreamContext { } } + fn handle_bedrock_binary_stream( + &mut self, + body: &[u8], + client_api: &SupportedAPIs, + upstream_api: &SupportedUpstreamAPIs, + ) -> Result, Action> { + use hermesllm::providers::response::ProviderStreamResponseType; + + // Initialize decoder if not present + if self.binary_frame_decoder.is_none() { + self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[])); + } + + // Add incoming bytes to buffer + if let Some(decoder) = self.binary_frame_decoder.as_mut() { + decoder.buffer_mut().extend_from_slice(body); + } + + let mut response_buffer = Vec::new(); + + // Decode all available complete frames + loop { + let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame(); + match decoded_frame { + Some(DecodedFrame::Complete(ref frame_ref)) => { + // Convert frame to ProviderStreamResponseType + let frame = DecodedFrame::Complete(frame_ref.clone()); + match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) { + Ok(provider_response) => { + self.record_ttft_if_needed(); + + // Extract index from the event if available + let event_index = + if let ProviderStreamResponseType::MessagesStreamEvent(ref evt) = + provider_response + { + use hermesllm::apis::anthropic::MessagesStreamEvent; + match evt { + MessagesStreamEvent::ContentBlockStart { + index, .. + } => Some(*index as i32), + MessagesStreamEvent::ContentBlockDelta { + index, .. + } => Some(*index as i32), + MessagesStreamEvent::ContentBlockStop { index, .. } => { + Some(*index as i32) + } + _ => None, + } + } else { + None + }; + + // Check event type to track ContentBlockStart + if let Some(event_type) = provider_response.event_type() { + match event_type { + "content_block_start" => { + // Mark that we've seen ContentBlockStart for this index + if let (Some(decoder), Some(index)) = + (self.binary_frame_decoder.as_mut(), event_index) + { + decoder.set_content_block_start_sent(index); + debug!( + "[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}", + self.request_identifier(), + index + ); + } + } + "content_block_delta" => { + // Check if ContentBlockStart was sent for this index + if let Some(index) = event_index { + let needs_start = if let Some(decoder) = + self.binary_frame_decoder.as_ref() + { + !decoder.has_content_block_start_been_sent(index) + } else { + false + }; + + if needs_start { + // Emit empty ContentBlockStart before delta + use hermesllm::apis::anthropic::{ + MessagesContentBlock, MessagesStreamEvent, + }; + let content_block_start = + MessagesStreamEvent::ContentBlockStart { + index: index as u32, + content_block: MessagesContentBlock::Text { + text: String::new(), + cache_control: None, + }, + }; + let start_sse: String = content_block_start.into(); + response_buffer + .extend_from_slice(start_sse.as_bytes()); + + // Mark that we've now sent it + if let Some(decoder) = + self.binary_frame_decoder.as_mut() + { + decoder.set_content_block_start_sent(index); + } + + debug!( + "[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}", + self.request_identifier(), + index + ); + } + } + } + _ => {} + } + } + + let sse_string: String = provider_response.into(); + response_buffer.extend_from_slice(sse_string.as_bytes()); + } + Err(e) => { + warn!( + "[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}", + self.request_identifier(), + e + ); + } + } + } + Some(DecodedFrame::Incomplete) => { + // Incomplete frame - buffer retains partial data, wait for more bytes + debug!( + "[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data", + self.request_identifier() + ); + break; + } + None => { + // Decode error + warn!( + "[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR", + self.request_identifier() + ); + return Err(Action::Continue); + } + } + } + + // Return accumulated complete frames (may be empty if all frames incomplete) + Ok(response_buffer) + } + fn handle_non_streaming_response( &mut self, body: &[u8], diff --git a/demos/use_cases/claude_code_router/config.yaml b/demos/use_cases/claude_code_router/config.yaml index 11a98c07..2a727c2a 100644 --- a/demos/use_cases/claude_code_router/config.yaml +++ b/demos/use_cases/claude_code_router/config.yaml @@ -9,8 +9,10 @@ listeners: llm_providers: # OpenAI Models - - model: openai/gpt-5-2025-08-07 - access_key: $OPENAI_API_KEY + + - model: amazon_bedrock/us.amazon.nova-premier-v1:0 + access_key: $AWS_BEARER_TOKEN_BEDROCK + base_url: https://bedrock-runtime.us-west-2.amazonaws.com routing_preferences: - name: code generation description: generating new code snippets, functions, or boilerplate based on user prompts or requirements @@ -26,7 +28,7 @@ llm_providers: default: true access_key: $ANTHROPIC_API_KEY - - model: anthropic/claude-3-haiku-20240307 + - model: anthropic/claude-haiku-4-5-20251001 access_key: $ANTHROPIC_API_KEY # Ollama Models @@ -38,4 +40,4 @@ llm_providers: model_aliases: # Alias for a small faster Claude model arch.claude.code.small.fast: - target: claude-3-haiku-20240307 + target: claude-haiku-4-5-20251001 diff --git a/tests/e2e/response_with_tools.hex b/tests/e2e/response_with_tools.hex new file mode 100644 index 00000000..5aa41165 Binary files /dev/null and b/tests/e2e/response_with_tools.hex differ diff --git a/tests/e2e/test_model_alias_routing.py b/tests/e2e/test_model_alias_routing.py index 6539deb6..c285bda8 100644 --- a/tests/e2e/test_model_alias_routing.py +++ b/tests/e2e/test_model_alias_routing.py @@ -499,3 +499,138 @@ def test_anthropic_client_with_coding_model_alias_and_tools(): # Should get either text response or tool use blocks for coding assistance assert text_content or len(tool_use_blocks) > 0 + + +@pytest.mark.flaky(retries=0) # Disable retries to see the actual failure +def test_anthropic_client_with_coding_model_alias_and_tools_streaming(): + """Test Anthropic client using 'coding-model' alias (maps to Bedrock) with coding question and tools - streaming""" + logger.info( + "Testing Anthropic client with 'coding-model' alias -> Bedrock with tools (streaming)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = anthropic.Anthropic(api_key="test-key", base_url=base_url) + + text_chunks = [] + tool_use_blocks = [] + all_events = [] # Capture all events for debugging + + try: + with client.messages.stream( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=1000, + messages=[ + { + "role": "user", + "content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?", + } + ], + tools=[ + { + "name": "run_python_code", + "description": "Execute Python code and return the result", + "input_schema": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + } + }, + "required": ["code"], + }, + } + ], + tool_choice={"type": "auto"}, + ) as stream: + for event in stream: + # Extract index if available + index = getattr(event, "index", None) + + # Log and capture all events for debugging + all_events.append( + {"type": event.type, "index": index, "event": str(event)[:200]} + ) + logger.info(f"Event #{len(all_events)}: {event.type} [index={index}]") + + # Collect text deltas + if event.type == "content_block_delta" and hasattr(event, "delta"): + if event.delta.type == "text_delta": + text_chunks.append(event.delta.text) + + # Collect tool use blocks + if event.type == "content_block_start" and hasattr( + event, "content_block" + ): + if event.content_block.type == "tool_use": + tool_use_blocks.append(event.content_block) + + final_message = stream.get_final_message() + except Exception as e: + logger.error(f"Exception during streaming: {type(e).__name__}: {e}") + logger.error(f"Events received before error: {len(all_events)}") + logger.error(f"Text chunks collected: {len(text_chunks)}") + logger.error(f"Tool use blocks collected: {len(tool_use_blocks)}") + logger.error("\nLast 20 events before crash:") + for evt in all_events[-20:]: + logger.error(f" {evt['type']:30s} index={evt['index']}") + raise + + full_text = "".join(text_chunks) + logger.info(f"Streaming response from coding-model with tools: {full_text}") + logger.info(f"Total events received: {len(all_events)}") + logger.info( + f"Text chunks: {len(text_chunks)}, Tool use blocks: {len(tool_use_blocks)}" + ) + + # Should get either text response or tool use blocks for coding assistance + # Modified assertion to be more lenient and provide better error messages + assert ( + full_text or len(tool_use_blocks) > 0 + ), f"Expected text or tool use. Got text_len={len(full_text)}, tools={len(tool_use_blocks)}, events={len(all_events)}" + + # Verify final message structure + assert final_message is not None, "Final message should not be None" + assert ( + final_message.content and len(final_message.content) > 0 + ), f"Final message should have content. Got: {final_message.content if final_message else 'None'}" + + +def test_anthropic_client_streaming_with_bedrock(): + """Test Anthropic client using 'coding-model' alias (maps to Bedrock) with streaming""" + logger.info( + "Testing Anthropic client with 'coding-model' alias -> Bedrock (streaming)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = anthropic.Anthropic(api_key="test-key", base_url=base_url) + + text_chunks = [] + + with client.messages.stream( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=500, + messages=[ + { + "role": "user", + "content": "Write a short 4-line sonnet about coding.", + } + ], + ) as stream: + for event in stream: + # Collect text deltas + if event.type == "content_block_delta" and hasattr(event, "delta"): + if event.delta.type == "text_delta": + text_chunks.append(event.delta.text) + + final_message = stream.get_final_message() + + full_text = "".join(text_chunks) + logger.info(f"Response: {full_text}") + + # Should get a text response + assert len(full_text) > 0, "Expected text response from streaming" + + # Verify final message structure + assert final_message is not None + assert final_message.content and len(final_message.content) > 0