Claude Code works with Amazon Bedrock

This commit is contained in:
Salman Paracha 2025-10-20 13:48:26 -07:00
parent d826de382a
commit db44602cb8
8 changed files with 441 additions and 32 deletions

View file

@ -693,16 +693,22 @@ pub struct ContentBlockStartEvent {
/// Content block start information
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
#[serde(untagged)]
pub enum ContentBlockStart {
#[serde(rename = "toolUse")]
ToolUse {
#[serde(rename = "toolUseId")]
tool_use_id: String,
name: String,
#[serde(rename = "toolUse")]
tool_use: ToolUseStart,
},
}
/// Tool use start information
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ToolUseStart {
#[serde(rename = "toolUseId")]
pub tool_use_id: String,
pub name: String,
}
/// Content block delta event
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ContentBlockDeltaEvent {
@ -718,7 +724,15 @@ pub struct ContentBlockDeltaEvent {
#[serde(untagged)]
pub enum ContentBlockDelta {
Text { text: String },
ToolUse { input: String },
ToolUse {
#[serde(rename = "toolUse")]
tool_use: ToolUseDelta
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ToolUseDelta {
pub input: String,
}
/// Content block stop event

View file

@ -478,7 +478,7 @@ where
{
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder,
buffer: B,
has_content_block_start_been_sent: bool,
content_block_start_indices: std::collections::HashSet<i32>,
}
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
@ -488,7 +488,7 @@ impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
Self {
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(),
buffer,
has_content_block_start_been_sent: false,
content_block_start_indices: std::collections::HashSet::new(),
}
}
}
@ -501,7 +501,7 @@ where
Self {
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(),
buffer,
has_content_block_start_been_sent: false,
content_block_start_indices: std::collections::HashSet::new(),
}
}
@ -521,14 +521,14 @@ where
self.buffer.has_remaining()
}
/// Check if a content_block_start event has been sent
pub fn has_content_block_start_been_sent(&self) -> bool {
self.has_content_block_start_been_sent
/// Check if a content_block_start event has been sent for the given index
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
self.content_block_start_indices.contains(&index)
}
/// Set the content_block_start flag
pub fn set_content_block_start_sent(&mut self, sent: bool) {
self.has_content_block_start_been_sent = sent;
/// Mark that a content_block_start event has been sent for the given index
pub fn set_content_block_start_sent(&mut self, index: i32) {
self.content_block_start_indices.insert(index);
}
}
@ -1122,6 +1122,17 @@ mod tests {
test_bedrock_conversion(true);
}
#[test]
fn test_bedrock_decoded_frame_with_tool_use() {
test_bedrock_conversion_with_tools(false);
}
#[test]
#[ignore] // Run with: cargo test -- --ignored --nocapture
fn test_bedrock_decoded_frame_with_tool_use_verbose() {
test_bedrock_conversion_with_tools(true);
}
fn test_bedrock_conversion(verbose: bool) {
use bytes::BytesMut;
use std::fs;
@ -1194,6 +1205,93 @@ mod tests {
assert!(message_start_seen, "Should have seen MessageStart event");
}
fn test_bedrock_conversion_with_tools(verbose: bool) {
use bytes::BytesMut;
use std::fs;
use std::path::PathBuf;
// Read the actual response_with_tools.hex file from tests/e2e directory
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../../tests/e2e/response_with_tools.hex");
// Only run this test if the file exists
if !test_file.exists() {
println!("Skipping test - response_with_tools.hex not found");
return;
}
let response_data = fs::read(&test_file).unwrap();
let mut buffer = BytesMut::from(&response_data[..]);
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream);
let mut conversion_count = 0;
let mut message_start_seen = false;
let mut content_block_start_seen = false;
let mut content_block_delta_tool_use_seen = false;
// Decode and convert frames
loop {
match decoder.decode_frame() {
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
// Convert DecodedFrame to ProviderStreamResponseType
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
match result {
Ok(provider_response) => {
conversion_count += 1;
// Verify we got a MessagesStreamEvent
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_)));
if verbose {
// Print the SSE string output
let sse_string: String = provider_response.clone().into();
println!("{}", sse_string);
}
// Check for specific events related to tool use
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response {
match event {
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => {
message_start_seen = true;
}
crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => {
content_block_start_seen = true;
}
crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) {
content_block_delta_tool_use_seen = true;
}
}
_ => {}
}
}
}
Err(e) => {
println!("Conversion error (frame {}): {}", conversion_count, e);
}
}
}
Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => {
// End of buffer
break;
}
None => {
panic!("Decode error");
}
}
}
assert!(conversion_count > 0, "Should have converted at least one frame");
assert!(message_start_seen, "Should have seen MessageStart event");
assert!(content_block_start_seen, "Should have seen ContentBlockStart event for tool use");
assert!(content_block_delta_tool_use_seen, "Should have seen ContentBlockDelta with ToolUseDelta");
}
#[test]
fn test_sse_event_transformation_openai_to_anthropic_message_start() {
use crate::apis::openai::OpenAIApi;

View file

@ -261,12 +261,12 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
// Anthropic expects the same pattern, so we initialize with an empty input object
match start_event.start {
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use_id, name } => {
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
Ok(MessagesStreamEvent::ContentBlockStart {
index: start_event.content_block_index as u32,
content_block: MessagesContentBlock::ToolUse {
id: tool_use_id,
name,
id: tool_use.tool_use_id,
name: tool_use.name,
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
cache_control: None,
},
@ -281,8 +281,8 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
ContentBlockDelta::Text { text } => {
MessagesContentDelta::TextDelta { text }
}
ContentBlockDelta::ToolUse { input } => {
MessagesContentDelta::InputJsonDelta { partial_json: input }
ContentBlockDelta::ToolUse { tool_use } => {
MessagesContentDelta::InputJsonDelta { partial_json: tool_use.input }
}
};

View file

@ -280,7 +280,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
use crate::apis::amazon_bedrock::ContentBlockStart;
match start_event.start {
ContentBlockStart::ToolUse { tool_use_id, name } => {
ContentBlockStart::ToolUse { tool_use } => {
Ok(create_openai_chunk(
"stream",
"unknown",
@ -291,10 +291,10 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: start_event.content_block_index as u32,
id: Some(tool_use_id),
id: Some(tool_use.tool_use_id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(name),
name: Some(tool_use.name),
arguments: Some("".to_string()),
}),
}]),
@ -325,7 +325,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
None,
))
}
ContentBlockDelta::ToolUse { input } => {
ContentBlockDelta::ToolUse { tool_use } => {
Ok(create_openai_chunk(
"stream",
"unknown",
@ -340,7 +340,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(input),
arguments: Some(tool_use.input),
}),
}]),
},

