Add support for Amazon Bedrock Converse and ConverseStream (#588)

* first commit to get Bedrock Converse API working. Next commit support for streaming and binary frames

* adding translation from BedrockBinaryFrameDecoder to AnthropicMessagesEvent

* Claude Code works with Amazon Bedrock

* added tests for openai streaming from bedrock

* PR comments fixed

* adding support for bedrock in docs as supported provider

* cargo fmt

* revertted to chatgpt models for claude code routing

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-288.local>
Co-authored-by: Adil Hafeez <adil.hafeez@gmail.com>
This commit is contained in:
Salman Paracha 2025-10-22 11:31:21 -07:00 committed by GitHub
parent ba826b1961
commit 9407ae6af7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 7362 additions and 1493 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,65 @@
use aws_smithy_eventstream::frame::DecodedFrame;
use aws_smithy_eventstream::frame::MessageFrameDecoder;
use bytes::Buf;
use std::collections::HashSet;
/// AWS Event Stream frame decoder wrapper
pub struct BedrockBinaryFrameDecoder<B>
where
B: Buf,
{
decoder: MessageFrameDecoder,
buffer: B,
content_block_start_indices: HashSet<i32>,
}
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
/// This is a convenience constructor that creates a BytesMut buffer internally
pub fn from_bytes(bytes: &[u8]) -> Self {
let buffer = bytes::BytesMut::from(bytes);
Self {
decoder: MessageFrameDecoder::new(),
buffer,
content_block_start_indices: std::collections::HashSet::new(),
}
}
}
impl<B> BedrockBinaryFrameDecoder<B>
where
B: Buf,
{
pub fn new(buffer: B) -> Self {
Self {
decoder: MessageFrameDecoder::new(),
buffer,
content_block_start_indices: HashSet::new(),
}
}
pub fn decode_frame(&mut self) -> Option<DecodedFrame> {
match self.decoder.decode_frame(&mut self.buffer) {
Ok(frame) => Some(frame),
Err(_e) => None, // Fatal decode error
}
}
pub fn buffer_mut(&mut self) -> &mut B {
&mut self.buffer
}
/// Check if there are any bytes remaining in the buffer
pub fn has_remaining(&self) -> bool {
self.buffer.has_remaining()
}
/// Check if a content_block_start event has been sent for the given index
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
self.content_block_start_indices.contains(&index)
}
/// Mark that a content_block_start event has been sent for the given index
pub fn set_content_block_start_sent(&mut self, index: i32) {
self.content_block_start_indices.insert(index);
}
}

View file

@ -5,9 +5,9 @@ use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
use crate::clients::transformer::ExtractText;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
use crate::transforms::lib::ExtractText;
use crate::MESSAGES_PATH;
// Enum for all supported Anthropic APIs

View file

@ -1,7 +1,19 @@
pub mod amazon_bedrock;
pub mod amazon_bedrock_binary_frame;
pub mod anthropic;
pub mod openai;
pub use anthropic::*;
pub use openai::*;
pub mod sse;
// Explicit exports to avoid naming conflicts
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
pub use amazon_bedrock::{
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
};
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
pub use openai::{
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
};
pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice};
pub trait ApiDefinition {
/// Returns the endpoint path for this API

View file

@ -6,9 +6,9 @@ use std::fmt::Display;
use thiserror::Error;
use super::ApiDefinition;
use crate::clients::transformer::ExtractText;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
use crate::transforms::lib::ExtractText;
use crate::CHAT_COMPLETIONS_PATH;
// ============================================================================

View file

@ -0,0 +1,196 @@
use crate::providers::response::ProviderStreamResponse;
use crate::providers::response::ProviderStreamResponseType;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use std::str::FromStr;
// ============================================================================
// SSE EVENT CONTAINER
// ============================================================================
/// Represents a single Server-Sent Event with the complete wire format
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SseEvent {
#[serde(rename = "data")]
pub data: Option<String>, // The JSON payload after "data: "
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
#[serde(skip_serializing, skip_deserializing)]
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
}
impl SseEvent {
/// Check if this event represents the end of the stream
pub fn is_done(&self) -> bool {
self.data == Some("[DONE]".into())
}
/// Check if this event should be skipped during processing
/// This includes ping messages and other provider-specific events that don't contain content
pub fn should_skip(&self) -> bool {
// Skip ping messages (commonly used by providers for connection keep-alive)
self.data == Some(r#"{"type": "ping"}"#.into())
}
/// Check if this is an event-only SSE event (no data payload)
pub fn is_event_only(&self) -> bool {
self.event.is_some() && self.data.is_none()
}
/// Get the parsed provider response if available
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
self.provider_stream_response
.as_ref()
.map(|resp| resp as &dyn ProviderStreamResponse)
.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
})
}
}
impl FromStr for SseEvent {
type Err = SseParseError;
fn from_str(line: &str) -> Result<Self, Self::Err> {
if line.starts_with("data: ") {
let data: String = line[6..].to_string(); // Remove "data: " prefix
if data.is_empty() {
return Err(SseParseError {
message: "Empty data field is not a valid SSE event".to_string(),
});
}
Ok(SseEvent {
data: Some(data),
event: None,
raw_line: line.to_string(),
sse_transform_buffer: line.to_string(),
provider_stream_response: None,
})
} else if line.starts_with("event: ") {
//used by Anthropic
let event_type = line[7..].to_string();
if event_type.is_empty() {
return Err(SseParseError {
message: "Empty event field is not a valid SSE event".to_string(),
});
}
Ok(SseEvent {
data: None,
event: Some(event_type),
raw_line: line.to_string(),
sse_transform_buffer: line.to_string(),
provider_stream_response: None,
})
} else {
Err(SseParseError {
message: format!("Line does not start with 'data: ' or 'event: ': {}", line),
})
}
}
}
impl fmt::Display for SseEvent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.sse_transform_buffer)
}
}
// Into implementation to convert SseEvent to bytes for response buffer
impl Into<Vec<u8>> for SseEvent {
fn into(self) -> Vec<u8> {
format!("{}\n\n", self.sse_transform_buffer).into_bytes()
}
}
#[derive(Debug)]
pub struct SseParseError {
pub message: String,
}
impl fmt::Display for SseParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SSE parse error: {}", self.message)
}
}
impl Error for SseParseError {}
/// Generic SSE (Server-Sent Events) streaming iterator container
/// Parses raw SSE lines into SseEvent objects
pub struct SseStreamIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub lines: I,
pub done_seen: bool,
}
impl<I> SseStreamIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub fn new(lines: I) -> Self {
Self {
lines,
done_seen: false,
}
}
}
// TryFrom implementation to parse bytes into SseStreamIter
// Handles both text-based SSE and binary AWS Event Stream formats
impl TryFrom<&[u8]> for SseStreamIter<std::vec::IntoIter<String>> {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
// Parse as text-based SSE format
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
Ok(SseStreamIter::new(lines.into_iter()))
}
}
impl<I> Iterator for SseStreamIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
type Item = SseEvent;
fn next(&mut self) -> Option<Self::Item> {
// If we already returned [DONE], terminate the stream
if self.done_seen {
return None;
}
for line in &mut self.lines {
let line_str = line.as_ref();
// Try to parse as either data: or event: line
if let Ok(event) = line_str.parse::<SseEvent>() {
// For data: lines, check if this is the [DONE] marker
if event.data.is_some() && event.is_done() {
self.done_seen = true;
return Some(event); // Return [DONE] event for transformation
}
// For data: lines, skip events that should be filtered at the transport layer
if event.data.is_some() && event.should_skip() {
continue;
}
return Some(event);
}
}
None
}
}

View file

@ -1,30 +1,5 @@
//! Supported endpoint registry for LLM APIs
//!
//! This module provides a simple registry to check which API endpoint paths
//! we support across different providers.
//!
//! # Examples
//!
//! ```rust
//! use hermesllm::clients::endpoints::supported_endpoints;
//!
//! // Check if we support an endpoint
//! use hermesllm::clients::endpoints::SupportedAPIs;
//! assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
//! assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
//! assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
//!
//! // Get all supported endpoints
//! let endpoints = supported_endpoints();
//! assert_eq!(endpoints.len(), 2);
//! assert!(endpoints.contains(&"/v1/chat/completions"));
//! assert!(endpoints.contains(&"/v1/messages"));
//! ```
use crate::{
apis::{AnthropicApi, ApiDefinition, OpenAIApi},
ProviderId,
};
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi};
use crate::ProviderId;
use std::fmt;
/// Unified enum representing all supported API endpoints across providers
@ -34,6 +9,14 @@ pub enum SupportedAPIs {
AnthropicMessagesAPI(AnthropicApi),
}
#[derive(Debug, Clone, PartialEq)]
pub enum SupportedUpstreamAPIs {
OpenAIChatCompletions(OpenAIApi),
AnthropicMessagesAPI(AnthropicApi),
AmazonBedrockConverse(AmazonBedrockApi),
AmazonBedrockConverseStream(AmazonBedrockApi),
}
impl fmt::Display for SupportedAPIs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@ -74,11 +57,21 @@ impl SupportedAPIs {
provider_id: &ProviderId,
request_path: &str,
model_id: &str,
is_streaming: bool,
) -> String {
let default_endpoint = "/v1/chat/completions".to_string();
match self {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
ProviderId::Anthropic => "/v1/messages".to_string(),
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") && !is_streaming {
format!("/model/{}/converse", model_id)
} else if request_path.starts_with("/v1/") && is_streaming {
format!("/model/{}/converse-stream", model_id)
} else {
default_endpoint
}
}
_ => default_endpoint,
},
_ => match provider_id {
@ -117,6 +110,17 @@ impl SupportedAPIs {
default_endpoint
}
}
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") {
if !is_streaming {
format!("/model/{}/converse", model_id)
} else {
format!("/model/{}/converse-stream", model_id)
}
} else {
default_endpoint
}
}
_ => default_endpoint,
},
}
@ -161,7 +165,6 @@ mod tests {
fn test_is_supported_endpoint() {
// OpenAI endpoints
assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
// Anthropic endpoints
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
@ -174,7 +177,7 @@ mod tests {
#[test]
fn test_supported_endpoints() {
let endpoints = supported_endpoints();
assert_eq!(endpoints.len(), 2);
assert_eq!(endpoints.len(), 2); // We have 2 APIs defined
assert!(endpoints.contains(&"/v1/chat/completions"));
assert!(endpoints.contains(&"/v1/messages"));
}
@ -217,7 +220,6 @@ mod tests {
endpoint
);
}
// Total should match
assert_eq!(
endpoints.len(),

File diff suppressed because it is too large Load diff

View file

@ -4,12 +4,16 @@
pub mod apis;
pub mod clients;
pub mod providers;
pub mod transforms;
// Re-export important types and traits
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
pub use apis::sse::{SseEvent, SseStreamIter};
pub use aws_smithy_eventstream::frame::DecodedFrame;
pub use providers::id::ProviderId;
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use providers::response::{
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
ProviderStreamResponseType, SseEvent, SseStreamIter, TokenUsage,
ProviderStreamResponseType, TokenUsage,
};
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
@ -18,6 +22,8 @@ pub const MESSAGES_PATH: &str = "/v1/messages";
#[cfg(test)]
mod tests {
use crate::clients::endpoints::SupportedUpstreamAPIs;
use super::*;
#[test]
@ -40,7 +46,7 @@ mod tests {
let client_api =
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
let upstream_api =
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
// Test the new simplified architecture - create SseStreamIter directly
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
@ -77,4 +83,156 @@ mod tests {
let final_event = streaming_iter.next();
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
}
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
///
/// This test demonstrates how to:
/// 1. Use MessageFrameDecoder to decode AWS Event Stream frames
/// 2. Handle chunked network arrivals with buffering
/// 3. Extract event types from message headers
/// 4. Parse JSON payloads from decoded messages
/// 5. Reconstruct streaming content from contentBlockDelta events
///
/// The decoder handles frame boundaries automatically - you just keep calling
/// decode_frame() until it returns Incomplete, which means you've processed
/// all complete frames in the buffer.
#[test]
fn test_amazon_bedrock_streaming_response() {
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use bytes::{Buf, BytesMut};
use std::fs;
use std::path::PathBuf;
// Read the response.hex file from tests/e2e directory
// Use absolute path to avoid cargo test working directory issues
let test_file =
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
let response_data = fs::read(&test_file)
.unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e));
println!("📊 Response data size: {} bytes\n", response_data.len());
// Create decoder and buffer that implements Buf trait
// BytesMut automatically tracks position as decoder advances it!
let mut decoder = MessageFrameDecoder::new();
let mut simulated_network_buffer = BytesMut::new();
let mut frame_count = 0;
let mut content_chunks = Vec::new();
// Simulate chunked network arrivals - process as data comes in
let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000];
let mut offset = 0;
let mut chunk_num = 0;
println!("🔄 Simulating chunked network arrivals...\n");
// Process chunks as they "arrive" from the network
while offset < response_data.len() {
// Receive next chunk from network
let chunk_size = chunk_sizes[chunk_num % chunk_sizes.len()];
let end = (offset + chunk_size).min(response_data.len());
let chunk = &response_data[offset..end];
chunk_num += 1;
simulated_network_buffer.extend_from_slice(chunk);
offset = end;
println!(
"📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)",
chunk_num,
chunk.len(),
simulated_network_buffer.len(),
simulated_network_buffer.remaining()
);
// Try to decode all complete frames from buffer
// The Buf trait tracks position automatically!
loop {
let bytes_before = simulated_network_buffer.remaining();
match decoder.decode_frame(&mut simulated_network_buffer) {
Ok(DecodedFrame::Complete(message)) => {
frame_count += 1;
let consumed = bytes_before - simulated_network_buffer.remaining();
println!(
" ✅ Frame {}: decoded ({} bytes, {} bytes remaining)",
frame_count,
consumed,
simulated_network_buffer.remaining()
);
// Get event type from headers
let event_type = message
.headers()
.iter()
.find(|h| h.name().as_str() == ":event-type")
.and_then(|h| {
h.value().as_string().ok().map(|s| s.as_str().to_string())
});
if let Some(ref evt) = event_type {
println!(" Event: {}", evt);
}
// Parse payload and extract content
let payload = message.payload();
if !payload.is_empty() {
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(payload) {
if event_type.as_deref() == Some("contentBlockDelta") {
if let Some(delta) = json.get("delta") {
if let Some(text) =
delta.get("text").and_then(|t| t.as_str())
{
println!(" 📝 Content: \"{}\"", text);
content_chunks.push(text.to_string());
}
}
}
}
} // Continue loop to check for more complete frames in buffer
}
Ok(DecodedFrame::Incomplete) => {
// Not enough data for a complete frame - need more chunks
println!(
" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n",
simulated_network_buffer.remaining()
);
break; // Wait for next chunk
}
Err(e) => {
panic!("❌ Frame decode error: {}", e);
}
}
}
}
println!("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("📋 Summary:");
println!(" Total chunks received: {}", chunk_num);
println!(" Total frames decoded: {}", frame_count);
println!(" Total content chunks: {}", content_chunks.len());
println!(
" Final buffer remaining: {} bytes",
simulated_network_buffer.remaining()
);
if !content_chunks.is_empty() {
let full_text = content_chunks.join("");
println!("\n📄 Full reconstructed content:");
println!("{}", full_text);
println!("\n Characters: {}", full_text.len());
println!(" Estimated tokens: ~{}", full_text.len() / 4);
}
// Ensure we decoded at least one frame
assert!(frame_count > 0, "Should decode at least one frame");
// Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame
assert_eq!(
simulated_network_buffer.remaining(),
0,
"All bytes should be consumed, {} bytes remain",
simulated_network_buffer.remaining()
);
}
}

View file

@ -1,5 +1,5 @@
use crate::apis::{AnthropicApi, OpenAIApi};
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
use std::fmt::Display;
/// Provider identifier enum - simple enum for identifying providers
@ -19,7 +19,8 @@ pub enum ProviderId {
Ollama,
Moonshotai,
Zhipu,
Qwen, // alias for Qwen
Qwen,
AmazonBedrock,
}
impl From<&str> for ProviderId {
@ -39,7 +40,8 @@ impl From<&str> for ProviderId {
"ollama" => ProviderId::Ollama,
"moonshotai" => ProviderId::Moonshotai,
"zhipu" => ProviderId::Zhipu,
"qwen" => ProviderId::Qwen, // alias for Zhipu
"qwen" => ProviderId::Qwen, // alias for Qwen
"amazon_bedrock" => ProviderId::AmazonBedrock,
_ => panic!("Unknown provider: {}", value),
}
}
@ -47,16 +49,20 @@ impl From<&str> for ProviderId {
impl ProviderId {
/// Given a client API, return the compatible upstream API for this provider
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs {
pub fn compatible_api_for_client(
&self,
client_api: &SupportedAPIs,
is_streaming: bool,
) -> SupportedUpstreamAPIs {
match (self, client_api) {
// Claude/Anthropic providers natively support Anthropic APIs
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
}
(
ProviderId::Anthropic,
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// OpenAI-compatible providers only support OpenAI chat completions
(
@ -75,7 +81,7 @@ impl ProviderId {
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::AnthropicMessagesAPI(_),
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(
ProviderId::OpenAI
@ -93,7 +99,27 @@ impl ProviderId {
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::OpenAIChatCompletions(_),
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// Amazon Bedrock natively supports Bedrock APIs
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
)
} else {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
}
}
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
)
} else {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
}
}
}
}
}
@ -116,6 +142,7 @@ impl Display for ProviderId {
ProviderId::Moonshotai => write!(f, "moonshotai"),
ProviderId::Zhipu => write!(f, "zhipu"),
ProviderId::Qwen => write!(f, "qwen"),
ProviderId::AmazonBedrock => write!(f, "amazon_bedrock"),
}
}
}

View file

@ -1,6 +1,9 @@
use crate::apis::anthropic::MessagesRequest;
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
use crate::clients::endpoints::SupportedAPIs;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use serde_json::Value;
use std::collections::HashMap;
@ -10,6 +13,8 @@ use std::fmt;
pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest),
MessagesRequest(MessagesRequest),
BedrockConverse(ConverseRequest),
BedrockConverseStream(ConverseStreamRequest),
//add more request types here
}
pub trait ProviderRequest: Send + Sync {
@ -42,6 +47,8 @@ impl ProviderRequest for ProviderRequestType {
match self {
Self::ChatCompletionsRequest(r) => r.model(),
Self::MessagesRequest(r) => r.model(),
Self::BedrockConverse(r) => r.model(),
Self::BedrockConverseStream(r) => r.model(),
}
}
@ -49,6 +56,8 @@ impl ProviderRequest for ProviderRequestType {
match self {
Self::ChatCompletionsRequest(r) => r.set_model(model),
Self::MessagesRequest(r) => r.set_model(model),
Self::BedrockConverse(r) => r.set_model(model),
Self::BedrockConverseStream(r) => r.set_model(model),
}
}
@ -56,6 +65,8 @@ impl ProviderRequest for ProviderRequestType {
match self {
Self::ChatCompletionsRequest(r) => r.is_streaming(),
Self::MessagesRequest(r) => r.is_streaming(),
Self::BedrockConverse(_) => false,
Self::BedrockConverseStream(_) => true,
}
}
@ -63,6 +74,8 @@ impl ProviderRequest for ProviderRequestType {
match self {
Self::ChatCompletionsRequest(r) => r.extract_messages_text(),
Self::MessagesRequest(r) => r.extract_messages_text(),
Self::BedrockConverse(r) => r.extract_messages_text(),
Self::BedrockConverseStream(r) => r.extract_messages_text(),
}
}
@ -70,6 +83,8 @@ impl ProviderRequest for ProviderRequestType {
match self {
Self::ChatCompletionsRequest(r) => r.get_recent_user_message(),
Self::MessagesRequest(r) => r.get_recent_user_message(),
Self::BedrockConverse(r) => r.get_recent_user_message(),
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
}
}
@ -77,6 +92,8 @@ impl ProviderRequest for ProviderRequestType {
match self {
Self::ChatCompletionsRequest(r) => r.to_bytes(),
Self::MessagesRequest(r) => r.to_bytes(),
Self::BedrockConverse(r) => r.to_bytes(),
Self::BedrockConverseStream(r) => r.to_bytes(),
}
}
@ -84,6 +101,8 @@ impl ProviderRequest for ProviderRequestType {
match self {
Self::ChatCompletionsRequest(r) => r.metadata(),
Self::MessagesRequest(r) => r.metadata(),
Self::BedrockConverse(r) => r.metadata(),
Self::BedrockConverseStream(r) => r.metadata(),
}
}
@ -91,6 +110,8 @@ impl ProviderRequest for ProviderRequestType {
match self {
Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key),
Self::MessagesRequest(r) => r.remove_metadata_key(key),
Self::BedrockConverse(r) => r.remove_metadata_key(key),
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
}
}
}
@ -120,27 +141,27 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
}
/// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs)
impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType {
type Error = ProviderRequestError;
fn try_from(
(request, upstream_api): (ProviderRequestType, &SupportedAPIs),
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
match (request, upstream_api) {
match (client_request, upstream_api) {
// Same API - no conversion needed, just clone the reference
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedAPIs::OpenAIChatCompletions(_),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedAPIs::AnthropicMessagesAPI(_),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
// Cross-API conversion - cloning is necessary for transformation
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedAPIs::AnthropicMessagesAPI(_),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => {
let messages_req =
MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError {
@ -155,7 +176,7 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedAPIs::OpenAIChatCompletions(_),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => {
let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| {
ProviderRequestError {
@ -168,6 +189,69 @@ impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
})?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
let bedrock_req = ConverseRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
source: Some(Box::new(e))
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
source: Some(Box::new(e))
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
let bedrock_req =
ConverseRequest::try_from(messages_req).map_err(|e| ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
e
),
source: Some(Box::new(e)),
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
// Amazon Bedrock to other APIs conversions
(ProviderRequestType::BedrockConverse(_), _) => {
todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet")
}
(ProviderRequestType::BedrockConverseStream(_), _) => {
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
}
}
}
}
@ -201,7 +285,7 @@ mod tests {
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::clients::endpoints::SupportedAPIs;
use crate::clients::transformer::ExtractText;
use crate::transforms::lib::ExtractText;
use serde_json::json;
#[test]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,231 @@
use crate::apis::anthropic::{MessagesContentBlock, MessagesImageSource};
use crate::apis::openai::{ContentPart, FunctionCall, ImageUrl, Message, MessageContent, ToolCall};
use crate::clients::TransformError;
use serde_json::Value;
use std::time::{SystemTime, UNIX_EPOCH};
pub trait ExtractText {
fn extract_text(&self) -> String;
}
/// Trait for utility functions on content collections
pub trait ContentUtils<T> {
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
}
/// Helper to create a current unix timestamp
pub fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
// Content Utilities
impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError> {
let mut tool_calls = Vec::new();
for block in self {
match block {
MessagesContentBlock::ToolUse {
id, name, input, ..
}
| MessagesContentBlock::ServerToolUse { id, name, input }
| MessagesContentBlock::McpToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input)?;
tool_calls.push(ToolCall {
id: id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: name.clone(),
arguments,
},
});
}
_ => continue,
}
}
Ok(if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
})
}
fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>
{
let mut content_parts = Vec::new();
let mut tool_calls = Vec::new();
let mut tool_results = Vec::new();
for block in self {
match block {
MessagesContentBlock::Text { text, .. } => {
content_parts.push(ContentPart::Text { text: text.clone() });
}
MessagesContentBlock::Image { source } => {
let url = convert_image_source_to_url(source);
content_parts.push(ContentPart::ImageUrl {
image_url: ImageUrl {
url,
detail: Some("auto".to_string()),
},
});
}
MessagesContentBlock::ToolUse {
id, name, input, ..
}
| MessagesContentBlock::ServerToolUse { id, name, input }
| MessagesContentBlock::McpToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input)?;
tool_calls.push(ToolCall {
id: id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: name.clone(),
arguments,
},
});
}
MessagesContentBlock::ToolResult {
tool_use_id,
content,
is_error,
..
} => {
let result_text = content.extract_text();
tool_results.push((
tool_use_id.clone(),
result_text,
is_error.unwrap_or(false),
));
}
MessagesContentBlock::WebSearchToolResult {
tool_use_id,
content,
is_error,
}
| MessagesContentBlock::CodeExecutionToolResult {
tool_use_id,
content,
is_error,
}
| MessagesContentBlock::McpToolResult {
tool_use_id,
content,
is_error,
} => {
let result_text = content.extract_text();
tool_results.push((
tool_use_id.clone(),
result_text,
is_error.unwrap_or(false),
));
}
_ => {
// Skip unsupported content types
continue;
}
}
}
Ok((content_parts, tool_calls, tool_results))
}
}
/// Convert image source to URL
pub fn convert_image_source_to_url(source: &MessagesImageSource) -> String {
match source {
MessagesImageSource::Base64 { media_type, data } => {
format!("data:{};base64,{}", media_type, data)
}
MessagesImageSource::Url { url } => url.clone(),
}
}
/// Convert image URL to Anthropic image source
fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource {
if image_url.url.starts_with("data:") {
// Parse data URL
let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
if parts.len() == 2 {
let header = parts[0];
let data = parts[1];
let media_type = header
.strip_prefix("data:")
.and_then(|s| s.split(';').next())
.unwrap_or("image/jpeg")
.to_string();
MessagesImageSource::Base64 {
media_type,
data: data.to_string(),
}
} else {
MessagesImageSource::Url {
url: image_url.url.clone(),
}
}
} else {
MessagesImageSource::Url {
url: image_url.url.clone(),
}
}
}
/// Convert OpenAI message to Anthropic content blocks
pub fn convert_openai_message_to_anthropic_content(
message: &Message,
) -> Result<Vec<MessagesContentBlock>, TransformError> {
let mut blocks = Vec::new();
// Handle regular content
match &message.content {
MessageContent::Text(text) => {
if !text.is_empty() {
blocks.push(MessagesContentBlock::Text {
text: text.clone(),
cache_control: None,
});
}
}
MessageContent::Parts(parts) => {
for part in parts {
match part {
ContentPart::Text { text } => {
blocks.push(MessagesContentBlock::Text {
text: text.clone(),
cache_control: None,
});
}
ContentPart::ImageUrl { image_url } => {
let source = convert_image_url_to_source(image_url);
blocks.push(MessagesContentBlock::Image { source });
}
}
}
}
}
// Handle tool calls
if let Some(tool_calls) = &message.tool_calls {
for tool_call in tool_calls {
let input: Value = serde_json::from_str(&tool_call.function.arguments)?;
blocks.push(MessagesContentBlock::ToolUse {
id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
input,
cache_control: None,
});
}
}
Ok(blocks)
}

