mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
Claude Code works with Amazon Bedrock
This commit is contained in:
parent
d826de382a
commit
db44602cb8
8 changed files with 441 additions and 32 deletions
|
|
@ -693,16 +693,22 @@ pub struct ContentBlockStartEvent {
|
|||
|
||||
/// Content block start information
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentBlockStart {
|
||||
#[serde(rename = "toolUse")]
|
||||
ToolUse {
|
||||
#[serde(rename = "toolUseId")]
|
||||
tool_use_id: String,
|
||||
name: String,
|
||||
#[serde(rename = "toolUse")]
|
||||
tool_use: ToolUseStart,
|
||||
},
|
||||
}
|
||||
|
||||
/// Tool use start information
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ToolUseStart {
|
||||
#[serde(rename = "toolUseId")]
|
||||
pub tool_use_id: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// Content block delta event
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ContentBlockDeltaEvent {
|
||||
|
|
@ -718,7 +724,15 @@ pub struct ContentBlockDeltaEvent {
|
|||
#[serde(untagged)]
|
||||
pub enum ContentBlockDelta {
|
||||
Text { text: String },
|
||||
ToolUse { input: String },
|
||||
ToolUse {
|
||||
#[serde(rename = "toolUse")]
|
||||
tool_use: ToolUseDelta
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ToolUseDelta {
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
/// Content block stop event
|
||||
|
|
|
|||
|
|
@ -478,7 +478,7 @@ where
|
|||
{
|
||||
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder,
|
||||
buffer: B,
|
||||
has_content_block_start_been_sent: bool,
|
||||
content_block_start_indices: std::collections::HashSet<i32>,
|
||||
}
|
||||
|
||||
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
||||
|
|
@ -488,7 +488,7 @@ impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
|||
Self {
|
||||
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
has_content_block_start_been_sent: false,
|
||||
content_block_start_indices: std::collections::HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -501,7 +501,7 @@ where
|
|||
Self {
|
||||
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
has_content_block_start_been_sent: false,
|
||||
content_block_start_indices: std::collections::HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -521,14 +521,14 @@ where
|
|||
self.buffer.has_remaining()
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent
|
||||
pub fn has_content_block_start_been_sent(&self) -> bool {
|
||||
self.has_content_block_start_been_sent
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
}
|
||||
|
||||
/// Set the content_block_start flag
|
||||
pub fn set_content_block_start_sent(&mut self, sent: bool) {
|
||||
self.has_content_block_start_been_sent = sent;
|
||||
/// Mark that a content_block_start event has been sent for the given index
|
||||
pub fn set_content_block_start_sent(&mut self, index: i32) {
|
||||
self.content_block_start_indices.insert(index);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1122,6 +1122,17 @@ mod tests {
|
|||
test_bedrock_conversion(true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bedrock_decoded_frame_with_tool_use() {
|
||||
test_bedrock_conversion_with_tools(false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Run with: cargo test -- --ignored --nocapture
|
||||
fn test_bedrock_decoded_frame_with_tool_use_verbose() {
|
||||
test_bedrock_conversion_with_tools(true);
|
||||
}
|
||||
|
||||
fn test_bedrock_conversion(verbose: bool) {
|
||||
use bytes::BytesMut;
|
||||
use std::fs;
|
||||
|
|
@ -1194,6 +1205,93 @@ mod tests {
|
|||
assert!(message_start_seen, "Should have seen MessageStart event");
|
||||
}
|
||||
|
||||
fn test_bedrock_conversion_with_tools(verbose: bool) {
|
||||
use bytes::BytesMut;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Read the actual response_with_tools.hex file from tests/e2e directory
|
||||
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../tests/e2e/response_with_tools.hex");
|
||||
|
||||
// Only run this test if the file exists
|
||||
if !test_file.exists() {
|
||||
println!("Skipping test - response_with_tools.hex not found");
|
||||
return;
|
||||
}
|
||||
|
||||
let response_data = fs::read(&test_file).unwrap();
|
||||
let mut buffer = BytesMut::from(&response_data[..]);
|
||||
|
||||
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
|
||||
|
||||
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream);
|
||||
|
||||
let mut conversion_count = 0;
|
||||
let mut message_start_seen = false;
|
||||
let mut content_block_start_seen = false;
|
||||
let mut content_block_delta_tool_use_seen = false;
|
||||
|
||||
// Decode and convert frames
|
||||
loop {
|
||||
match decoder.decode_frame() {
|
||||
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
|
||||
// Convert DecodedFrame to ProviderStreamResponseType
|
||||
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
|
||||
|
||||
match result {
|
||||
Ok(provider_response) => {
|
||||
conversion_count += 1;
|
||||
|
||||
// Verify we got a MessagesStreamEvent
|
||||
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_)));
|
||||
|
||||
if verbose {
|
||||
// Print the SSE string output
|
||||
let sse_string: String = provider_response.clone().into();
|
||||
println!("{}", sse_string);
|
||||
}
|
||||
|
||||
// Check for specific events related to tool use
|
||||
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response {
|
||||
match event {
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => {
|
||||
message_start_seen = true;
|
||||
}
|
||||
crate::apis::anthropic::MessagesStreamEvent::ContentBlockStart { .. } => {
|
||||
content_block_start_seen = true;
|
||||
}
|
||||
crate::apis::anthropic::MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
if matches!(delta, crate::apis::anthropic::MessagesContentDelta::InputJsonDelta { .. }) {
|
||||
content_block_delta_tool_use_seen = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Conversion error (frame {}): {}", conversion_count, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(aws_smithy_eventstream::frame::DecodedFrame::Incomplete) => {
|
||||
// End of buffer
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
panic!("Decode error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(conversion_count > 0, "Should have converted at least one frame");
|
||||
assert!(message_start_seen, "Should have seen MessageStart event");
|
||||
assert!(content_block_start_seen, "Should have seen ContentBlockStart event for tool use");
|
||||
assert!(content_block_delta_tool_use_seen, "Should have seen ContentBlockDelta with ToolUseDelta");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_transformation_openai_to_anthropic_message_start() {
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
|
|
|
|||
|
|
@ -261,12 +261,12 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
|||
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
|
||||
// Anthropic expects the same pattern, so we initialize with an empty input object
|
||||
match start_event.start {
|
||||
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use_id, name } => {
|
||||
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: start_event.content_block_index as u32,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: tool_use_id,
|
||||
name,
|
||||
id: tool_use.tool_use_id,
|
||||
name: tool_use.name,
|
||||
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
|
||||
cache_control: None,
|
||||
},
|
||||
|
|
@ -281,8 +281,8 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
|||
ContentBlockDelta::Text { text } => {
|
||||
MessagesContentDelta::TextDelta { text }
|
||||
}
|
||||
ContentBlockDelta::ToolUse { input } => {
|
||||
MessagesContentDelta::InputJsonDelta { partial_json: input }
|
||||
ContentBlockDelta::ToolUse { tool_use } => {
|
||||
MessagesContentDelta::InputJsonDelta { partial_json: tool_use.input }
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -280,7 +280,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
|||
use crate::apis::amazon_bedrock::ContentBlockStart;
|
||||
|
||||
match start_event.start {
|
||||
ContentBlockStart::ToolUse { tool_use_id, name } => {
|
||||
ContentBlockStart::ToolUse { tool_use } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
|
|
@ -291,10 +291,10 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
|||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: start_event.content_block_index as u32,
|
||||
id: Some(tool_use_id),
|
||||
id: Some(tool_use.tool_use_id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(name),
|
||||
name: Some(tool_use.name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
|
|
@ -325,7 +325,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
|||
None,
|
||||
))
|
||||
}
|
||||
ContentBlockDelta::ToolUse { input } => {
|
||||
ContentBlockDelta::ToolUse { tool_use } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
|
|
@ -340,7 +340,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
|||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(input),
|
||||
arguments: Some(tool_use.input),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
use bytes::Buf;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -25,9 +24,11 @@ use common::tracing::{Event, Span, TraceData, Traceparent};
|
|||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::clients::endpoints::SupportedAPIs;
|
||||
use hermesllm::providers::response::{
|
||||
BedrockBinaryFrameDecoder, ProviderResponse, SseEvent, SseStreamIter,
|
||||
BedrockBinaryFrameDecoder, ProviderResponse, ProviderStreamResponse, SseEvent, SseStreamIter,
|
||||
};
|
||||
use hermesllm::{
|
||||
DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType,
|
||||
};
|
||||
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
|
||||
|
||||
pub struct StreamContext {
|
||||
metrics: Rc<Metrics>,
|
||||
|
|
@ -424,6 +425,14 @@ impl StreamContext {
|
|||
let upstream_api =
|
||||
provider_id.compatible_api_for_client(&client_api, self.streaming_response);
|
||||
|
||||
// Check if this is Bedrock binary stream
|
||||
if matches!(
|
||||
upstream_api,
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
|
||||
) {
|
||||
return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api);
|
||||
}
|
||||
|
||||
// Parse body into SSE iterator using TryFrom
|
||||
let sse_iter: SseStreamIter<std::vec::IntoIter<String>> =
|
||||
match SseStreamIter::try_from(body) {
|
||||
|
|
@ -499,6 +508,157 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_bedrock_binary_stream(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
client_api: &SupportedAPIs,
|
||||
upstream_api: &SupportedUpstreamAPIs,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
use hermesllm::providers::response::ProviderStreamResponseType;
|
||||
|
||||
// Initialize decoder if not present
|
||||
if self.binary_frame_decoder.is_none() {
|
||||
self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[]));
|
||||
}
|
||||
|
||||
// Add incoming bytes to buffer
|
||||
if let Some(decoder) = self.binary_frame_decoder.as_mut() {
|
||||
decoder.buffer_mut().extend_from_slice(body);
|
||||
}
|
||||
|
||||
let mut response_buffer = Vec::new();
|
||||
|
||||
// Decode all available complete frames
|
||||
loop {
|
||||
let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame();
|
||||
match decoded_frame {
|
||||
Some(DecodedFrame::Complete(ref frame_ref)) => {
|
||||
// Convert frame to ProviderStreamResponseType
|
||||
let frame = DecodedFrame::Complete(frame_ref.clone());
|
||||
match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) {
|
||||
Ok(provider_response) => {
|
||||
self.record_ttft_if_needed();
|
||||
|
||||
// Extract index from the event if available
|
||||
let event_index =
|
||||
if let ProviderStreamResponseType::MessagesStreamEvent(ref evt) =
|
||||
provider_response
|
||||
{
|
||||
use hermesllm::apis::anthropic::MessagesStreamEvent;
|
||||
match evt {
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index, ..
|
||||
} => Some(*index as i32),
|
||||
MessagesStreamEvent::ContentBlockDelta {
|
||||
index, ..
|
||||
} => Some(*index as i32),
|
||||
MessagesStreamEvent::ContentBlockStop { index, .. } => {
|
||||
Some(*index as i32)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Check event type to track ContentBlockStart
|
||||
if let Some(event_type) = provider_response.event_type() {
|
||||
match event_type {
|
||||
"content_block_start" => {
|
||||
// Mark that we've seen ContentBlockStart for this index
|
||||
if let (Some(decoder), Some(index)) =
|
||||
(self.binary_frame_decoder.as_mut(), event_index)
|
||||
{
|
||||
decoder.set_content_block_start_sent(index);
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}",
|
||||
self.request_identifier(),
|
||||
index
|
||||
);
|
||||
}
|
||||
}
|
||||
"content_block_delta" => {
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
if let Some(index) = event_index {
|
||||
let needs_start = if let Some(decoder) =
|
||||
self.binary_frame_decoder.as_ref()
|
||||
{
|
||||
!decoder.has_content_block_start_been_sent(index)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if needs_start {
|
||||
// Emit empty ContentBlockStart before delta
|
||||
use hermesllm::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesStreamEvent,
|
||||
};
|
||||
let content_block_start =
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index: index as u32,
|
||||
content_block: MessagesContentBlock::Text {
|
||||
text: String::new(),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
let start_sse: String = content_block_start.into();
|
||||
response_buffer
|
||||
.extend_from_slice(start_sse.as_bytes());
|
||||
|
||||
// Mark that we've now sent it
|
||||
if let Some(decoder) =
|
||||
self.binary_frame_decoder.as_mut()
|
||||
{
|
||||
decoder.set_content_block_start_sent(index);
|
||||
}
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}",
|
||||
self.request_identifier(),
|
||||
index
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let sse_string: String = provider_response.into();
|
||||
response_buffer.extend_from_slice(sse_string.as_bytes());
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(DecodedFrame::Incomplete) => {
|
||||
// Incomplete frame - buffer retains partial data, wait for more bytes
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
|
||||
self.request_identifier()
|
||||
);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
// Decode error
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR",
|
||||
self.request_identifier()
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return accumulated complete frames (may be empty if all frames incomplete)
|
||||
Ok(response_buffer)
|
||||
}
|
||||
|
||||
fn handle_non_streaming_response(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue