mirror of
https://github.com/katanemo/plano.git
synced 2026-05-04 13:23:00 +02:00
Add support for Amazon Bedrock Converse and ConverseStream (#588)
* first commit to get Bedrock Converse API working. Next commit support for streaming and binary frames * adding translation from BedrockBinaryFrameDecoder to AnthropicMessagesEvent * Claude Code works with Amazon Bedrock * added tests for openai streaming from bedrock * PR comments fixed * adding support for bedrock in docs as supported provider * cargo fmt * revertted to chatgpt models for claude code routing --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-288.local> Co-authored-by: Adil Hafeez <adil.hafeez@gmail.com>
This commit is contained in:
parent
ba826b1961
commit
9407ae6af7
35 changed files with 7362 additions and 1493 deletions
|
|
@ -4,12 +4,16 @@
|
|||
pub mod apis;
|
||||
pub mod clients;
|
||||
pub mod providers;
|
||||
pub mod transforms;
|
||||
// Re-export important types and traits
|
||||
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
pub use apis::sse::{SseEvent, SseStreamIter};
|
||||
pub use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
pub use providers::id::ProviderId;
|
||||
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use providers::response::{
|
||||
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
|
||||
ProviderStreamResponseType, SseEvent, SseStreamIter, TokenUsage,
|
||||
ProviderStreamResponseType, TokenUsage,
|
||||
};
|
||||
|
||||
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
|
||||
|
|
@ -18,6 +22,8 @@ pub const MESSAGES_PATH: &str = "/v1/messages";
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
|
@ -40,7 +46,7 @@ mod tests {
|
|||
let client_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test the new simplified architecture - create SseStreamIter directly
|
||||
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
|
||||
|
|
@ -77,4 +83,156 @@ mod tests {
|
|||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
}
|
||||
|
||||
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
|
||||
///
|
||||
/// This test demonstrates how to:
|
||||
/// 1. Use MessageFrameDecoder to decode AWS Event Stream frames
|
||||
/// 2. Handle chunked network arrivals with buffering
|
||||
/// 3. Extract event types from message headers
|
||||
/// 4. Parse JSON payloads from decoded messages
|
||||
/// 5. Reconstruct streaming content from contentBlockDelta events
|
||||
///
|
||||
/// The decoder handles frame boundaries automatically - you just keep calling
|
||||
/// decode_frame() until it returns Incomplete, which means you've processed
|
||||
/// all complete frames in the buffer.
|
||||
#[test]
|
||||
fn test_amazon_bedrock_streaming_response() {
|
||||
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
||||
use bytes::{Buf, BytesMut};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Read the response.hex file from tests/e2e directory
|
||||
// Use absolute path to avoid cargo test working directory issues
|
||||
let test_file =
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
|
||||
let response_data = fs::read(&test_file)
|
||||
.unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e));
|
||||
|
||||
println!("📊 Response data size: {} bytes\n", response_data.len());
|
||||
|
||||
// Create decoder and buffer that implements Buf trait
|
||||
// BytesMut automatically tracks position as decoder advances it!
|
||||
let mut decoder = MessageFrameDecoder::new();
|
||||
let mut simulated_network_buffer = BytesMut::new();
|
||||
let mut frame_count = 0;
|
||||
let mut content_chunks = Vec::new();
|
||||
|
||||
// Simulate chunked network arrivals - process as data comes in
|
||||
let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000];
|
||||
let mut offset = 0;
|
||||
let mut chunk_num = 0;
|
||||
|
||||
println!("🔄 Simulating chunked network arrivals...\n");
|
||||
|
||||
// Process chunks as they "arrive" from the network
|
||||
while offset < response_data.len() {
|
||||
// Receive next chunk from network
|
||||
let chunk_size = chunk_sizes[chunk_num % chunk_sizes.len()];
|
||||
let end = (offset + chunk_size).min(response_data.len());
|
||||
let chunk = &response_data[offset..end];
|
||||
|
||||
chunk_num += 1;
|
||||
simulated_network_buffer.extend_from_slice(chunk);
|
||||
offset = end;
|
||||
|
||||
println!(
|
||||
"📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)",
|
||||
chunk_num,
|
||||
chunk.len(),
|
||||
simulated_network_buffer.len(),
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
// Try to decode all complete frames from buffer
|
||||
// The Buf trait tracks position automatically!
|
||||
loop {
|
||||
let bytes_before = simulated_network_buffer.remaining();
|
||||
match decoder.decode_frame(&mut simulated_network_buffer) {
|
||||
Ok(DecodedFrame::Complete(message)) => {
|
||||
frame_count += 1;
|
||||
let consumed = bytes_before - simulated_network_buffer.remaining();
|
||||
|
||||
println!(
|
||||
" ✅ Frame {}: decoded ({} bytes, {} bytes remaining)",
|
||||
frame_count,
|
||||
consumed,
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
// Get event type from headers
|
||||
let event_type = message
|
||||
.headers()
|
||||
.iter()
|
||||
.find(|h| h.name().as_str() == ":event-type")
|
||||
.and_then(|h| {
|
||||
h.value().as_string().ok().map(|s| s.as_str().to_string())
|
||||
});
|
||||
|
||||
if let Some(ref evt) = event_type {
|
||||
println!(" Event: {}", evt);
|
||||
}
|
||||
|
||||
// Parse payload and extract content
|
||||
let payload = message.payload();
|
||||
if !payload.is_empty() {
|
||||
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(payload) {
|
||||
if event_type.as_deref() == Some("contentBlockDelta") {
|
||||
if let Some(delta) = json.get("delta") {
|
||||
if let Some(text) =
|
||||
delta.get("text").and_then(|t| t.as_str())
|
||||
{
|
||||
println!(" 📝 Content: \"{}\"", text);
|
||||
content_chunks.push(text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // Continue loop to check for more complete frames in buffer
|
||||
}
|
||||
Ok(DecodedFrame::Incomplete) => {
|
||||
// Not enough data for a complete frame - need more chunks
|
||||
println!(
|
||||
" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
break; // Wait for next chunk
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("❌ Frame decode error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("📋 Summary:");
|
||||
println!(" Total chunks received: {}", chunk_num);
|
||||
println!(" Total frames decoded: {}", frame_count);
|
||||
println!(" Total content chunks: {}", content_chunks.len());
|
||||
println!(
|
||||
" Final buffer remaining: {} bytes",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
if !content_chunks.is_empty() {
|
||||
let full_text = content_chunks.join("");
|
||||
println!("\n📄 Full reconstructed content:");
|
||||
println!("{}", full_text);
|
||||
println!("\n Characters: {}", full_text.len());
|
||||
println!(" Estimated tokens: ~{}", full_text.len() / 4);
|
||||
}
|
||||
|
||||
// Ensure we decoded at least one frame
|
||||
assert!(frame_count > 0, "Should decode at least one frame");
|
||||
|
||||
// Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame
|
||||
assert_eq!(
|
||||
simulated_network_buffer.remaining(),
|
||||
0,
|
||||
"All bytes should be consumed, {} bytes remain",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue