Claude Code works with Amazon Bedrock

This commit is contained in:
Salman Paracha 2025-10-20 13:48:26 -07:00
parent d826de382a
commit db44602cb8
8 changed files with 441 additions and 32 deletions

View file

@ -693,16 +693,22 @@ pub struct ContentBlockStartEvent {
/// Content block start information
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
#[serde(untagged)]
pub enum ContentBlockStart {
#[serde(rename = "toolUse")]
ToolUse {
#[serde(rename = "toolUseId")]
tool_use_id: String,
name: String,
#[serde(rename = "toolUse")]
tool_use: ToolUseStart,
},
}
/// Tool use start information
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ToolUseStart {
#[serde(rename = "toolUseId")]
pub tool_use_id: String,
pub name: String,
}
/// Content block delta event
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ContentBlockDeltaEvent {
@ -718,7 +724,15 @@ pub struct ContentBlockDeltaEvent {
#[serde(untagged)]
pub enum ContentBlockDelta {
Text { text: String },
ToolUse { input: String },
ToolUse {
#[serde(rename = "toolUse")]
tool_use: ToolUseDelta
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ToolUseDelta {
pub input: String,
}
/// Content block stop event

View file

@ -478,7 +478,7 @@ where
{
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder,
buffer: B,
has_content_block_start_been_sent: bool,
content_block_start_indices: std::collections::HashSet<i32>,
}
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
@ -488,7 +488,7 @@ impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
Self {
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(),
buffer,
has_content_block_start_been_sent: false,
content_block_start_indices: std::collections::HashSet::new(),
}
}
}
@ -501,7 +501,7 @@ where
Self {
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(),
buffer,
has_content_block_start_been_sent: false,
content_block_start_indices: std::collections::HashSet::new(),
}
}
@ -521,14 +521,14 @@ where
self.buffer.has_remaining()
}
/// Check if a content_block_start event has been sent
pub fn has_content_block_start_been_sent(&self) -> bool {
self.has_content_block_start_been_sent
/// Check if a content_block_start event has been sent for the given index
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
self.content_block_start_indices.contains(&index)
}
/// Set the content_block_start flag
pub fn set_content_block_start_sent(&mut self, sent: bool) {
self.has_content_block_start_been_sent = sent;
/// Mark that a content_block_start event has been sent for the given index
pub fn set_content_block_start_sent(&mut self, index: i32) {
self.content_block_start_indices.insert(index);
}
}
@ -1122,6 +1122,17 @@ mod tests {
test_bedrock_conversion(true);
}
#[test]
fn test_bedrock_decoded_frame_with_tool_use() {
test_bedrock_conversion_with_tools(false);
}
#[test]
#[ignore] // Run with: cargo test -- --ignored --nocapture
fn test_bedrock_decoded_frame_with_tool_use_verbose() {
test_bedrock_conversion_with_tools(true);
}
fn test_bedrock_conversion(verbose: bool) {
use bytes::BytesMut;
use std::fs;
@ -1194,6 +1205,93 @@ mod tests {
assert!(message_start_seen, "Should have seen MessageStart event");
}
fn test_bedrock_conversion_with_tools(verbose: bool) {
use bytes::BytesMut;
use std::fs;
use std::path::PathBuf;
// Read the actual response_with_tools.hex file from tests/e2e directory
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../../tests/e2e/response_with_tools.hex");
// Only run this test if the file exists
if !test_file.exists() {
println!("Skipping test - response_with_tools.hex not found");
return;
}
let response_data = fs::read(&test_file).unwrap();
let mut buffer = BytesMut::from(&response_data[..]);
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream);
let mut conversion_count = 0;
let mut message_start_seen = false;
let mut content_block_start_seen = false;
let mut content_block_delta_tool_use_seen = false;
// Decode and convert frames
loop {
match decoder.decode_frame() {
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
// Convert DecodedFrame to ProviderStreamResponseType
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
match result {
Ok(provider_response) => {
conversion_count += 1;
// Verify we got a MessagesStreamEvent
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_)));
if verbose {
// Print the SSE string output
let sse_string: String = provider_response.clone().into();
println!("{}", sse_string);
}
// Check for specific events related to tool use
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response {
match event {
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => {
message_start_seen = true;
}
crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => {
content_block_start_seen = true;
}
crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) {
content_block_delta_tool_use_seen = true;
}
}
_ => {}
}
}
}
Err(e) => {
println!("Conversion error (frame {}): {}", conversion_count, e);
}
}
}
Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => {
// End of buffer
break;
}
None => {
panic!("Decode error");
}
}
}
assert!(conversion_count > 0, "Should have converted at least one frame");
assert!(message_start_seen, "Should have seen MessageStart event");
assert!(content_block_start_seen, "Should have seen ContentBlockStart event for tool use");
assert!(content_block_delta_tool_use_seen, "Should have seen ContentBlockDelta with ToolUseDelta");
}
#[test]
fn test_sse_event_transformation_openai_to_anthropic_message_start() {
use crate::apis::openai::OpenAIApi;

View file

@ -261,12 +261,12 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
// Anthropic expects the same pattern, so we initialize with an empty input object
match start_event.start {
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use_id, name } => {
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
Ok(MessagesStreamEvent::ContentBlockStart {
index: start_event.content_block_index as u32,
content_block: MessagesContentBlock::ToolUse {
id: tool_use_id,
name,
id: tool_use.tool_use_id,
name: tool_use.name,
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
cache_control: None,
},
@ -281,8 +281,8 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
ContentBlockDelta::Text { text } => {
MessagesContentDelta::TextDelta { text }
}
ContentBlockDelta::ToolUse { input } => {
MessagesContentDelta::InputJsonDelta { partial_json: input }
ContentBlockDelta::ToolUse { tool_use } => {
MessagesContentDelta::InputJsonDelta { partial_json: tool_use.input }
}
};

View file

@ -280,7 +280,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
use crate::apis::amazon_bedrock::ContentBlockStart;
match start_event.start {
ContentBlockStart::ToolUse { tool_use_id, name } => {
ContentBlockStart::ToolUse { tool_use } => {
Ok(create_openai_chunk(
"stream",
"unknown",
@ -291,10 +291,10 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: start_event.content_block_index as u32,
id: Some(tool_use_id),
id: Some(tool_use.tool_use_id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(name),
name: Some(tool_use.name),
arguments: Some("".to_string()),
}),
}]),
@ -325,7 +325,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
None,
))
}
ContentBlockDelta::ToolUse { input } => {
ContentBlockDelta::ToolUse { tool_use } => {
Ok(create_openai_chunk(
"stream",
"unknown",
@ -340,7 +340,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(input),
arguments: Some(tool_use.input),
}),
}]),
},