View file

@ -0,0 +1,25 @@
//! API transformation modules
//!
//! This module provides organized transformations between the two main LLM API formats:
//! - `/v1/chat/completions` (OpenAI format)
//! - `/v1/messages` (Anthropic format)
//!
//! Provider-specific transformations (Bedrock, Groq, etc.) are handled internally
//! by the gateway, but the external API surface remains these two standard formats.
//! The transformations are split into logical modules for maintainability.
pub mod lib;
pub mod request;
pub mod response;
// Re-export commonly used items for convenience
pub use lib::*;
pub use request::*;
pub use response::*;
// ============================================================================
// CONSTANTS
// ============================================================================
/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified
pub const DEFAULT_MAX_TOKENS: u32 = 4096;

View file

@ -0,0 +1,704 @@
use crate::apis::amazon_bedrock::{
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, ImageBlock,
ImageSource, InferenceConfiguration, Message as BedrockMessage, SystemContentBlock,
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration,
ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ToolSpecDefinition,
ToolUseBlock,
};
use crate::apis::anthropic::{
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, MessagesStopReason,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage,
ToolResultContent,
};
use crate::apis::openai::{
ChatCompletionsRequest, ContentPart, FinishReason, Function, FunctionChoice, Message,
MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Usage,
};
use crate::clients::TransformError;
use crate::transforms::lib::*;
type AnthropicMessagesRequest = MessagesRequest;
// Conversion from Anthropic MessagesRequest to OpenAI ChatCompletionsRequest
impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
type Error = TransformError;
fn try_from(req: AnthropicMessagesRequest) -> Result<Self, Self::Error> {
let mut openai_messages: Vec<Message> = Vec::new();
// Convert system prompt to system message if present
if let Some(system) = req.system {
openai_messages.push(system.into());
}
// Convert messages
for message in req.messages {
let converted_messages: Vec<Message> = message.try_into()?;
openai_messages.extend(converted_messages);
}
// Convert tools and tool choice
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
let (openai_tool_choice, parallel_tool_calls) =
convert_anthropic_tool_choice(req.tool_choice);
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
model: req.model,
messages: openai_messages,
temperature: req.temperature,
top_p: req.top_p,
max_completion_tokens: Some(req.max_tokens),
stream: req.stream,
stop: req.stop_sequences,
tools: openai_tools,
tool_choice: openai_tool_choice,
parallel_tool_calls,
..Default::default()
};
_chat_completions_req.suppress_max_tokens_if_o3();
_chat_completions_req.fix_temperature_if_gpt5();
Ok(_chat_completions_req)
}
}
// Conversion from Anthropic MessagesRequest to Amazon Bedrock ConverseRequest
impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
type Error = TransformError;
fn try_from(req: AnthropicMessagesRequest) -> Result<Self, Self::Error> {
// Convert system prompt to SystemContentBlock if present
let system: Option<Vec<SystemContentBlock>> = req.system.map(|system_prompt| {
let text = match system_prompt {
MessagesSystemPrompt::Single(text) => text,
MessagesSystemPrompt::Blocks(blocks) => blocks.extract_text(),
};
vec![SystemContentBlock::Text { text }]
});
// Convert messages to Bedrock format
let messages = if req.messages.is_empty() {
None
} else {
let mut bedrock_messages = Vec::new();
for anthropic_message in req.messages {
let bedrock_message: BedrockMessage = anthropic_message.try_into()?;
bedrock_messages.push(bedrock_message);
}
Some(bedrock_messages)
};
// Build inference configuration
// Anthropic always requires max_tokens, so we should always include inferenceConfig
let inference_config = Some(InferenceConfiguration {
max_tokens: Some(req.max_tokens),
temperature: req.temperature,
top_p: req.top_p,
stop_sequences: req.stop_sequences,
});
// Convert tools and tool choice to ToolConfiguration
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
let tools = req.tools.map(|anthropic_tools| {
anthropic_tools
.into_iter()
.map(|tool| BedrockTool::ToolSpec {
tool_spec: ToolSpecDefinition {
name: tool.name,
description: tool.description,
input_schema: ToolInputSchema {
json: tool.input_schema,
},
},
})
.collect()
});
let tool_choice = req.tool_choice.map(|choice| {
match choice.kind {
MessagesToolChoiceType::Auto => BedrockToolChoice::Auto {
auto: AutoChoice {},
},
MessagesToolChoiceType::Any => BedrockToolChoice::Any { any: AnyChoice {} },
MessagesToolChoiceType::None => BedrockToolChoice::Auto {
auto: AutoChoice {},
}, // Bedrock doesn't have explicit "none"
MessagesToolChoiceType::Tool => {
if let Some(name) = choice.name {
BedrockToolChoice::Tool {
tool: ToolChoiceSpec { name },
}
} else {
BedrockToolChoice::Auto {
auto: AutoChoice {},
}
}
}
}
});
Some(ToolConfiguration { tools, tool_choice })
} else {
None
};
Ok(ConverseRequest {
model_id: req.model,
messages,
system,
inference_config,
tool_config,
stream: req.stream.unwrap_or(false),
guardrail_config: None,
additional_model_request_fields: None,
additional_model_response_field_paths: None,
performance_config: None,
prompt_variables: None,
request_metadata: None,
metadata: None,
})
}
}
// Message Conversions
impl TryFrom<MessagesMessage> for Vec<Message> {
type Error = TransformError;
fn try_from(message: MessagesMessage) -> Result<Self, Self::Error> {
let mut result = Vec::new();
match message.content {
MessagesMessageContent::Single(text) => {
result.push(Message {
role: message.role.into(),
content: MessageContent::Text(text),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
MessagesMessageContent::Blocks(blocks) => {
let (content_parts, tool_calls, tool_results) = blocks.split_for_openai()?;
// Add tool result messages
for (tool_use_id, result_text, _is_error) in tool_results {
result.push(Message {
role: Role::Tool,
content: MessageContent::Text(result_text),
name: None,
tool_calls: None,
tool_call_id: Some(tool_use_id),
});
}
// Only create main message if there's actual content or tool calls
// Skip creating empty content messages (e.g., when message only contains tool_result blocks)
if !content_parts.is_empty() || !tool_calls.is_empty() {
let content = build_openai_content(content_parts, &tool_calls);
let main_message = Message {
role: message.role.into(),
content,
name: None,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
};
result.push(main_message);
}
}
}
Ok(result)
}
}
// Role Conversions
impl Into<Role> for MessagesRole {
fn into(self) -> Role {
match self {
MessagesRole::User => Role::User,
MessagesRole::Assistant => Role::Assistant,
}
}
}
impl Into<MessagesStopReason> for FinishReason {
fn into(self) -> MessagesStopReason {
match self {
FinishReason::Stop => MessagesStopReason::EndTurn,
FinishReason::Length => MessagesStopReason::MaxTokens,
FinishReason::ToolCalls => MessagesStopReason::ToolUse,
FinishReason::ContentFilter => MessagesStopReason::Refusal,
FinishReason::FunctionCall => MessagesStopReason::ToolUse,
}
}
}
impl Into<MessagesUsage> for Usage {
fn into(self) -> MessagesUsage {
MessagesUsage {
input_tokens: self.prompt_tokens,
output_tokens: self.completion_tokens,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}
}
}
// System Prompt Conversions
impl Into<Message> for MessagesSystemPrompt {
fn into(self) -> Message {
let system_content = match self {
MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
};
Message {
role: Role::System,
content: system_content,
name: None,
tool_calls: None,
tool_call_id: None,
}
}
}
//Utility Functions
/// Convert Anthropic tools to OpenAI format
fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
tools
.into_iter()
.map(|tool| Tool {
tool_type: "function".to_string(),
function: Function {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
strict: None,
},
})
.collect()
}
/// Convert Anthropic tool choice to OpenAI format
fn convert_anthropic_tool_choice(
tool_choice: Option<MessagesToolChoice>,
) -> (Option<ToolChoice>, Option<bool>) {
match tool_choice {
Some(choice) => {
let openai_choice = match choice.kind {
MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto),
MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required),
MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None),
MessagesToolChoiceType::Tool => {
if let Some(name) = choice.name {
ToolChoice::Function {
choice_type: "function".to_string(),
function: FunctionChoice { name },
}
} else {
ToolChoice::Type(ToolChoiceType::Auto)
}
}
};
let parallel = choice.disable_parallel_tool_use.map(|disable| !disable);
(Some(openai_choice), parallel)
}
None => (None, None),
}
}
/// Build OpenAI message content from parts and tool calls
fn build_openai_content(
content_parts: Vec<ContentPart>,
tool_calls: &[ToolCall],
) -> MessageContent {
if content_parts.len() == 1 && tool_calls.is_empty() {
match &content_parts[0] {
ContentPart::Text { text } => MessageContent::Text(text.clone()),
_ => MessageContent::Parts(content_parts),
}
} else if content_parts.is_empty() {
MessageContent::Text("".to_string())
} else {
MessageContent::Parts(content_parts)
}
}
impl TryFrom<MessagesMessage> for BedrockMessage {
type Error = TransformError;
fn try_from(message: MessagesMessage) -> Result<Self, Self::Error> {
let role = match message.role {
MessagesRole::User => ConversationRole::User,
MessagesRole::Assistant => ConversationRole::Assistant,
};
let mut content_blocks = Vec::new();
// Convert content blocks
match message.content {
MessagesMessageContent::Single(text) => {
if !text.is_empty() {
content_blocks.push(ContentBlock::Text { text });
}
}
MessagesMessageContent::Blocks(blocks) => {
for block in blocks {
match block {
crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => {
if !text.is_empty() {
content_blocks.push(ContentBlock::Text { text });
}
}
crate::apis::anthropic::MessagesContentBlock::ToolUse {
id,
name,
input,
..
} => {
content_blocks.push(ContentBlock::ToolUse {
tool_use: ToolUseBlock {
tool_use_id: id,
name,
input,
},
});
}
crate::apis::anthropic::MessagesContentBlock::ToolResult {
tool_use_id,
is_error,
content,
..
} => {
// Convert Anthropic ToolResultContent to Bedrock ToolResultContentBlock
let tool_result_content = match content {
ToolResultContent::Text(text) => {
vec![ToolResultContentBlock::Text { text }]
}
ToolResultContent::Blocks(blocks) => {
let mut result_blocks = Vec::new();
for result_block in blocks {
match result_block {
crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => {
result_blocks.push(ToolResultContentBlock::Text { text });
}
// For now, skip other content types in tool results
_ => {}
}
}
result_blocks
}
};
// Ensure we have at least one content block
let final_content = if tool_result_content.is_empty() {
vec![ToolResultContentBlock::Text {
text: " ".to_string(),
}]
} else {
tool_result_content
};
let status = if is_error.unwrap_or(false) {
Some(ToolResultStatus::Error)
} else {
Some(ToolResultStatus::Success)
};
content_blocks.push(ContentBlock::ToolResult {
tool_result: ToolResultBlock {
tool_use_id,
content: final_content,
status,
},
});
}
crate::apis::anthropic::MessagesContentBlock::Image { source } => {
// Convert Anthropic image to Bedrock image format
match source {
crate::apis::anthropic::MessagesImageSource::Base64 {
media_type,
data,
} => {
content_blocks.push(ContentBlock::Image {
image: ImageBlock {
source: ImageSource::Base64 { media_type, data },
},
});
}
crate::apis::anthropic::MessagesImageSource::Url { .. } => {
// Bedrock doesn't support URL-based images, skip for now
// Could potentially download and convert to base64, but not implemented
}
}
}
// Skip other content types for now (Thinking, Document, etc.)
_ => {}
}
}
}
}
// Ensure we have at least one content block
if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
}
Ok(BedrockMessage {
role,
content: content_blocks,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::amazon_bedrock::{
ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
ToolChoice as BedrockToolChoice,
};
use crate::apis::anthropic::{
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
};
use serde_json::json;
#[test]
fn test_anthropic_to_bedrock_basic_request() {
let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Hello, how are you?".to_string()),
}],
max_tokens: 1000,
container: None,
mcp_servers: None,
system: Some(MessagesSystemPrompt::Single(
"You are a helpful assistant.".to_string(),
)),
metadata: None,
service_tier: None,
thinking: None,
temperature: Some(0.7),
top_p: Some(0.9),
top_k: None,
stream: Some(false),
stop_sequences: Some(vec!["STOP".to_string()]),
tools: None,
tool_choice: None,
};
let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap();
assert_eq!(bedrock_request.model_id, "claude-3-5-sonnet-20241022");
assert!(bedrock_request.system.is_some());
assert_eq!(bedrock_request.system.as_ref().unwrap().len(), 1);
assert!(bedrock_request.messages.is_some());
let messages = bedrock_request.messages.as_ref().unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, ConversationRole::User);
if let ContentBlock::Text { text } = &messages[0].content[0] {
assert_eq!(text, "Hello, how are you?");
} else {
panic!("Expected text content block");
}
let inference_config = bedrock_request.inference_config.as_ref().unwrap();
assert_eq!(inference_config.temperature, Some(0.7));
assert_eq!(inference_config.top_p, Some(0.9));
assert_eq!(inference_config.max_tokens, Some(1000));
assert_eq!(
inference_config.stop_sequences,
Some(vec!["STOP".to_string()])
);
}
#[test]
fn test_anthropic_to_bedrock_with_tools() {
let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("What's the weather like?".to_string()),
}],
max_tokens: 1000,
container: None,
mcp_servers: None,
system: None,
metadata: None,
service_tier: None,
thinking: None,
temperature: None,
top_p: None,
top_k: None,
stream: None,
stop_sequences: None,
tools: Some(vec![MessagesTool {
name: "get_weather".to_string(),
description: Some("Get current weather information".to_string()),
input_schema: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
}
},
"required": ["location"]
}),
}]),
tool_choice: Some(MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some("get_weather".to_string()),
disable_parallel_tool_use: None,
}),
};
let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap();
assert_eq!(bedrock_request.model_id, "claude-3-5-sonnet-20241022");
assert!(bedrock_request.tool_config.is_some());
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
assert!(tool_config.tools.is_some());
let tools = tool_config.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
assert_eq!(tool_spec.name, "get_weather");
assert_eq!(
tool_spec.description,
Some("Get current weather information".to_string())
);
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
assert_eq!(tool.name, "get_weather");
} else {
panic!("Expected specific tool choice");
}
}
#[test]
fn test_anthropic_to_bedrock_auto_tool_choice() {
let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Help me with something".to_string()),
}],
max_tokens: 500,
container: None,
mcp_servers: None,
system: None,
metadata: None,
service_tier: None,
thinking: None,
temperature: None,
top_p: None,
top_k: None,
stream: None,
stop_sequences: None,
tools: Some(vec![MessagesTool {
name: "help_tool".to_string(),
description: Some("A helpful tool".to_string()),
input_schema: json!({
"type": "object",
"properties": {}
}),
}]),
tool_choice: Some(MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: None,
}),
};
let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap();
assert!(bedrock_request.tool_config.is_some());
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
assert!(matches!(
tool_config.tool_choice,
Some(BedrockToolChoice::Auto { .. })
));
}
#[test]
fn test_anthropic_to_bedrock_multi_message_conversation() {
let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![
MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Hello".to_string()),
},
MessagesMessage {
role: MessagesRole::Assistant,
content: MessagesMessageContent::Single(
"Hi there! How can I help you?".to_string(),
),
},
MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("What's 2+2?".to_string()),
},
],
max_tokens: 100,
container: None,
mcp_servers: None,
system: Some(MessagesSystemPrompt::Single("Be concise".to_string())),
metadata: None,
service_tier: None,
thinking: None,
temperature: Some(0.5),
top_p: None,
top_k: None,
stream: None,
stop_sequences: None,
tools: None,
tool_choice: None,
};
let bedrock_request: ConverseRequest = anthropic_request.try_into().unwrap();
assert!(bedrock_request.messages.is_some());
let messages = bedrock_request.messages.as_ref().unwrap();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].role, ConversationRole::User);
assert_eq!(messages[1].role, ConversationRole::Assistant);
assert_eq!(messages[2].role, ConversationRole::User);
// Check system prompt
assert!(bedrock_request.system.is_some());
if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] {
assert_eq!(text, "Be concise");
} else {
panic!("Expected system text block");
}
}
#[test]
fn test_anthropic_message_to_bedrock_conversion() {
let anthropic_message = MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Test message".to_string()),
};
let bedrock_message: BedrockMessage = anthropic_message.try_into().unwrap();
assert_eq!(bedrock_message.role, ConversationRole::User);
assert_eq!(bedrock_message.content.len(), 1);
if let ContentBlock::Text { text } = &bedrock_message.content[0] {
assert_eq!(text, "Test message");
} else {
panic!("Expected text content block");
}
}
}