View file

@ -1,4 +1,3 @@
use bytes::Buf;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use http::StatusCode;
use log::{debug, info, warn};
@ -25,9 +24,11 @@ use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::clients::endpoints::SupportedAPIs;
use hermesllm::providers::response::{
BedrockBinaryFrameDecoder, ProviderResponse, SseEvent, SseStreamIter,
BedrockBinaryFrameDecoder, ProviderResponse, ProviderStreamResponse, SseEvent, SseStreamIter,
};
use hermesllm::{
DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType,
};
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
pub struct StreamContext {
metrics: Rc<Metrics>,
@ -424,6 +425,14 @@ impl StreamContext {
let upstream_api =
provider_id.compatible_api_for_client(&client_api, self.streaming_response);
// Check if this is Bedrock binary stream
if matches!(
upstream_api,
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
) {
return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api);
}
// Parse body into SSE iterator using TryFrom
let sse_iter: SseStreamIter<std::vec::IntoIter<String>> =
match SseStreamIter::try_from(body) {
@ -499,6 +508,157 @@ impl StreamContext {
}
}
fn handle_bedrock_binary_stream(
&mut self,
body: &[u8],
client_api: &SupportedAPIs,
upstream_api: &SupportedUpstreamAPIs,
) -> Result<Vec<u8>, Action> {
use hermesllm::providers::response::ProviderStreamResponseType;
// Initialize decoder if not present
if self.binary_frame_decoder.is_none() {
self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[]));
}
// Add incoming bytes to buffer
if let Some(decoder) = self.binary_frame_decoder.as_mut() {
decoder.buffer_mut().extend_from_slice(body);
}
let mut response_buffer = Vec::new();
// Decode all available complete frames
loop {
let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame();
match decoded_frame {
Some(DecodedFrame::Complete(ref frame_ref)) => {
// Convert frame to ProviderStreamResponseType
let frame = DecodedFrame::Complete(frame_ref.clone());
match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) {
Ok(provider_response) => {
self.record_ttft_if_needed();
// Extract index from the event if available
let event_index =
if let ProviderStreamResponseType::MessagesStreamEvent(ref evt) =
provider_response
{
use hermesllm::apis::anthropic::MessagesStreamEvent;
match evt {
MessagesStreamEvent::ContentBlockStart {
index, ..
} => Some(*index as i32),
MessagesStreamEvent::ContentBlockDelta {
index, ..
} => Some(*index as i32),
MessagesStreamEvent::ContentBlockStop { index, .. } => {
Some(*index as i32)
}
_ => None,
}
} else {
None
};
// Check event type to track ContentBlockStart
if let Some(event_type) = provider_response.event_type() {
match event_type {
"content_block_start" => {
// Mark that we've seen ContentBlockStart for this index
if let (Some(decoder), Some(index)) =
(self.binary_frame_decoder.as_mut(), event_index)
{
decoder.set_content_block_start_sent(index);
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}",
self.request_identifier(),
index
);
}
}
"content_block_delta" => {
// Check if ContentBlockStart was sent for this index
if let Some(index) = event_index {
let needs_start = if let Some(decoder) =
self.binary_frame_decoder.as_ref()
{
!decoder.has_content_block_start_been_sent(index)
} else {
false
};
if needs_start {
// Emit empty ContentBlockStart before delta
use hermesllm::apis::anthropic::{
MessagesContentBlock, MessagesStreamEvent,
};
let content_block_start =
MessagesStreamEvent::ContentBlockStart {
index: index as u32,
content_block: MessagesContentBlock::Text {
text: String::new(),
cache_control: None,
},
};
let start_sse: String = content_block_start.into();
response_buffer
.extend_from_slice(start_sse.as_bytes());
// Mark that we've now sent it
if let Some(decoder) =
self.binary_frame_decoder.as_mut()
{
decoder.set_content_block_start_sent(index);
}
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}",
self.request_identifier(),
index
);
}
}
}
_ => {}
}
}
let sse_string: String = provider_response.into();
response_buffer.extend_from_slice(sse_string.as_bytes());
}
Err(e) => {
warn!(
"[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
self.request_identifier(),
e
);
}
}
}
Some(DecodedFrame::Incomplete) => {
// Incomplete frame - buffer retains partial data, wait for more bytes
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
self.request_identifier()
);
break;
}
None => {
// Decode error
warn!(
"[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR",
self.request_identifier()
);
return Err(Action::Continue);
}
}
}
// Return accumulated complete frames (may be empty if all frames incomplete)
Ok(response_buffer)
}
fn handle_non_streaming_response(
&mut self,
body: &[u8],

View file

@ -9,8 +9,10 @@ listeners:
llm_providers:
# OpenAI Models
- model: openai/gpt-5-2025-08-07
access_key: $OPENAI_API_KEY
- model: amazon_bedrock/us.amazon.nova-premier-v1:0
access_key: $AWS_BEARER_TOKEN_BEDROCK
base_url: https://bedrock-runtime.us-west-2.amazonaws.com
routing_preferences:
- name: code generation
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
@ -26,7 +28,7 @@ llm_providers:
default: true
access_key: $ANTHROPIC_API_KEY
- model: anthropic/claude-3-haiku-20240307
- model: anthropic/claude-haiku-4-5-20251001
access_key: $ANTHROPIC_API_KEY
# Ollama Models
@ -38,4 +40,4 @@ llm_providers:
model_aliases:
# Alias for a small faster Claude model
arch.claude.code.small.fast:
target: claude-3-haiku-20240307
target: claude-haiku-4-5-20251001

Binary file not shown.

View file

@ -499,3 +499,138 @@ def test_anthropic_client_with_coding_model_alias_and_tools():
# Should get either text response or tool use blocks for coding assistance
assert text_content or len(tool_use_blocks) > 0
@pytest.mark.flaky(retries=0) # Disable retries to see the actual failure
def test_anthropic_client_with_coding_model_alias_and_tools_streaming():
"""Test Anthropic client using 'coding-model' alias (maps to Bedrock) with coding question and tools - streaming"""
logger.info(
"Testing Anthropic client with 'coding-model' alias -> Bedrock with tools (streaming)"
)
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
client = anthropic.Anthropic(api_key="test-key", base_url=base_url)
text_chunks = []
tool_use_blocks = []
all_events = [] # Capture all events for debugging
try:
with client.messages.stream(
model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0
max_tokens=1000,
messages=[
{
"role": "user",
"content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?",
}
],
tools=[
{
"name": "run_python_code",
"description": "Execute Python code and return the result",
"input_schema": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python code to execute",
}
},
"required": ["code"],
},
}
],
tool_choice={"type": "auto"},
) as stream:
for event in stream:
# Extract index if available
index = getattr(event, "index", None)
# Log and capture all events for debugging
all_events.append(
{"type": event.type, "index": index, "event": str(event)[:200]}
)
logger.info(f"Event #{len(all_events)}: {event.type} [index={index}]")
# Collect text deltas
if event.type == "content_block_delta" and hasattr(event, "delta"):
if event.delta.type == "text_delta":
text_chunks.append(event.delta.text)
# Collect tool use blocks
if event.type == "content_block_start" and hasattr(
event, "content_block"
):
if event.content_block.type == "tool_use":
tool_use_blocks.append(event.content_block)
final_message = stream.get_final_message()
except Exception as e:
logger.error(f"Exception during streaming: {type(e).__name__}: {e}")
logger.error(f"Events received before error: {len(all_events)}")
logger.error(f"Text chunks collected: {len(text_chunks)}")
logger.error(f"Tool use blocks collected: {len(tool_use_blocks)}")
logger.error("\nLast 20 events before crash:")
for evt in all_events[-20:]:
logger.error(f" {evt['type']:30s} index={evt['index']}")
raise
full_text = "".join(text_chunks)
logger.info(f"Streaming response from coding-model with tools: {full_text}")
logger.info(f"Total events received: {len(all_events)}")
logger.info(
f"Text chunks: {len(text_chunks)}, Tool use blocks: {len(tool_use_blocks)}"
)
# Should get either text response or tool use blocks for coding assistance
# Modified assertion to be more lenient and provide better error messages
assert (
full_text or len(tool_use_blocks) > 0
), f"Expected text or tool use. Got text_len={len(full_text)}, tools={len(tool_use_blocks)}, events={len(all_events)}"
# Verify final message structure
assert final_message is not None, "Final message should not be None"
assert (
final_message.content and len(final_message.content) > 0
), f"Final message should have content. Got: {final_message.content if final_message else 'None'}"
def test_anthropic_client_streaming_with_bedrock():
"""Test Anthropic client using 'coding-model' alias (maps to Bedrock) with streaming"""
logger.info(
"Testing Anthropic client with 'coding-model' alias -> Bedrock (streaming)"
)
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
client = anthropic.Anthropic(api_key="test-key", base_url=base_url)
text_chunks = []
with client.messages.stream(
model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0
max_tokens=500,
messages=[
{
"role": "user",
"content": "Write a short 4-line sonnet about coding.",
}
],
) as stream:
for event in stream:
# Collect text deltas
if event.type == "content_block_delta" and hasattr(event, "delta"):
if event.delta.type == "text_delta":
text_chunks.append(event.delta.text)
final_message = stream.get_final_message()
full_text = "".join(text_chunks)
logger.info(f"Response: {full_text}")
# Should get a text response
assert len(full_text) > 0, "Expected text response from streaming"
# Verify final message structure
assert final_message is not None
assert final_message.content and len(final_message.content) > 0