View file

@ -1,4 +1,3 @@
use bytes::Buf;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use http::StatusCode;
use log::{debug, info, warn};
@ -25,9 +24,11 @@ use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::clients::endpoints::SupportedAPIs;
use hermesllm::providers::response::{
BedrockBinaryFrameDecoder, ProviderResponse, SseEvent, SseStreamIter,
BedrockBinaryFrameDecoder, ProviderResponse, ProviderStreamResponse, SseEvent, SseStreamIter,
};
use hermesllm::{
DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType,
};
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
pub struct StreamContext {
metrics: Rc<Metrics>,
@ -424,6 +425,14 @@ impl StreamContext {
let upstream_api =
provider_id.compatible_api_for_client(&client_api, self.streaming_response);
// Check if this is Bedrock binary stream
if matches!(
upstream_api,
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
) {
return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api);
}
// Parse body into SSE iterator using TryFrom
let sse_iter: SseStreamIter<std::vec::IntoIter<String>> =
match SseStreamIter::try_from(body) {
@ -499,6 +508,157 @@ impl StreamContext {
}
}
fn handle_bedrock_binary_stream(
&mut self,
body: &[u8],
client_api: &SupportedAPIs,
upstream_api: &SupportedUpstreamAPIs,
) -> Result<Vec<u8>, Action> {
use hermesllm::providers::response::ProviderStreamResponseType;
// Initialize decoder if not present
if self.binary_frame_decoder.is_none() {
self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[]));
}
// Add incoming bytes to buffer
if let Some(decoder) = self.binary_frame_decoder.as_mut() {
decoder.buffer_mut().extend_from_slice(body);
}
let mut response_buffer = Vec::new();
// Decode all available complete frames
loop {
let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame();
match decoded_frame {
Some(DecodedFrame::Complete(ref frame_ref)) => {
// Convert frame to ProviderStreamResponseType
let frame = DecodedFrame::Complete(frame_ref.clone());
match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) {
Ok(provider_response) => {
self.record_ttft_if_needed();
// Extract index from the event if available
let event_index =
if let ProviderStreamResponseType::MessagesStreamEvent(ref evt) =
provider_response
{
use hermesllm::apis::anthropic::MessagesStreamEvent;
match evt {
MessagesStreamEvent::ContentBlockStart {
index, ..
} => Some(*index as i32),
MessagesStreamEvent::ContentBlockDelta {
index, ..
} => Some(*index as i32),
MessagesStreamEvent::ContentBlockStop { index, .. } => {
Some(*index as i32)
}
_ => None,
}
} else {
None
};
// Check event type to track ContentBlockStart
if let Some(event_type) = provider_response.event_type() {
match event_type {
"content_block_start" => {
// Mark that we've seen ContentBlockStart for this index
if let (Some(decoder), Some(index)) =
(self.binary_frame_decoder.as_mut(), event_index)
{
decoder.set_content_block_start_sent(index);
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}",
self.request_identifier(),
index
);
}
}
"content_block_delta" => {
// Check if ContentBlockStart was sent for this index
if let Some(index) = event_index {
let needs_start = if let Some(decoder) =
self.binary_frame_decoder.as_ref()
{
!decoder.has_content_block_start_been_sent(index)
} else {
false
};
if needs_start {
// Emit empty ContentBlockStart before delta
use hermesllm::apis::anthropic::{
MessagesContentBlock, MessagesStreamEvent,
};
let content_block_start =
MessagesStreamEvent::ContentBlockStart {
index: index as u32,
content_block: MessagesContentBlock::Text {
text: String::new(),
cache_control: None,
},
};
let start_sse: String = content_block_start.into();
response_buffer
.extend_from_slice(start_sse.as_bytes());
// Mark that we've now sent it
if let Some(decoder) =
self.binary_frame_decoder.as_mut()
{
decoder.set_content_block_start_sent(index);
}
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}",
self.request_identifier(),
index
);
}
}
}
_ => {}
}
}
let sse_string: String = provider_response.into();
response_buffer.extend_from_slice(sse_string.as_bytes());
}
Err(e) => {
warn!(
"[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
self.request_identifier(),
e
);
}
}
}
Some(DecodedFrame::Incomplete) => {
// Incomplete frame - buffer retains partial data, wait for more bytes
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
self.request_identifier()
);
break;
}
None => {
// Decode error
warn!(
"[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR",
self.request_identifier()
);
return Err(Action::Continue);
}
}
}
// Return accumulated complete frames (may be empty if all frames incomplete)
Ok(response_buffer)
}
fn handle_non_streaming_response(
&mut self,
body: &[u8],