View file

@ -0,0 +1,782 @@
use crate::apis::amazon_bedrock::{
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, InferenceConfiguration,
Message as BedrockMessage, SystemContentBlock, Tool as BedrockTool,
ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration, ToolInputSchema,
ToolSpecDefinition,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
ToolResultContent,
};
use crate::apis::openai::{
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
};
use crate::clients::TransformError;
use crate::transforms::lib::ExtractText;
use crate::transforms::lib::*;
use crate::transforms::*;
type AnthropicMessagesRequest = MessagesRequest;
// ============================================================================
// MAIN REQUEST TRANSFORMATIONS
// ============================================================================
impl Into<MessagesSystemPrompt> for Message {
fn into(self) -> MessagesSystemPrompt {
let system_text = match self.content {
MessageContent::Text(text) => text,
MessageContent::Parts(parts) => parts.extract_text(),
};
MessagesSystemPrompt::Single(system_text)
}
}
impl TryFrom<Message> for MessagesMessage {
type Error = TransformError;
fn try_from(message: Message) -> Result<Self, Self::Error> {
let role = match message.role {
Role::User => MessagesRole::User,
Role::Assistant => MessagesRole::Assistant,
Role::Tool => {
// Tool messages become user messages with tool results
let tool_call_id = message.tool_call_id.ok_or_else(|| {
TransformError::MissingField(
"tool_call_id required for Tool messages".to_string(),
)
})?;
return Ok(MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Blocks(vec![
MessagesContentBlock::ToolResult {
tool_use_id: tool_call_id,
is_error: None,
content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text {
text: message.content.extract_text(),
cache_control: None,
}]),
cache_control: None,
},
]),
});
}
Role::System => {
return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately".to_string(),
));
}
};
let content_blocks = convert_openai_message_to_anthropic_content(&message)?;
let content = build_anthropic_content(content_blocks);
Ok(MessagesMessage { role, content })
}
}
impl TryFrom<Message> for BedrockMessage {
type Error = TransformError;
fn try_from(message: Message) -> Result<Self, Self::Error> {
let role = match message.role {
Role::User => ConversationRole::User,
Role::Assistant => ConversationRole::Assistant,
Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock
Role::System => {
return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately in Bedrock".to_string(),
));
}
};
let mut content_blocks = Vec::new();
// Handle different message types
match message.role {
Role::User => {
// Convert user message content to content blocks
match message.content {
MessageContent::Text(text) => {
if !text.is_empty() {
content_blocks.push(ContentBlock::Text { text });
}
}
MessageContent::Parts(parts) => {
// Convert OpenAI content parts to Bedrock ContentBlocks
for part in parts {
match part {
crate::apis::openai::ContentPart::Text { text } => {
if !text.is_empty() {
content_blocks.push(ContentBlock::Text { text });
}
}
crate::apis::openai::ContentPart::ImageUrl { image_url } => {
// Convert image URL to Bedrock image format
if image_url.url.starts_with("data:") {
if let Some((media_type, data)) =
parse_data_url(&image_url.url)
{
content_blocks.push(ContentBlock::Image {
image: crate::apis::amazon_bedrock::ImageBlock {
source: crate::apis::amazon_bedrock::ImageSource::Base64 {
media_type,
data,
},
},
});
} else {
return Err(TransformError::UnsupportedConversion(
format!(
"Invalid data URL format: {}",
image_url.url
),
));
}
} else {
return Err(TransformError::UnsupportedConversion(
"Only base64 data URLs are supported for images in Bedrock".to_string()
));
}
}
}
}
}
}
// Ensure we have at least one content block
if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
}
}
Role::Assistant => {
// Handle text content - but only add if non-empty OR if we don't have tool calls
let text_content = message.content.extract_text();
let has_tool_calls = message
.tool_calls
.as_ref()
.map_or(false, |calls| !calls.is_empty());
// Add text content if it's non-empty, or if we have no tool calls (to avoid empty content)
if !text_content.is_empty() {
content_blocks.push(ContentBlock::Text { text: text_content });
} else if !has_tool_calls {
// If we have empty content and no tool calls, add a minimal placeholder
// This prevents the "blank text field" error
content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
}
// Convert tool calls to ToolUse content blocks
if let Some(tool_calls) = message.tool_calls {
for tool_call in tool_calls {
// Parse the arguments string as JSON
let input: serde_json::Value =
serde_json::from_str(&tool_call.function.arguments).map_err(|e| {
TransformError::UnsupportedConversion(format!(
"Failed to parse tool arguments as JSON: {}. Arguments: {}",
e, tool_call.function.arguments
))
})?;
content_blocks.push(ContentBlock::ToolUse {
tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
tool_use_id: tool_call.id,
name: tool_call.function.name,
input,
},
});
}
}
// Bedrock requires at least one content block
if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
}
}
Role::Tool => {
// Tool messages become user messages with ToolResult content blocks
let tool_call_id = message.tool_call_id.ok_or_else(|| {
TransformError::MissingField(
"tool_call_id required for Tool messages".to_string(),
)
})?;
let tool_content = message.content.extract_text();
// Create ToolResult content block
let tool_result_content = if tool_content.is_empty() {
// Even for tool results, we need non-empty content
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
text: " ".to_string(),
}]
} else {
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
text: tool_content,
}]
};
content_blocks.push(ContentBlock::ToolResult {
tool_result: crate::apis::amazon_bedrock::ToolResultBlock {
tool_use_id: tool_call_id,
content: tool_result_content,
status: Some(crate::apis::amazon_bedrock::ToolResultStatus::Success), // Default to success
},
});
}
Role::System => {
// Already handled above with early return
unreachable!()
}
}
Ok(BedrockMessage {
role,
content: content_blocks,
})
}
}
impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
type Error = TransformError;
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
let mut system_prompt = None;
let mut messages = Vec::new();
for message in req.messages {
match message.role {
Role::System => {
system_prompt = Some(message.into());
}
_ => {
let anthropic_message: MessagesMessage = message.try_into()?;
messages.push(anthropic_message);
}
}
}
// Convert tools and tool choice
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
let anthropic_tool_choice =
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
Ok(AnthropicMessagesRequest {
model: req.model,
system: system_prompt,
messages,
max_tokens: req
.max_completion_tokens
.or(req.max_tokens)
.unwrap_or(DEFAULT_MAX_TOKENS),
container: None,
mcp_servers: None,
service_tier: None,
thinking: None,
temperature: req.temperature,
top_p: req.top_p,
top_k: None, // OpenAI doesn't have top_k
stream: req.stream,
stop_sequences: req.stop,
tools: anthropic_tools,
tool_choice: anthropic_tool_choice,
metadata: None,
})
}
}
impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
type Error = TransformError;
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
// Separate system messages from user/assistant messages
let mut system_messages = Vec::new();
let mut conversation_messages = Vec::new();
for message in req.messages {
match message.role {
Role::System => {
let system_text = match message.content {
MessageContent::Text(text) => text,
MessageContent::Parts(parts) => parts.extract_text(),
};
system_messages.push(SystemContentBlock::Text { text: system_text });
}
_ => {
let bedrock_message: BedrockMessage = message.try_into()?;
conversation_messages.push(bedrock_message);
}
}
}
// Convert system messages
let system = if system_messages.is_empty() {
None
} else {
Some(system_messages)
};
// Convert conversation messages
let messages = if conversation_messages.is_empty() {
None
} else {
Some(conversation_messages)
};
// Build inference configuration
let max_tokens = req.max_completion_tokens.or(req.max_tokens);
let inference_config = if max_tokens.is_some()
|| req.temperature.is_some()
|| req.top_p.is_some()
|| req.stop.is_some()
{
Some(InferenceConfiguration {
max_tokens,
temperature: req.temperature,
top_p: req.top_p,
stop_sequences: req.stop,
})
} else {
None
};
// Convert tools and tool choice to ToolConfiguration
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
let tools = req.tools.map(|openai_tools| {
openai_tools
.into_iter()
.map(|tool| BedrockTool::ToolSpec {
tool_spec: ToolSpecDefinition {
name: tool.function.name,
description: tool.function.description,
input_schema: ToolInputSchema {
json: tool.function.parameters,
},
},
})
.collect()
});
let tool_choice = req
.tool_choice
.map(|choice| {
match choice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => BedrockToolChoice::Auto {
auto: AutoChoice {},
},
ToolChoiceType::Required => {
BedrockToolChoice::Any { any: AnyChoice {} }
}
ToolChoiceType::None => BedrockToolChoice::Auto {
auto: AutoChoice {},
}, // Bedrock doesn't have explicit "none"
},
ToolChoice::Function { function, .. } => BedrockToolChoice::Tool {
tool: ToolChoiceSpec {
name: function.name,
},
},
}
})
.or_else(|| {
// If tools are present but no tool_choice specified, default to "auto"
if tools.is_some() {
Some(BedrockToolChoice::Auto {
auto: AutoChoice {},
})
} else {
None
}
});
Some(ToolConfiguration { tools, tool_choice })
} else {
None
};
Ok(ConverseRequest {
model_id: req.model,
messages,
system,
inference_config,
tool_config,
stream: req.stream.unwrap_or(false),
guardrail_config: None,
additional_model_request_fields: None,
additional_model_response_field_paths: None,
performance_config: None,
prompt_variables: None,
request_metadata: None,
metadata: None,
})
}
}
/// Convert OpenAI tools to Anthropic format
fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
tools
.into_iter()
.map(|tool| MessagesTool {
name: tool.function.name,
description: tool.function.description,
input_schema: tool.function.parameters,
})
.collect()
}
/// Convert OpenAI tool choice to Anthropic format
fn convert_openai_tool_choice(
tool_choice: Option<ToolChoice>,
parallel_tool_calls: Option<bool>,
) -> Option<MessagesToolChoice> {
tool_choice.map(|choice| match choice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::Required => MessagesToolChoice {
kind: MessagesToolChoiceType::Any,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::None => MessagesToolChoice {
kind: MessagesToolChoiceType::None,
name: None,
disable_parallel_tool_use: None,
},
},
ToolChoice::Function { function, .. } => MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some(function.name),
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
})
}
/// Build Anthropic message content from content blocks
fn build_anthropic_content(content_blocks: Vec<MessagesContentBlock>) -> MessagesMessageContent {
if content_blocks.len() == 1 {
match &content_blocks[0] {
MessagesContentBlock::Text { text, .. } => MessagesMessageContent::Single(text.clone()),
_ => MessagesMessageContent::Blocks(content_blocks),
}
} else if content_blocks.is_empty() {
MessagesMessageContent::Single("".to_string())
} else {
MessagesMessageContent::Blocks(content_blocks)
}
}
/// Parse a data URL into media type and base64 data
/// Supports format: data:image/jpeg;base64,<data>
fn parse_data_url(url: &str) -> Option<(String, String)> {
if !url.starts_with("data:") {
return None;
}
let without_prefix = &url[5..]; // Remove "data:" prefix
let parts: Vec<&str> = without_prefix.splitn(2, ',').collect();
if parts.len() != 2 {
return None;
}
let header = parts[0];
let data = parts[1];
// Parse header: "image/jpeg;base64" or just "image/jpeg"
let header_parts: Vec<&str> = header.split(';').collect();
if header_parts.is_empty() {
return None;
}
let media_type = header_parts[0].to_string();
// Check if it's base64 encoded
if header_parts.len() > 1 && header_parts[1] == "base64" {
Some((media_type, data.to_string()))
} else {
// For now, only support base64 encoding
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::amazon_bedrock::{
ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
ToolChoice as BedrockToolChoice,
};
use crate::apis::openai::{
ChatCompletionsRequest, Function, FunctionChoice, Message, MessageContent, Role, Tool,
ToolChoice, ToolChoiceType,
};
use serde_json::json;
#[test]
fn test_openai_to_bedrock_basic_request() {
let openai_request = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("You are a helpful assistant.".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
},
Message {
role: Role::User,
content: MessageContent::Text("Hello, how are you?".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
},
],
temperature: Some(0.7),
top_p: Some(0.9),
max_completion_tokens: Some(1000),
stop: Some(vec!["STOP".to_string()]),
stream: Some(false),
tools: None,
tool_choice: None,
..Default::default()
};
let bedrock_request: ConverseRequest = openai_request.try_into().unwrap();
assert_eq!(bedrock_request.model_id, "gpt-4");
assert!(bedrock_request.system.is_some());
assert_eq!(bedrock_request.system.as_ref().unwrap().len(), 1);
if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] {
assert_eq!(text, "You are a helpful assistant.");
} else {
panic!("Expected system text block");
}
assert!(bedrock_request.messages.is_some());
let messages = bedrock_request.messages.as_ref().unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, ConversationRole::User);
if let ContentBlock::Text { text } = &messages[0].content[0] {
assert_eq!(text, "Hello, how are you?");
} else {
panic!("Expected text content block");
}
let inference_config = bedrock_request.inference_config.as_ref().unwrap();
assert_eq!(inference_config.temperature, Some(0.7));
assert_eq!(inference_config.top_p, Some(0.9));
assert_eq!(inference_config.max_tokens, Some(1000));
assert_eq!(
inference_config.stop_sequences,
Some(vec!["STOP".to_string()])
);
}
#[test]
fn test_openai_to_bedrock_with_tools() {
let openai_request = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("What's the weather like?".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
}],
temperature: None,
top_p: None,
max_completion_tokens: Some(1000),
stop: None,
stream: None,
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get current weather information".to_string()),
parameters: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
}
},
"required": ["location"]
}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::Function {
choice_type: "function".to_string(),
function: FunctionChoice {
name: "get_weather".to_string(),
},
}),
..Default::default()
};
let bedrock_request: ConverseRequest = openai_request.try_into().unwrap();
assert_eq!(bedrock_request.model_id, "gpt-4");
assert!(bedrock_request.tool_config.is_some());
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
assert!(tool_config.tools.is_some());
let tools = tool_config.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
assert_eq!(tool_spec.name, "get_weather");
assert_eq!(
tool_spec.description,
Some("Get current weather information".to_string())
);
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
assert_eq!(tool.name, "get_weather");
} else {
panic!("Expected specific tool choice");
}
}
#[test]
fn test_openai_to_bedrock_auto_tool_choice() {
let openai_request = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("Help me with something".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
}],
temperature: None,
top_p: None,
max_completion_tokens: Some(500),
stop: None,
stream: None,
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "help_tool".to_string(),
description: Some("A helpful tool".to_string()),
parameters: json!({
"type": "object",
"properties": {}
}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::Type(ToolChoiceType::Auto)),
..Default::default()
};
let bedrock_request: ConverseRequest = openai_request.try_into().unwrap();
assert!(bedrock_request.tool_config.is_some());
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
assert!(matches!(
tool_config.tool_choice,
Some(BedrockToolChoice::Auto { .. })
));
}
#[test]
fn test_openai_to_bedrock_multi_message_conversation() {
let openai_request = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("Be concise".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
},
Message {
role: Role::User,
content: MessageContent::Text("Hello".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
},
Message {
role: Role::Assistant,
content: MessageContent::Text("Hi there! How can I help you?".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
},
Message {
role: Role::User,
content: MessageContent::Text("What's 2+2?".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
},
],
temperature: Some(0.5),
top_p: None,
max_completion_tokens: Some(100),
stop: None,
stream: None,
tools: None,
tool_choice: None,
..Default::default()
};
let bedrock_request: ConverseRequest = openai_request.try_into().unwrap();
assert!(bedrock_request.messages.is_some());
let messages = bedrock_request.messages.as_ref().unwrap();
assert_eq!(messages.len(), 3); // System message is separate
assert_eq!(messages[0].role, ConversationRole::User);
assert_eq!(messages[1].role, ConversationRole::Assistant);
assert_eq!(messages[2].role, ConversationRole::User);
// Check system prompt
assert!(bedrock_request.system.is_some());
if let SystemContentBlock::Text { text } = &bedrock_request.system.as_ref().unwrap()[0] {
assert_eq!(text, "Be concise");
} else {
panic!("Expected system text block");
}
}
#[test]
fn test_openai_message_to_bedrock_conversion() {
let openai_message = Message {
role: Role::User,
content: MessageContent::Text("Test message".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
};
let bedrock_message: BedrockMessage = openai_message.try_into().unwrap();
assert_eq!(bedrock_message.role, ConversationRole::User);
assert_eq!(bedrock_message.content.len(), 1);
if let ContentBlock::Text { text } = &bedrock_message.content[0] {
assert_eq!(text, "Test message");
} else {
panic!("Expected text content block");
}
}
}

View file

@ -0,0 +1,4 @@
//! Request transformation modules
pub mod from_anthropic;
pub mod from_openai;

View file

@ -0,0 +1,3 @@
//! Response transformation modules
pub mod to_anthropic;
pub mod to_openai;

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff