cargo fmt

This commit is contained in:
Adil Hafeez 2025-10-21 16:31:29 -07:00
parent aec052a843
commit d35d068d0d
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
25 changed files with 1978 additions and 1258 deletions

View file

@ -1,17 +1,17 @@
use std::sync::Arc;
use std::collections::HashMap;
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{ModelAlias, ModelUsagePreference}; use common::configuration::{ModelAlias, ModelUsagePreference};
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_IS_STREAMING_HEADER}; use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
use hermesllm::apis::openai::ChatCompletionsRequest; use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::clients::SupportedAPIs;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::clients::SupportedAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType}; use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody}; use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame; use hyper::body::Frame;
use hyper::header::{self}; use hyper::header::{self};
use hyper::{Request, Response, StatusCode}; use hyper::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@ -31,14 +31,19 @@ pub async fn chat(
full_qualified_llm_provider_url: String, full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>, model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string(); let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone(); let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes(); let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes)); debug!(
"Received request body (raw utf8): {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) { let mut client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request, Ok(request) => request,
Err(err) => { Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err); warn!("Failed to parse request as ProviderRequestType: {}", err);
@ -79,18 +84,30 @@ pub async fn chat(
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original) // Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
let chat_completions_request_for_arch_router: ChatCompletionsRequest = let chat_completions_request_for_arch_router: ChatCompletionsRequest =
match ProviderRequestType::try_from((client_request, &SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) { match ProviderRequestType::try_from((
client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions(
hermesllm::apis::OpenAIApi::ChatCompletions,
),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req, Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(ProviderRequestType::MessagesRequest(_) | ProviderRequestType::BedrockConverse(_) | ProviderRequestType::BedrockConverseStream(_)) => { Ok(
ProviderRequestType::MessagesRequest(_)
| ProviderRequestType::BedrockConverse(_)
| ProviderRequestType::BedrockConverseStream(_),
) => {
// This should not happen after conversion to OpenAI format // This should not happen after conversion to OpenAI format
warn!("Unexpected: got MessagesRequest after converting to OpenAI format"); warn!("Unexpected: got MessagesRequest after converting to OpenAI format");
let err_msg = "Request conversion failed".to_string(); let err_msg = "Request conversion failed".to_string();
let mut bad_request = Response::new(full(err_msg)); let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST; *bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request); return Ok(bad_request);
}, }
Err(err) => { Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err); warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
let err_msg = format!("Failed to convert request: {}", err); let err_msg = format!("Failed to convert request: {}", err);
let mut bad_request = Response::new(full(err_msg)); let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST; *bad_request.status_mut() = StatusCode::BAD_REQUEST;
@ -108,28 +125,29 @@ pub async fn chat(
.find(|(ty, _)| ty.as_str() == "traceparent") .find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string()); .map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> = let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
routing_metadata.as_ref().and_then(|metadata| { metadata
metadata .get("archgw_preference_config")
.get("archgw_preference_config") .map(|value| value.to_string())
.map(|value| value.to_string()) });
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref() .as_ref()
.and_then(|s| serde_yaml::from_str(s).ok()); .and_then(|s| serde_yaml::from_str(s).ok());
let latest_message_for_log = let latest_message_for_log = chat_completions_request_for_arch_router
chat_completions_request_for_arch_router .messages
.messages .last()
.last() .map_or("None".to_string(), |msg| {
.map_or("None".to_string(), |msg| { msg.content.to_string().replace('\n', "\\n")
msg.content.to_string().replace('\n', "\\n") });
});
const MAX_MESSAGE_LENGTH: usize = 50; const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH { let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
let truncated: String = latest_message_for_log.chars().take(MAX_MESSAGE_LENGTH).collect(); let truncated: String = latest_message_for_log
.chars()
.take(MAX_MESSAGE_LENGTH)
.collect();
format!("{}...", truncated) format!("{}...", truncated)
} else { } else {
latest_message_for_log latest_message_for_log
@ -155,12 +173,11 @@ pub async fn chat(
Ok(route) => match route { Ok(route) => match route {
Some((_, model_name)) => model_name, Some((_, model_name)) => model_name,
None => { None => {
info!( info!(
"No route determined, using default model from request: {}", "No route determined, using default model from request: {}",
chat_completions_request_for_arch_router.model chat_completions_request_for_arch_router.model
); );
chat_completions_request_for_arch_router.model.clone() chat_completions_request_for_arch_router.model.clone()
} }
}, },
Err(err) => { Err(err) => {

View file

@ -68,8 +68,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
&serde_json::to_string(arch_config.as_ref()).unwrap() &serde_json::to_string(arch_config.as_ref()).unwrap()
); );
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT") let llm_provider_url =
.unwrap_or_else(|_| "http://localhost:12001".to_string()); env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
info!("llm provider url: {}", llm_provider_url); info!("llm provider url: {}", llm_provider_url);
info!("listening on http://{}", bind_address); info!("listening on http://{}", bind_address);
@ -96,7 +96,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::new(arch_config.model_aliases.clone()); let model_aliases = Arc::new(arch_config.model_aliases.clone());
loop { loop {
let (stream, _) = listener.accept().await?; let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?; let peer_addr = stream.peer_addr()?;
@ -108,7 +107,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone(); let llm_providers = llm_providers.clone();
let service = service_fn(move |req| { let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service); let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req); let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone(); let llm_provider_url = llm_provider_url.clone();
@ -118,7 +116,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
async move { async move {
match (req.method(), req.uri().path()) { match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => { (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url, model_aliases) chat(req, router_service, fully_qualified_url, model_aliases)
.with_context(parent_cx) .with_context(parent_cx)
.await .await

View file

@ -1,9 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use common::{ use common::configuration::{ModelUsagePreference, RoutingPreference};
configuration::{ModelUsagePreference, RoutingPreference}, use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
};
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::{debug, warn}; use tracing::{debug, warn};

View file

@ -1,20 +1,27 @@
use std::sync::OnceLock;
use std::fmt; use std::fmt;
use std::sync::OnceLock;
use opentelemetry::global; use opentelemetry::global;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider}; use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter; use opentelemetry_stdout::SpanExporter;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields};
use tracing::{Event, Subscriber};
use time::macros::format_description; use time::macros::format_description;
use tracing::{Event, Subscriber};
use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields};
use tracing_subscriber::EnvFilter;
struct BracketedTime; struct BracketedTime;
impl FormatTime for BracketedTime { impl FormatTime for BracketedTime {
fn format_time(&self, w: &mut format::Writer<'_>) -> fmt::Result { fn format_time(&self, w: &mut format::Writer<'_>) -> fmt::Result {
let now = time::OffsetDateTime::now_utc(); let now = time::OffsetDateTime::now_utc();
write!(w, "[{}]", now.format(&format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]")).unwrap()) write!(
w,
"[{}]",
now.format(&format_description!(
"[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"
))
.unwrap()
)
} }
} }
@ -34,7 +41,11 @@ where
let timer = BracketedTime; let timer = BracketedTime;
timer.format_time(&mut writer)?; timer.format_time(&mut writer)?;
write!(writer, "[{}] ", event.metadata().level().to_string().to_lowercase())?; write!(
writer,
"[{}] ",
event.metadata().level().to_string().to_lowercase()
)?;
ctx.field_format().format_fields(writer.by_ref(), event)?; ctx.field_format().format_fields(writer.by_ref(), event)?;

View file

@ -33,7 +33,6 @@ pub fn get_llm_provider(
return provider; return provider;
} }
if llm_providers.default().is_some() { if llm_providers.default().is_some() {
return llm_providers.default().unwrap(); return llm_providers.default().unwrap();
} }

View file

@ -2,8 +2,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use thiserror::Error;
use std::collections::HashMap; use std::collections::HashMap;
use thiserror::Error;
use super::ApiDefinition; use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::request::{ProviderRequest, ProviderRequestError};
@ -56,10 +56,7 @@ impl ApiDefinition for AmazonBedrockApi {
} }
fn all_variants() -> Vec<Self> { fn all_variants() -> Vec<Self> {
vec![ vec![AmazonBedrockApi::Converse, AmazonBedrockApi::ConverseStream]
AmazonBedrockApi::Converse,
AmazonBedrockApi::ConverseStream,
]
} }
} }
@ -173,7 +170,9 @@ impl ProviderRequest for ConverseRequest {
SystemContentBlock::Text { text } => { SystemContentBlock::Text { text } => {
text_parts.push(text.clone()); text_parts.push(text.clone());
} }
SystemContentBlock::GuardContent { text: Some(guard_text) } => { SystemContentBlock::GuardContent {
text: Some(guard_text),
} => {
text_parts.push(guard_text.text.clone()); text_parts.push(guard_text.text.clone());
} }
SystemContentBlock::GuardContent { text: None } => { SystemContentBlock::GuardContent { text: None } => {
@ -194,11 +193,9 @@ impl ProviderRequest for ConverseRequest {
.find(|msg| msg.role == ConversationRole::User) .find(|msg| msg.role == ConversationRole::User)
.and_then(|msg| { .and_then(|msg| {
// Extract the first text content block from the user message // Extract the first text content block from the user message
msg.content.iter().find_map(|content| { msg.content.iter().find_map(|content| match content {
match content { ContentBlock::Text { text } => Some(text.clone()),
ContentBlock::Text { text } => Some(text.clone()), _ => None,
_ => None,
}
}) })
}) })
} }
@ -294,13 +291,13 @@ pub enum ConversationRole {
#[serde(untagged)] #[serde(untagged)]
pub enum ContentBlock { pub enum ContentBlock {
Text { Text {
text: String text: String,
}, },
Image { Image {
image: ImageBlock image: ImageBlock,
}, },
Document { Document {
document: DocumentBlock document: DocumentBlock,
}, },
ToolUse { ToolUse {
#[serde(rename = "toolUse")] #[serde(rename = "toolUse")]
@ -360,9 +357,7 @@ pub enum SystemContentBlock {
#[serde(rename = "text")] #[serde(rename = "text")]
Text { text: String }, Text { text: String },
#[serde(rename = "guardContent")] #[serde(rename = "guardContent")]
GuardContent { GuardContent { text: Option<GuardContentText> },
text: Option<GuardContentText>,
},
} }
/// Image source for vision capabilities /// Image source for vision capabilities
@ -723,10 +718,12 @@ pub struct ContentBlockDeltaEvent {
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)] #[serde(untagged)]
pub enum ContentBlockDelta { pub enum ContentBlockDelta {
Text { text: String }, Text {
text: String,
},
ToolUse { ToolUse {
#[serde(rename = "toolUse")] #[serde(rename = "toolUse")]
tool_use: ToolUseDelta tool_use: ToolUseDelta,
}, },
} }
@ -1008,7 +1005,9 @@ impl Into<String> for ConverseStreamEvent {
ConverseStreamEvent::Metadata { .. } => "metadata", ConverseStreamEvent::Metadata { .. } => "metadata",
ConverseStreamEvent::InternalServerException { .. } => "internal_server_exception", ConverseStreamEvent::InternalServerException { .. } => "internal_server_exception",
ConverseStreamEvent::ModelStreamErrorException { .. } => "model_stream_error_exception", ConverseStreamEvent::ModelStreamErrorException { .. } => "model_stream_error_exception",
ConverseStreamEvent::ServiceUnavailableException { .. } => "service_unavailable_exception", ConverseStreamEvent::ServiceUnavailableException { .. } => {
"service_unavailable_exception"
}
ConverseStreamEvent::ThrottlingException { .. } => "throttling_exception", ConverseStreamEvent::ThrottlingException { .. } => "throttling_exception",
ConverseStreamEvent::ValidationException { .. } => "validation_exception", ConverseStreamEvent::ValidationException { .. } => "validation_exception",
}; };
@ -1019,17 +1018,14 @@ impl Into<String> for ConverseStreamEvent {
} }
} }
// Implement ProviderStreamResponse for ConverseStreamEvent // Implement ProviderStreamResponse for ConverseStreamEvent
impl ProviderStreamResponse for ConverseStreamEvent { impl ProviderStreamResponse for ConverseStreamEvent {
fn content_delta(&self) -> Option<&str> { fn content_delta(&self) -> Option<&str> {
match self { match self {
ConverseStreamEvent::ContentBlockDelta(event) => { ConverseStreamEvent::ContentBlockDelta(event) => match &event.delta {
match &event.delta { ContentBlockDelta::Text { text } => Some(text),
ContentBlockDelta::Text { text } => Some(text), ContentBlockDelta::ToolUse { .. } => None,
ContentBlockDelta::ToolUse { .. } => None, },
}
}
_ => None, _ => None,
} }
} }
@ -1099,7 +1095,10 @@ mod tests {
}; };
let serialized = serde_json::to_value(&tool).unwrap(); let serialized = serde_json::to_value(&tool).unwrap();
println!("Tool serialization: {}", serde_json::to_string_pretty(&serialized).unwrap()); println!(
"Tool serialization: {}",
serde_json::to_string_pretty(&serialized).unwrap()
);
// Verify the structure matches Bedrock API expectations // Verify the structure matches Bedrock API expectations
assert!(serialized.get("toolSpec").is_some()); assert!(serialized.get("toolSpec").is_some());
@ -1107,16 +1106,24 @@ mod tests {
let tool_spec = serialized.get("toolSpec").unwrap(); let tool_spec = serialized.get("toolSpec").unwrap();
assert_eq!(tool_spec.get("name").unwrap(), "get_weather"); assert_eq!(tool_spec.get("name").unwrap(), "get_weather");
assert_eq!(tool_spec.get("description").unwrap(), "Get the current weather for a specified city"); assert_eq!(
tool_spec.get("description").unwrap(),
"Get the current weather for a specified city"
);
assert!(tool_spec.get("inputSchema").is_some()); assert!(tool_spec.get("inputSchema").is_some());
} }
#[test] #[test]
fn test_tool_choice_serialization_format() { fn test_tool_choice_serialization_format() {
// Test Auto choice // Test Auto choice
let auto_choice = ToolChoice::Auto { auto: AutoChoice {} }; let auto_choice = ToolChoice::Auto {
auto: AutoChoice {},
};
let serialized = serde_json::to_value(&auto_choice).unwrap(); let serialized = serde_json::to_value(&auto_choice).unwrap();
println!("Auto ToolChoice serialization: {}", serde_json::to_string_pretty(&serialized).unwrap()); println!(
"Auto ToolChoice serialization: {}",
serde_json::to_string_pretty(&serialized).unwrap()
);
assert!(serialized.get("auto").is_some()); assert!(serialized.get("auto").is_some());
assert!(serialized.get("type").is_none()); // Should not have a type field assert!(serialized.get("type").is_none()); // Should not have a type field
@ -1124,11 +1131,14 @@ mod tests {
// Test Tool choice // Test Tool choice
let tool_choice = ToolChoice::Tool { let tool_choice = ToolChoice::Tool {
tool: ToolChoiceSpec { tool: ToolChoiceSpec {
name: "get_weather".to_string() name: "get_weather".to_string(),
} },
}; };
let serialized = serde_json::to_value(&tool_choice).unwrap(); let serialized = serde_json::to_value(&tool_choice).unwrap();
println!("Tool ToolChoice serialization: {}", serde_json::to_string_pretty(&serialized).unwrap()); println!(
"Tool ToolChoice serialization: {}",
serde_json::to_string_pretty(&serialized).unwrap()
);
assert!(serialized.get("tool").is_some()); assert!(serialized.get("tool").is_some());
assert!(serialized.get("type").is_none()); // Should not have a type field assert!(serialized.get("type").is_none()); // Should not have a type field

View file

@ -1,7 +1,7 @@
use std::collections::HashSet;
use bytes::Buf;
use aws_smithy_eventstream::frame::DecodedFrame; use aws_smithy_eventstream::frame::DecodedFrame;
use aws_smithy_eventstream::frame::MessageFrameDecoder; use aws_smithy_eventstream::frame::MessageFrameDecoder;
use bytes::Buf;
use std::collections::HashSet;
/// AWS Event Stream frame decoder wrapper /// AWS Event Stream frame decoder wrapper
pub struct BedrockBinaryFrameDecoder<B> pub struct BedrockBinaryFrameDecoder<B>

View file

@ -8,7 +8,7 @@ use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse}; use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
use crate::transforms::lib::ExtractText; use crate::transforms::lib::ExtractText;
use crate::{MESSAGES_PATH}; use crate::MESSAGES_PATH;
// Enum for all supported Anthropic APIs // Enum for all supported Anthropic APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi {
} }
fn all_variants() -> Vec<Self> { fn all_variants() -> Vec<Self> {
vec![ vec![AnthropicApi::Messages]
AnthropicApi::Messages,
]
} }
} }
@ -100,7 +98,6 @@ pub struct McpServer {
pub tool_configuration: Option<McpToolConfiguration>, pub tool_configuration: Option<McpToolConfiguration>,
} }
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesRequest { pub struct MessagesRequest {
@ -121,10 +118,8 @@ pub struct MessagesRequest {
pub stop_sequences: Option<Vec<String>>, pub stop_sequences: Option<Vec<String>>,
pub tools: Option<Vec<MessagesTool>>, pub tools: Option<Vec<MessagesTool>>,
pub tool_choice: Option<MessagesToolChoice>, pub tool_choice: Option<MessagesToolChoice>,
} }
// Messages API specific types // Messages API specific types
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
@ -235,34 +230,21 @@ impl ExtractText for Vec<MessagesContentBlock> {
} }
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum MessagesImageSource { pub enum MessagesImageSource {
Base64 { Base64 { media_type: String, data: String },
media_type: String, Url { url: String },
data: String,
},
Url {
url: String,
},
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum MessagesDocumentSource { pub enum MessagesDocumentSource {
Base64 { Base64 { media_type: String, data: String },
media_type: String, Url { url: String },
data: String, File { file_id: String },
},
Url {
url: String,
},
File {
file_id: String,
},
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent {
fn extract_text(&self) -> String { fn extract_text(&self) -> String {
match self { match self {
MessagesMessageContent::Single(text) => text.clone(), MessagesMessageContent::Single(text) => text.clone(),
MessagesMessageContent::Blocks(parts) => parts.extract_text() MessagesMessageContent::Blocks(parts) => parts.extract_text(),
} }
} }
} }
@ -320,7 +302,6 @@ pub struct MessagesToolChoice {
pub disable_parallel_tool_use: Option<bool>, pub disable_parallel_tool_use: Option<bool>,
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum MessagesStopReason { pub enum MessagesStopReason {
@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse {
Some(self) Some(self)
} }
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize)) Some((
self.usage.input_tokens as usize,
self.usage.output_tokens as usize,
(self.usage.input_tokens + self.usage.output_tokens) as usize,
))
} }
} }
@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest {
} }
fn metadata(&self) -> &Option<HashMap<String, Value>> { fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata; return &self.metadata;
} }
fn remove_metadata_key(&mut self, key: &str) -> bool { fn remove_metadata_key(&mut self, key: &str) -> bool {
@ -572,13 +557,11 @@ impl MessagesRole {
impl ProviderStreamResponse for MessagesStreamEvent { impl ProviderStreamResponse for MessagesStreamEvent {
fn content_delta(&self) -> Option<&str> { fn content_delta(&self) -> Option<&str> {
match self { match self {
MessagesStreamEvent::ContentBlockDelta { delta, .. } => { MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta {
match delta { MessagesContentDelta::TextDelta { text } => Some(text),
MessagesContentDelta::TextDelta { text } => Some(text), MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking), _ => None,
_ => None, },
}
}
_ => None, _ => None,
} }
} }
@ -627,7 +610,8 @@ mod tests {
}); });
// Deserialize JSON into MessagesRequest // Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set // Validate required fields are properly set
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -687,7 +671,8 @@ mod tests {
}); });
// Deserialize JSON into MessagesRequest // Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields // Validate required fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -730,7 +715,10 @@ mod tests {
assert_eq!(serialized_json["messages"], original_json["messages"]); assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]); assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["system"], original_json["system"]); assert_eq!(serialized_json["system"], original_json["system"]);
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]); assert_eq!(
serialized_json["service_tier"],
original_json["service_tier"]
);
assert_eq!(serialized_json["thinking"], original_json["thinking"]); assert_eq!(serialized_json["thinking"], original_json["thinking"]);
assert_eq!(serialized_json["metadata"], original_json["metadata"]); assert_eq!(serialized_json["metadata"], original_json["metadata"]);
@ -818,7 +806,8 @@ mod tests {
}); });
// Deserialize JSON into MessagesRequest // Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields // Validate top-level fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -833,7 +822,10 @@ mod tests {
// Validate text content block // Validate text content block
if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] { if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] {
assert_eq!(text, "What can you see in this image and what's the weather like?"); assert_eq!(
text,
"What can you see in this image and what's the weather like?"
);
} else { } else {
panic!("Expected text content block"); panic!("Expected text content block");
} }
@ -861,20 +853,32 @@ mod tests {
// Validate thinking content block // Validate thinking content block
if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] { if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] {
assert_eq!(thinking, "Let me analyze the image and then check the weather..."); assert_eq!(
thinking,
"Let me analyze the image and then check the weather..."
);
} else { } else {
panic!("Expected thinking content block"); panic!("Expected thinking content block");
} }
// Validate text content block // Validate text content block
if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] { if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] {
assert_eq!(text, "I can see the image. Let me check the weather for you."); assert_eq!(
text,
"I can see the image. Let me check the weather for you."
);
} else { } else {
panic!("Expected text content block"); panic!("Expected text content block");
} }
// Validate tool use content block // Validate tool use content block
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] { if let MessagesContentBlock::ToolUse {
ref id,
ref name,
ref input,
..
} = content_blocks[2]
{
assert_eq!(id, "toolu_weather123"); assert_eq!(id, "toolu_weather123");
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA"); assert_eq!(input["location"], "San Francisco, CA");
@ -892,7 +896,10 @@ mod tests {
let tool = &tools[0]; let tool = &tools[0];
assert_eq!(tool.name, "get_weather"); assert_eq!(tool.name, "get_weather");
assert_eq!(tool.description, Some("Get current weather information for a location".to_string())); assert_eq!(
tool.description,
Some("Get current weather information for a location".to_string())
);
assert_eq!(tool.input_schema["type"], "object"); assert_eq!(tool.input_schema["type"], "object");
assert!(tool.input_schema["properties"]["location"].is_object()); assert!(tool.input_schema["properties"]["location"].is_object());
@ -938,10 +945,16 @@ mod tests {
assert_eq!(deserialized_mcp.name, "test-server"); assert_eq!(deserialized_mcp.name, "test-server");
assert_eq!(deserialized_mcp.server_type, McpServerType::Url); assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
assert_eq!(deserialized_mcp.url, "https://example.com/mcp"); assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string())); assert_eq!(
deserialized_mcp.authorization_token,
Some("secret-token".to_string())
);
if let Some(tool_config) = &deserialized_mcp.tool_configuration { if let Some(tool_config) = &deserialized_mcp.tool_configuration {
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()])); assert_eq!(
tool_config.allowed_tools,
Some(vec!["tool1".to_string(), "tool2".to_string()])
);
assert_eq!(tool_config.enabled, Some(true)); assert_eq!(tool_config.enabled, Some(true));
} else { } else {
panic!("Expected tool configuration"); panic!("Expected tool configuration");
@ -957,7 +970,8 @@ mod tests {
"url": "https://minimal.com/mcp" "url": "https://minimal.com/mcp"
}); });
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap(); let deserialized_minimal: McpServer =
serde_json::from_value(minimal_mcp_json.clone()).unwrap();
assert_eq!(deserialized_minimal.name, "minimal-server"); assert_eq!(deserialized_minimal.name, "minimal-server");
assert_eq!(deserialized_minimal.server_type, McpServerType::Url); assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp"); assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
@ -991,12 +1005,16 @@ mod tests {
} }
}); });
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap(); let deserialized_response: MessagesResponse =
serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.id, "msg_01ABC123"); assert_eq!(deserialized_response.id, "msg_01ABC123");
assert_eq!(deserialized_response.obj_type, "message"); assert_eq!(deserialized_response.obj_type, "message");
assert_eq!(deserialized_response.role, MessagesRole::Assistant); assert_eq!(deserialized_response.role, MessagesRole::Assistant);
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229"); assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn); assert_eq!(
deserialized_response.stop_reason,
MessagesStopReason::EndTurn
);
assert!(deserialized_response.stop_sequence.is_none()); assert!(deserialized_response.stop_sequence.is_none());
assert!(deserialized_response.container.is_none()); assert!(deserialized_response.container.is_none());
@ -1011,7 +1029,10 @@ mod tests {
// Check usage // Check usage
assert_eq!(deserialized_response.usage.input_tokens, 10); assert_eq!(deserialized_response.usage.input_tokens, 10);
assert_eq!(deserialized_response.usage.output_tokens, 25); assert_eq!(deserialized_response.usage.output_tokens, 25);
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5)); assert_eq!(
deserialized_response.usage.cache_creation_input_tokens,
Some(5)
);
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3)); assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap(); let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
@ -1027,7 +1048,8 @@ mod tests {
} }
}); });
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap(); let deserialized_event: MessagesStreamEvent =
serde_json::from_value(stream_event_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event { if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0); assert_eq!(index, 0);
if let MessagesContentDelta::TextDelta { text } = delta { if let MessagesContentDelta::TextDelta { text } = delta {
@ -1055,8 +1077,15 @@ mod tests {
} }
}); });
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap(); let deserialized_tool_use: MessagesContentBlock =
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use { serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse {
ref id,
ref name,
ref input,
..
} = deserialized_tool_use
{
assert_eq!(id, "toolu_01ABC123"); assert_eq!(id, "toolu_01ABC123");
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA"); assert_eq!(input["location"], "San Francisco, CA");
@ -1079,8 +1108,15 @@ mod tests {
] ]
}); });
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap(); let deserialized_tool_result: MessagesContentBlock =
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result { serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult {
ref tool_use_id,
ref is_error,
ref content,
..
} = deserialized_tool_result
{
assert_eq!(tool_use_id, "toolu_01ABC123"); assert_eq!(tool_use_id, "toolu_01ABC123");
assert!(is_error.is_none()); assert!(is_error.is_none());
if let ToolResultContent::Blocks(blocks) = content { if let ToolResultContent::Blocks(blocks) = content {
@ -1229,7 +1265,8 @@ mod tests {
}); });
// Deserialize the complex MessagesRequest // Deserialize the complex MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap(); let deserialized_request: MessagesRequest =
serde_json::from_value(complex_request_json.clone()).unwrap();
// Verify basic fields // Verify basic fields
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514"); assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
@ -1239,8 +1276,15 @@ mod tests {
// Verify system message with cache_control // Verify system message with cache_control
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system { if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
assert_eq!(system_blocks.len(), 2); assert_eq!(system_blocks.len(), 2);
if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] { if let MessagesContentBlock::Text {
assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude."); text,
cache_control,
} = &system_blocks[0]
{
assert_eq!(
text,
"You are Claude Code, Anthropic's official CLI for Claude."
);
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral)); assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
} else { } else {
panic!("Expected text system message with cache_control"); panic!("Expected text system message with cache_control");
@ -1253,7 +1297,13 @@ mod tests {
let assistant_message = &deserialized_request.messages[1]; let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, MessagesRole::Assistant); assert_eq!(assistant_message.role, MessagesRole::Assistant);
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content { if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] { if let MessagesContentBlock::ToolUse {
id,
name,
input,
cache_control,
} = &content_blocks[0]
{
assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl"); assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl");
assert_eq!(name, "TodoWrite"); assert_eq!(name, "TodoWrite");
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral)); assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
@ -1272,7 +1322,12 @@ mod tests {
let user_message = &deserialized_request.messages[2]; let user_message = &deserialized_request.messages[2];
assert_eq!(user_message.role, MessagesRole::User); assert_eq!(user_message.role, MessagesRole::User);
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content { if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] { if let MessagesContentBlock::ToolResult {
tool_use_id,
content,
..
} = &content_blocks[0]
{
assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl"); assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl");
if let ToolResultContent::Text(text) = content { if let ToolResultContent::Text(text) = content {
assert!(text.contains("Todos have been modified successfully")); assert!(text.contains("Todos have been modified successfully"));
@ -1284,7 +1339,11 @@ mod tests {
} }
// Verify text content with cache_control // Verify text content with cache_control
if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] { if let MessagesContentBlock::Text {
text,
cache_control,
} = &content_blocks[2]
{
assert_eq!(text, "try again"); assert_eq!(text, "try again");
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral)); assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
} else { } else {
@ -1296,11 +1355,15 @@ mod tests {
// Test serialization round-trip // Test serialization round-trip
let serialized_request = serde_json::to_value(&deserialized_request).unwrap(); let serialized_request = serde_json::to_value(&deserialized_request).unwrap();
let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap(); let re_deserialized_request: MessagesRequest =
serde_json::from_value(serialized_request).unwrap();
// Verify round-trip consistency // Verify round-trip consistency
assert_eq!(deserialized_request.model, re_deserialized_request.model); assert_eq!(deserialized_request.model, re_deserialized_request.model);
assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len()); assert_eq!(
deserialized_request.messages.len(),
re_deserialized_request.messages.len()
);
} }
#[test] #[test]
@ -1339,7 +1402,8 @@ mod tests {
} }
}); });
let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap(); let deserialized_event: MessagesStreamEvent =
serde_json::from_value(thinking_delta_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event { if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0); assert_eq!(index, 0);
if let MessagesContentDelta::ThinkingDelta { thinking } = delta { if let MessagesContentDelta::ThinkingDelta { thinking } = delta {
@ -1352,7 +1416,10 @@ mod tests {
} }
// Test that thinking delta is returned by content_delta() // Test that thinking delta is returned by content_delta()
assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current")); assert_eq!(
deserialized_event.content_delta(),
Some(".\n\nI need to consider:\n1. Current")
);
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap(); let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
assert_eq!(thinking_delta_json, serialized_event_json); assert_eq!(thinking_delta_json, serialized_event_json);
@ -1376,7 +1443,8 @@ mod tests {
} }
}); });
let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap(); let deserialized_request: MessagesRequest =
serde_json::from_value(request_json.clone()).unwrap();
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514"); assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
assert_eq!(deserialized_request.max_tokens, 2048); assert_eq!(deserialized_request.max_tokens, 2048);

View file

@ -1,15 +1,19 @@
pub mod anthropic;
pub mod openai;
pub mod amazon_bedrock; pub mod amazon_bedrock;
pub mod amazon_bedrock_binary_frame; pub mod amazon_bedrock_binary_frame;
pub mod anthropic;
pub mod openai;
pub mod sse; pub mod sse;
// Explicit exports to avoid naming conflicts // Explicit exports to avoid naming conflicts
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
pub use openai::{OpenAIApi, ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice};
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest}; pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
pub use amazon_bedrock::{Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice}; pub use amazon_bedrock::{
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
};
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
pub use openai::{
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
};
pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice};
pub trait ApiDefinition { pub trait ApiDefinition {
/// Returns the endpoint path for this API /// Returns the endpoint path for this API
@ -56,11 +60,7 @@ mod tests {
#[test] #[test]
fn test_api_detection_from_endpoints() { fn test_api_detection_from_endpoints() {
// Test that we can detect APIs from endpoints using the trait // Test that we can detect APIs from endpoints using the trait
let endpoints = vec![ let endpoints = vec![CHAT_COMPLETIONS_PATH, MESSAGES_PATH, "/v1/unknown"];
CHAT_COMPLETIONS_PATH,
MESSAGES_PATH,
"/v1/unknown"
];
let mut detected_apis = Vec::new(); let mut detected_apis = Vec::new();
@ -74,11 +74,14 @@ mod tests {
} }
} }
assert_eq!(detected_apis, vec![ assert_eq!(
"OpenAI: ChatCompletions", detected_apis,
"Anthropic: Messages", vec![
"Unknown API" "OpenAI: ChatCompletions",
]); "Anthropic: Messages",
"Unknown API"
]
);
} }
#[test] #[test]

View file

@ -5,11 +5,11 @@ use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use thiserror::Error; use thiserror::Error;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError}; use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage}; use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
use super::ApiDefinition;
use crate::transforms::lib::ExtractText; use crate::transforms::lib::ExtractText;
use crate::{CHAT_COMPLETIONS_PATH}; use crate::CHAT_COMPLETIONS_PATH;
// ============================================================================ // ============================================================================
// OPENAI API ENUMERATION // OPENAI API ENUMERATION
@ -46,7 +46,7 @@ impl ApiDefinition for OpenAIApi {
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
match self { match self {
OpenAIApi::ChatCompletions => true, OpenAIApi::ChatCompletions => true,
} }
} }
@ -58,9 +58,7 @@ impl ApiDefinition for OpenAIApi {
} }
fn all_variants() -> Vec<Self> { fn all_variants() -> Vec<Self> {
vec![ vec![OpenAIApi::ChatCompletions]
OpenAIApi::ChatCompletions,
]
} }
} }
@ -190,7 +188,9 @@ impl ResponseMessage {
pub fn to_message(&self) -> Message { pub fn to_message(&self) -> Message {
Message { Message {
role: self.role.clone(), role: self.role.clone(),
content: self.content.as_ref() content: self
.content
.as_ref()
.map(|s| MessageContent::Text(s.clone())) .map(|s| MessageContent::Text(s.clone()))
.unwrap_or(MessageContent::Text(String::new())), .unwrap_or(MessageContent::Text(String::new())),
name: None, // Response messages don't have names in the same way request messages do name: None, // Response messages don't have names in the same way request messages do
@ -215,7 +215,7 @@ impl ExtractText for MessageContent {
fn extract_text(&self) -> String { fn extract_text(&self) -> String {
match self { match self {
MessageContent::Text(text) => text.clone(), MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts.extract_text() MessageContent::Parts(parts) => parts.extract_text(),
} }
} }
} }
@ -274,7 +274,6 @@ pub struct ImageUrl {
/// A single message in a chat conversation /// A single message in a chat conversation
/// A tool call made by the assistant /// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall { pub struct ToolCall {
@ -374,7 +373,6 @@ pub enum StaticContentType {
Parts(Vec<ContentPart>), Parts(Vec<ContentPart>),
} }
/// Chat completions API response /// Chat completions API response
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
@ -496,7 +494,6 @@ pub struct ChatCompletionsStreamResponse {
pub service_tier: Option<String>, pub service_tier: Option<String>,
} }
/// A choice in a streaming response /// A choice in a streaming response
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
@ -566,7 +563,6 @@ pub struct Models {
pub data: Vec<ModelDetail>, pub data: Vec<ModelDetail>,
} }
// Error type for streaming operations // Error type for streaming operations
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum OpenAIStreamError { pub enum OpenAIStreamError {
@ -597,13 +593,13 @@ pub enum OpenAIError {
/// Trait Implementations /// Trait Implementations
/// =========================================================================== /// ===========================================================================
/// Parameterized conversion for ChatCompletionsRequest /// Parameterized conversion for ChatCompletionsRequest
impl TryFrom<&[u8]> for ChatCompletionsRequest { impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError; type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> { fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?; let mut req: ChatCompletionsRequest =
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
// Use the centralized suppression logic // Use the centralized suppression logic
req.suppress_max_tokens_if_o3(); req.suppress_max_tokens_if_o3();
req.fix_temperature_if_gpt5(); req.fix_temperature_if_gpt5();
@ -651,13 +647,18 @@ impl ProviderRequest for ChatCompletionsRequest {
fn extract_messages_text(&self) -> String { fn extract_messages_text(&self) -> String {
self.messages.iter().fold(String::new(), |acc, m| { self.messages.iter().fold(String::new(), |acc, m| {
acc + " " + &match &m.content { acc + " "
MessageContent::Text(text) => text.clone(), + &match &m.content {
MessageContent::Parts(parts) => parts.iter().map(|part| match part { MessageContent::Text(text) => text.clone(),
ContentPart::Text { text } => text.clone(), MessageContent::Parts(parts) => parts
ContentPart::ImageUrl { .. } => "[Image]".to_string(), .iter()
}).collect::<Vec<_>>().join(" ") .map(|part| match part {
} ContentPart::Text { text } => text.clone(),
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
})
.collect::<Vec<_>>()
.join(" "),
}
}) })
} }
@ -721,14 +722,14 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
} }
fn role(&self) -> Option<&str> { fn role(&self) -> Option<&str> {
self.choices self.choices.first().and_then(|choice| {
.first() choice.delta.role.as_ref().map(|r| match r {
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
Role::System => "system", Role::System => "system",
Role::User => "user", Role::User => "user",
Role::Assistant => "assistant", Role::Assistant => "assistant",
Role::Tool => "tool", Role::Tool => "tool",
})) })
})
} }
fn event_type(&self) -> Option<&str> { fn event_type(&self) -> Option<&str> {
@ -736,7 +737,6 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -756,7 +756,8 @@ mod tests {
}); });
// Deserialize JSON into ChatCompletionsRequest // Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap(); let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set // Validate required fields are properly set
assert_eq!(deserialized_request.model, "gpt-4"); assert_eq!(deserialized_request.model, "gpt-4");
@ -799,7 +800,8 @@ mod tests {
}); });
// Deserialize JSON into ChatCompletionsRequest // Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap(); let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields // Validate required fields
assert_eq!(deserialized_request.model, "gpt-4"); assert_eq!(deserialized_request.model, "gpt-4");
@ -836,7 +838,10 @@ mod tests {
assert_eq!(serialized_json["messages"], original_json["messages"]); assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]); assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["stream"], original_json["stream"]); assert_eq!(serialized_json["stream"], original_json["stream"]);
assert_eq!(serialized_json["stream_options"], original_json["stream_options"]); assert_eq!(
serialized_json["stream_options"],
original_json["stream_options"]
);
assert_eq!(serialized_json["metadata"], original_json["metadata"]); assert_eq!(serialized_json["metadata"], original_json["metadata"]);
// Handle temperature with floating point tolerance // Handle temperature with floating point tolerance
@ -917,7 +922,8 @@ mod tests {
}); });
// Deserialize JSON into ChatCompletionsRequest // Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap(); let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields // Validate top-level fields
assert_eq!(deserialized_request.model, "gpt-4-vision-preview"); assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
@ -953,7 +959,10 @@ mod tests {
let assistant_message = &deserialized_request.messages[1]; let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, Role::Assistant); assert_eq!(assistant_message.role, Role::Assistant);
if let MessageContent::Text(text) = &assistant_message.content { if let MessageContent::Text(text) = &assistant_message.content {
assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you."); assert_eq!(
text,
"I can see a beautiful cityscape. Let me check the weather for you."
);
} else { } else {
panic!("Expected text content for assistant message"); panic!("Expected text content for assistant message");
} }
@ -967,7 +976,10 @@ mod tests {
assert_eq!(tool_call.id, "call_weather123"); assert_eq!(tool_call.id, "call_weather123");
assert_eq!(tool_call.call_type, "function"); assert_eq!(tool_call.call_type, "function");
assert_eq!(tool_call.function.name, "get_weather"); assert_eq!(tool_call.function.name, "get_weather");
assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}"); assert_eq!(
tool_call.function.arguments,
"{\"location\": \"New York, NY\"}"
);
// Validate third message (tool response) // Validate third message (tool response)
let tool_message = &deserialized_request.messages[2]; let tool_message = &deserialized_request.messages[2];
@ -977,7 +989,10 @@ mod tests {
} else { } else {
panic!("Expected text content for tool message"); panic!("Expected text content for tool message");
} }
assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string())); assert_eq!(
tool_message.tool_call_id,
Some("call_weather123".to_string())
);
// Validate tools array // Validate tools array
assert!(deserialized_request.tools.is_some()); assert!(deserialized_request.tools.is_some());
@ -987,7 +1002,10 @@ mod tests {
let tool = &tools[0]; let tool = &tools[0];
assert_eq!(tool.tool_type, "function"); assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "get_weather"); assert_eq!(tool.function.name, "get_weather");
assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string())); assert_eq!(
tool.function.description,
Some("Get current weather information for a location".to_string())
);
assert_eq!(tool.function.strict, Some(true)); assert_eq!(tool.function.strict, Some(true));
// Validate tool parameters schema // Validate tool parameters schema
@ -1093,7 +1111,8 @@ mod tests {
] ]
}); });
let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap(); let deserialized_assistant: Message =
serde_json::from_value(assistant_json.clone()).unwrap();
assert_eq!(deserialized_assistant.role, Role::Assistant); assert_eq!(deserialized_assistant.role, Role::Assistant);
if let MessageContent::Text(content) = &deserialized_assistant.content { if let MessageContent::Text(content) = &deserialized_assistant.content {
assert_eq!(content, "I'll help with that."); assert_eq!(content, "I'll help with that.");
@ -1142,9 +1161,13 @@ mod tests {
] ]
}); });
let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap(); let deserialized_response: ResponseMessage =
serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.role, Role::Assistant); assert_eq!(deserialized_response.role, Role::Assistant);
assert_eq!(deserialized_response.content, Some("Response content".to_string())); assert_eq!(
deserialized_response.content,
Some("Response content".to_string())
);
assert!(deserialized_response.annotations.is_some()); assert!(deserialized_response.annotations.is_some());
assert!(deserialized_response.refusal.is_none()); assert!(deserialized_response.refusal.is_none());
assert!(deserialized_response.function_call.is_none()); assert!(deserialized_response.function_call.is_none());
@ -1186,7 +1209,10 @@ mod tests {
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap(); let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto)); assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required)); assert_eq!(
required_deserialized,
ToolChoice::Type(ToolChoiceType::Required)
);
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None)); assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
// Test that invalid string values fail deserialization (type safety!) // Test that invalid string values fail deserialization (type safety!)
@ -1237,7 +1263,10 @@ mod tests {
assert_eq!(response.created, 1756574706); assert_eq!(response.created, 1756574706);
assert_eq!(response.model, "gpt-4o-2024-08-06"); assert_eq!(response.model, "gpt-4o-2024-08-06");
assert_eq!(response.service_tier, Some("default".to_string())); assert_eq!(response.service_tier, Some("default".to_string()));
assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string())); assert_eq!(
response.system_fingerprint,
Some("fp_f33640a400".to_string())
);
assert_eq!(response.choices.len(), 1); assert_eq!(response.choices.len(), 1);
assert_eq!(response.usage.prompt_tokens, 65); assert_eq!(response.usage.prompt_tokens, 65);
assert_eq!(response.usage.completion_tokens, 184); assert_eq!(response.usage.completion_tokens, 184);

View file

@ -1,9 +1,9 @@
use std::str::FromStr;
use std::fmt;
use std::error::Error;
use serde::{Serialize, Deserialize};
use crate::providers::response::ProviderStreamResponse; use crate::providers::response::ProviderStreamResponse;
use crate::providers::response::ProviderStreamResponseType; use crate::providers::response::ProviderStreamResponseType;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use std::str::FromStr;
// ============================================================================ // ============================================================================
// SSE EVENT CONTAINER // SSE EVENT CONTAINER
@ -13,19 +13,19 @@ use crate::providers::response::ProviderStreamResponseType;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SseEvent { pub struct SseEvent {
#[serde(rename = "data")] #[serde(rename = "data")]
pub data: Option<String>, // The JSON payload after "data: " pub data: Option<String>, // The JSON payload after "data: "
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta") pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
#[serde(skip_serializing, skip_deserializing)] #[serde(skip_serializing, skip_deserializing)]
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n" pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)] #[serde(skip_serializing, skip_deserializing)]
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
} }
impl SseEvent { impl SseEvent {
@ -48,13 +48,13 @@ impl SseEvent {
/// Get the parsed provider response if available /// Get the parsed provider response if available
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> { pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
self.provider_stream_response.as_ref() self.provider_stream_response
.as_ref()
.map(|resp| resp as &dyn ProviderStreamResponse) .map(|resp| resp as &dyn ProviderStreamResponse)
.ok_or_else(|| { .ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found") std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
}) })
} }
} }
impl FromStr for SseEvent { impl FromStr for SseEvent {
@ -75,7 +75,8 @@ impl FromStr for SseEvent {
sse_transform_buffer: line.to_string(), sse_transform_buffer: line.to_string(),
provider_stream_response: None, provider_stream_response: None,
}) })
} else if line.starts_with("event: ") { //used by Anthropic } else if line.starts_with("event: ") {
//used by Anthropic
let event_type = line[7..].to_string(); let event_type = line[7..].to_string();
if event_type.is_empty() { if event_type.is_empty() {
return Err(SseParseError { return Err(SseParseError {
@ -123,7 +124,6 @@ impl fmt::Display for SseParseError {
impl Error for SseParseError {} impl Error for SseParseError {}
/// Generic SSE (Server-Sent Events) streaming iterator container /// Generic SSE (Server-Sent Events) streaming iterator container
/// Parses raw SSE lines into SseEvent objects /// Parses raw SSE lines into SseEvent objects
pub struct SseStreamIter<I> pub struct SseStreamIter<I>
@ -141,7 +141,10 @@ where
I::Item: AsRef<str>, I::Item: AsRef<str>,
{ {
pub fn new(lines: I) -> Self { pub fn new(lines: I) -> Self {
Self { lines, done_seen: false } Self {
lines,
done_seen: false,
}
} }
} }
@ -151,14 +154,13 @@ impl TryFrom<&[u8]> for SseStreamIter<std::vec::IntoIter<String>> {
type Error = Box<dyn std::error::Error + Send + Sync>; type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> { fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
// Parse as text-based SSE format // Parse as text-based SSE format
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect(); let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
Ok(SseStreamIter::new(lines.into_iter())) Ok(SseStreamIter::new(lines.into_iter()))
} }
} }
impl<I> Iterator for SseStreamIter<I> impl<I> Iterator for SseStreamIter<I>
where where
I: Iterator, I: Iterator,

View file

@ -1,5 +1,5 @@
use crate::{ProviderId}; use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi};
use crate::apis::{OpenAIApi, AnthropicApi, AmazonBedrockApi, ApiDefinition}; use crate::ProviderId;
use std::fmt; use std::fmt;
/// Unified enum representing all supported API endpoints across providers /// Unified enum representing all supported API endpoints across providers
@ -20,8 +20,12 @@ pub enum SupportedUpstreamAPIs {
impl fmt::Display for SupportedAPIs { impl fmt::Display for SupportedAPIs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
SupportedAPIs::OpenAIChatCompletions(api) => write!(f, "OpenAI API ({})", api.endpoint()), SupportedAPIs::OpenAIChatCompletions(api) => {
SupportedAPIs::AnthropicMessagesAPI(api) => write!(f, "Anthropic API ({})", api.endpoint()), write!(f, "OpenAI API ({})", api.endpoint())
}
SupportedAPIs::AnthropicMessagesAPI(api) => {
write!(f, "Anthropic API ({})", api.endpoint())
}
} }
} }
} }
@ -48,81 +52,81 @@ impl SupportedAPIs {
} }
} }
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str, is_streaming: bool) -> String { pub fn target_endpoint_for_provider(
&self,
provider_id: &ProviderId,
request_path: &str,
model_id: &str,
is_streaming: bool,
) -> String {
let default_endpoint = "/v1/chat/completions".to_string(); let default_endpoint = "/v1/chat/completions".to_string();
match self { match self {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => { SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
match provider_id { ProviderId::Anthropic => "/v1/messages".to_string(),
ProviderId::Anthropic => "/v1/messages".to_string(), ProviderId::AmazonBedrock => {
ProviderId::AmazonBedrock => { if request_path.starts_with("/v1/") && !is_streaming {
if request_path.starts_with("/v1/") && !is_streaming { format!("/model/{}/converse", model_id)
} else if request_path.starts_with("/v1/") && is_streaming {
format!("/model/{}/converse-stream", model_id)
} else {
default_endpoint
}
}
_ => default_endpoint,
},
_ => match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
format!("/openai{}", request_path)
} else {
default_endpoint
}
}
ProviderId::Zhipu => {
if request_path.starts_with("/v1/") {
"/api/paas/v4/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::Qwen => {
if request_path.starts_with("/v1/") {
"/compatible-mode/v1/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
} else {
default_endpoint
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
"/v1beta/openai/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") {
if !is_streaming {
format!("/model/{}/converse", model_id) format!("/model/{}/converse", model_id)
} else if request_path.starts_with("/v1/") && is_streaming { } else {
format!("/model/{}/converse-stream", model_id) format!("/model/{}/converse-stream", model_id)
} else {
default_endpoint
} }
} else {
default_endpoint
} }
_ => default_endpoint,
} }
} _ => default_endpoint,
_ => { },
match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
format!("/openai{}", request_path)
} else {
default_endpoint
}
}
ProviderId::Zhipu => {
if request_path.starts_with("/v1/") {
"/api/paas/v4/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::Qwen => {
if request_path.starts_with("/v1/") {
"/compatible-mode/v1/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
} else {
default_endpoint
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
"/v1beta/openai/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") {
if !is_streaming {
format!("/model/{}/converse", model_id)
} else {
format!("/model/{}/converse-stream", model_id)
}
} else {
default_endpoint
}
}
_ => default_endpoint,
}
}
} }
} }
} }
/// Get all supported endpoint paths /// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> { pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new(); let mut endpoints = Vec::new();
@ -164,7 +168,6 @@ mod tests {
// Anthropic endpoints // Anthropic endpoints
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some()); assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
// Unsupported endpoints // Unsupported endpoints
assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some()); assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some()); assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some());
@ -177,7 +180,6 @@ mod tests {
assert_eq!(endpoints.len(), 2); // We have 2 APIs defined assert_eq!(endpoints.len(), 2); // We have 2 APIs defined
assert!(endpoints.contains(&"/v1/chat/completions")); assert!(endpoints.contains(&"/v1/chat/completions"));
assert!(endpoints.contains(&"/v1/messages")); assert!(endpoints.contains(&"/v1/messages"));
} }
#[test] #[test]
@ -203,14 +205,25 @@ mod tests {
// All OpenAI endpoints should be in the result // All OpenAI endpoints should be in the result
for endpoint in openai_endpoints { for endpoint in openai_endpoints {
assert!(endpoints.contains(&endpoint), "Missing OpenAI endpoint: {}", endpoint); assert!(
endpoints.contains(&endpoint),
"Missing OpenAI endpoint: {}",
endpoint
);
} }
// All Anthropic endpoints should be in the result // All Anthropic endpoints should be in the result
for endpoint in anthropic_endpoints { for endpoint in anthropic_endpoints {
assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint); assert!(
endpoints.contains(&endpoint),
"Missing Anthropic endpoint: {}",
endpoint
);
} }
// Total should match // Total should match
assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len()); assert_eq!(
endpoints.len(),
OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len()
);
} }
} }

View file

@ -1,9 +1,9 @@
pub mod endpoints;
pub mod lib; pub mod lib;
pub mod transformer; pub mod transformer;
pub mod endpoints;
// Re-export the main items for easier access // Re-export the main items for easier access
pub use endpoints::{identify_provider, SupportedAPIs};
pub use lib::*; pub use lib::*;
pub use endpoints::{SupportedAPIs, identify_provider};
// Note: transformer module contains TryFrom trait implementations that are automatically available // Note: transformer module contains TryFrom trait implementations that are automatically available

View file

@ -1,6 +1,5 @@
// Re-export new transformation modules for backward compatibility // Re-export new transformation modules for backward compatibility
//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING //KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING
// ============================================================================ // ============================================================================
@ -9,10 +8,10 @@
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json::json;
use crate::apis::anthropic::*; use crate::apis::anthropic::*;
use crate::apis::openai::*; use crate::apis::openai::*;
use crate::transforms::*; use crate::transforms::*;
use serde_json::json;
type AnthropicMessagesRequest = MessagesRequest; type AnthropicMessagesRequest = MessagesRequest;
#[test] #[test]
@ -81,11 +80,20 @@ mod tests {
// Check key fields are preserved // Check key fields are preserved
assert_eq!(original_anthropic.model, roundtrip_anthropic.model); assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
assert_eq!(original_anthropic.max_tokens, roundtrip_anthropic.max_tokens); assert_eq!(
assert_eq!(original_anthropic.temperature, roundtrip_anthropic.temperature); original_anthropic.max_tokens,
roundtrip_anthropic.max_tokens
);
assert_eq!(
original_anthropic.temperature,
roundtrip_anthropic.temperature
);
assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p); assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p);
assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream); assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream);
assert_eq!(original_anthropic.messages.len(), roundtrip_anthropic.messages.len()); assert_eq!(
original_anthropic.messages.len(),
roundtrip_anthropic.messages.len()
);
} }
#[test] #[test]
@ -229,7 +237,10 @@ mod tests {
let tool_calls = choice.delta.tool_calls.as_ref().unwrap(); let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, Some("call_123".to_string())); assert_eq!(tool_calls[0].id, Some("call_123".to_string()));
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string())); assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("get_weather".to_string())
);
} }
#[test] #[test]
@ -249,7 +260,10 @@ mod tests {
let tool_calls = choice.delta.tool_calls.as_ref().unwrap(); let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.as_ref().unwrap().arguments, Some(r#"{"location": "San Francisco"#.to_string())); assert_eq!(
tool_calls[0].function.as_ref().unwrap().arguments,
Some(r#"{"location": "San Francisco"#.to_string())
);
} }
#[test] #[test]
@ -412,7 +426,10 @@ mod tests {
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match anthropic_event { match anthropic_event {
MessagesStreamEvent::ContentBlockStart { index, content_block } => { MessagesStreamEvent::ContentBlockStart {
index,
content_block,
} => {
assert_eq!(index, 0); assert_eq!(index, 0);
match content_block { match content_block {
MessagesContentBlock::ToolUse { id, name, .. } => { MessagesContentBlock::ToolUse { id, name, .. } => {
@ -555,16 +572,28 @@ mod tests {
// Verify tool start // Verify tool start
let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap(); let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls[0].id, Some("call_weather".to_string())); assert_eq!(tool_calls[0].id, Some("call_weather".to_string()));
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string())); assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("get_weather".to_string())
);
// Verify argument deltas // Verify argument deltas
let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0] let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0]
.function.as_ref().unwrap().arguments; .function
.as_ref()
.unwrap()
.arguments;
assert_eq!(args1, &Some(r#"{"location": "#.to_string())); assert_eq!(args1, &Some(r#"{"location": "#.to_string()));
let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0] let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0]
.function.as_ref().unwrap().arguments; .function
assert_eq!(args2, &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())); .as_ref()
.unwrap()
.arguments;
assert_eq!(
args2,
&Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())
);
} }
#[test] #[test]
@ -592,14 +621,23 @@ mod tests {
}; };
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap(); let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.choices[0].finish_reason, Some(expected_openai_reason)); assert_eq!(
openai_resp.choices[0].finish_reason,
Some(expected_openai_reason)
);
// Test reverse conversion // Test reverse conversion
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap(); let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match roundtrip_event { match roundtrip_event {
MessagesStreamEvent::MessageDelta { delta, .. } => { MessagesStreamEvent::MessageDelta { delta, .. } => {
// Note: Some precision may be lost in roundtrip due to mapping differences // Note: Some precision may be lost in roundtrip due to mapping differences
assert!(matches!(delta.stop_reason, MessagesStopReason::EndTurn | MessagesStopReason::MaxTokens | MessagesStopReason::ToolUse | MessagesStopReason::StopSequence)); assert!(matches!(
delta.stop_reason,
MessagesStopReason::EndTurn
| MessagesStopReason::MaxTokens
| MessagesStopReason::ToolUse
| MessagesStopReason::StopSequence
));
} }
_ => panic!("Expected MessageDelta after roundtrip"), _ => panic!("Expected MessageDelta after roundtrip"),
} }
@ -632,7 +670,8 @@ mod tests {
}; };
// Should convert to Ping when no meaningful content // Should convert to Ping when no meaningful content
let anthropic_event: MessagesStreamEvent = openai_resp_with_missing_data.try_into().unwrap(); let anthropic_event: MessagesStreamEvent =
openai_resp_with_missing_data.try_into().unwrap();
assert!(matches!(anthropic_event, MessagesStreamEvent::Ping)); assert!(matches!(anthropic_event, MessagesStreamEvent::Ping));
} }

View file

@ -1,24 +1,25 @@
//! hermesllm: A library for translating LLM API requests and responses //! hermesllm: A library for translating LLM API requests and responses
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats. //! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
pub mod providers;
pub mod apis; pub mod apis;
pub mod clients; pub mod clients;
pub mod providers;
pub mod transforms; pub mod transforms;
// Re-export important types and traits // Re-export important types and traits
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
pub use apis::sse::{SseEvent, SseStreamIter};
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage}; pub use apis::sse::{SseEvent, SseStreamIter};
pub use providers::id::ProviderId;
pub use aws_smithy_eventstream::frame::DecodedFrame; pub use aws_smithy_eventstream::frame::DecodedFrame;
pub use providers::id::ProviderId;
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use providers::response::{
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
ProviderStreamResponseType, TokenUsage,
};
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings //TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const MESSAGES_PATH: &str = "/v1/messages"; pub const MESSAGES_PATH: &str = "/v1/messages";
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::clients::endpoints::SupportedUpstreamAPIs;
@ -36,49 +37,51 @@ mod tests {
#[test] #[test]
fn test_provider_streaming_response() { fn test_provider_streaming_response() {
// Test streaming response parsing with sample SSE data // Test streaming response parsing with sample SSE data
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
data: [DONE] data: [DONE]
"#; "#;
use crate::clients::endpoints::SupportedAPIs; use crate::clients::endpoints::SupportedAPIs;
let client_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); let client_api =
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions); SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
let upstream_api =
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
// Test the new simplified architecture - create SseStreamIter directly // Test the new simplified architecture - create SseStreamIter directly
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes()); let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
assert!(sse_iter.is_ok()); assert!(sse_iter.is_ok());
let mut streaming_iter = sse_iter.unwrap(); let mut streaming_iter = sse_iter.unwrap();
// Test that we can iterate over SseEvents // Test that we can iterate over SseEvents
let first_event = streaming_iter.next(); let first_event = streaming_iter.next();
assert!(first_event.is_some()); assert!(first_event.is_some());
let sse_event = first_event.unwrap(); let sse_event = first_event.unwrap();
// Test SseEvent properties // Test SseEvent properties
assert!(!sse_event.is_done()); assert!(!sse_event.is_done());
assert!(sse_event.data.as_ref().unwrap().contains("Hello")); assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
// Test that we can parse the event into a provider stream response // Test that we can parse the event into a provider stream response
let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api)); let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api));
if let Err(e) = &transformed_event { if let Err(e) = &transformed_event {
println!("Transform error: {:?}", e); println!("Transform error: {:?}", e);
} }
assert!(transformed_event.is_ok()); assert!(transformed_event.is_ok());
let transformed_event = transformed_event.unwrap(); let transformed_event = transformed_event.unwrap();
let provider_response = transformed_event.provider_response(); let provider_response = transformed_event.provider_response();
assert!(provider_response.is_ok()); assert!(provider_response.is_ok());
let stream_response = provider_response.unwrap(); let stream_response = provider_response.unwrap();
assert_eq!(stream_response.content_delta(), Some("Hello")); assert_eq!(stream_response.content_delta(), Some("Hello"));
assert!(!stream_response.is_final()); assert!(!stream_response.is_final());
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE]) // Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
let final_event = streaming_iter.next(); let final_event = streaming_iter.next();
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE] assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
} }
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses. /// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
@ -95,15 +98,15 @@ mod tests {
/// all complete frames in the buffer. /// all complete frames in the buffer.
#[test] #[test]
fn test_amazon_bedrock_streaming_response() { fn test_amazon_bedrock_streaming_response() {
use aws_smithy_eventstream::frame::{MessageFrameDecoder, DecodedFrame}; use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use bytes::{Buf, BytesMut}; use bytes::{Buf, BytesMut};
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
// Read the response.hex file from tests/e2e directory // Read the response.hex file from tests/e2e directory
// Use absolute path to avoid cargo test working directory issues // Use absolute path to avoid cargo test working directory issues
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) let test_file =
.join("../../tests/e2e/response.hex"); PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
let response_data = fs::read(&test_file) let response_data = fs::read(&test_file)
.unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e)); .unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e));
@ -134,8 +137,13 @@ mod tests {
simulated_network_buffer.extend_from_slice(chunk); simulated_network_buffer.extend_from_slice(chunk);
offset = end; offset = end;
println!("📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)", println!(
chunk_num, chunk.len(), simulated_network_buffer.len(), simulated_network_buffer.remaining()); "📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)",
chunk_num,
chunk.len(),
simulated_network_buffer.len(),
simulated_network_buffer.remaining()
);
// Try to decode all complete frames from buffer // Try to decode all complete frames from buffer
// The Buf trait tracks position automatically! // The Buf trait tracks position automatically!
@ -146,11 +154,16 @@ mod tests {
frame_count += 1; frame_count += 1;
let consumed = bytes_before - simulated_network_buffer.remaining(); let consumed = bytes_before - simulated_network_buffer.remaining();
println!(" ✅ Frame {}: decoded ({} bytes, {} bytes remaining)", println!(
frame_count, consumed, simulated_network_buffer.remaining()); " ✅ Frame {}: decoded ({} bytes, {} bytes remaining)",
frame_count,
consumed,
simulated_network_buffer.remaining()
);
// Get event type from headers // Get event type from headers
let event_type = message.headers() let event_type = message
.headers()
.iter() .iter()
.find(|h| h.name().as_str() == ":event-type") .find(|h| h.name().as_str() == ":event-type")
.and_then(|h| { .and_then(|h| {
@ -167,7 +180,9 @@ mod tests {
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(payload) { if let Ok(json) = serde_json::from_slice::<serde_json::Value>(payload) {
if event_type.as_deref() == Some("contentBlockDelta") { if event_type.as_deref() == Some("contentBlockDelta") {
if let Some(delta) = json.get("delta") { if let Some(delta) = json.get("delta") {
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) { if let Some(text) =
delta.get("text").and_then(|t| t.as_str())
{
println!(" 📝 Content: \"{}\"", text); println!(" 📝 Content: \"{}\"", text);
content_chunks.push(text.to_string()); content_chunks.push(text.to_string());
} }
@ -178,7 +193,10 @@ mod tests {
} }
Ok(DecodedFrame::Incomplete) => { Ok(DecodedFrame::Incomplete) => {
// Not enough data for a complete frame - need more chunks // Not enough data for a complete frame - need more chunks
println!(" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n", simulated_network_buffer.remaining()); println!(
" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n",
simulated_network_buffer.remaining()
);
break; // Wait for next chunk break; // Wait for next chunk
} }
Err(e) => { Err(e) => {
@ -193,7 +211,10 @@ mod tests {
println!(" Total chunks received: {}", chunk_num); println!(" Total chunks received: {}", chunk_num);
println!(" Total frames decoded: {}", frame_count); println!(" Total frames decoded: {}", frame_count);
println!(" Total content chunks: {}", content_chunks.len()); println!(" Total content chunks: {}", content_chunks.len());
println!(" Final buffer remaining: {} bytes", simulated_network_buffer.remaining()); println!(
" Final buffer remaining: {} bytes",
simulated_network_buffer.remaining()
);
if !content_chunks.is_empty() { if !content_chunks.is_empty() {
let full_text = content_chunks.join(""); let full_text = content_chunks.join("");
@ -207,6 +228,11 @@ mod tests {
assert!(frame_count > 0, "Should decode at least one frame"); assert!(frame_count > 0, "Should decode at least one frame");
// Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame // Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame
assert_eq!(simulated_network_buffer.remaining(), 0, "All bytes should be consumed, {} bytes remain", simulated_network_buffer.remaining()); assert_eq!(
simulated_network_buffer.remaining(),
0,
"All bytes should be consumed, {} bytes remain",
simulated_network_buffer.remaining()
);
} }
} }

View file

@ -1,6 +1,6 @@
use std::fmt::Display; use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs}; use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
use crate::apis::{OpenAIApi, AnthropicApi, AmazonBedrockApi}; use std::fmt::Display;
/// Provider identifier enum - simple enum for identifying providers /// Provider identifier enum - simple enum for identifying providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -49,60 +49,77 @@ impl From<&str> for ProviderId {
impl ProviderId { impl ProviderId {
/// Given a client API, return the compatible upstream API for this provider /// Given a client API, return the compatible upstream API for this provider
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs, is_streaming: bool) -> SupportedUpstreamAPIs { pub fn compatible_api_for_client(
&self,
client_api: &SupportedAPIs,
is_streaming: bool,
) -> SupportedUpstreamAPIs {
match (self, client_api) { match (self, client_api) {
// Claude/Anthropic providers natively support Anthropic APIs // Claude/Anthropic providers natively support Anthropic APIs
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), (ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
}
(
ProviderId::Anthropic,
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// OpenAI-compatible providers only support OpenAI chat completions // OpenAI-compatible providers only support OpenAI chat completions
(ProviderId::OpenAI (
| ProviderId::Groq ProviderId::OpenAI
| ProviderId::Mistral | ProviderId::Groq
| ProviderId::Deepseek | ProviderId::Mistral
| ProviderId::Arch | ProviderId::Deepseek
| ProviderId::Gemini | ProviderId::Arch
| ProviderId::GitHub | ProviderId::Gemini
| ProviderId::AzureOpenAI | ProviderId::GitHub
| ProviderId::XAI | ProviderId::AzureOpenAI
| ProviderId::TogetherAI | ProviderId::XAI
| ProviderId::Ollama | ProviderId::TogetherAI
| ProviderId::Moonshotai | ProviderId::Ollama
| ProviderId::Zhipu | ProviderId::Moonshotai
| ProviderId::Qwen, | ProviderId::Zhipu
SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), | ProviderId::Qwen,
SupportedAPIs::AnthropicMessagesAPI(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI (
| ProviderId::Groq ProviderId::OpenAI
| ProviderId::Mistral | ProviderId::Groq
| ProviderId::Deepseek | ProviderId::Mistral
| ProviderId::Arch | ProviderId::Deepseek
| ProviderId::Gemini | ProviderId::Arch
| ProviderId::GitHub | ProviderId::Gemini
| ProviderId::AzureOpenAI | ProviderId::GitHub
| ProviderId::XAI | ProviderId::AzureOpenAI
| ProviderId::TogetherAI | ProviderId::XAI
| ProviderId::Ollama | ProviderId::TogetherAI
| ProviderId::Moonshotai | ProviderId::Ollama
| ProviderId::Zhipu | ProviderId::Moonshotai
| ProviderId::Qwen, | ProviderId::Zhipu
SupportedAPIs::OpenAIChatCompletions(_)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), | ProviderId::Qwen,
SupportedAPIs::OpenAIChatCompletions(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// Amazon Bedrock natively supports Bedrock APIs // Amazon Bedrock natively supports Bedrock APIs
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => { (ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
if is_streaming { if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(AmazonBedrockApi::ConverseStream) SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
)
} else { } else {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
} }
}, }
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => { (ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
if is_streaming { if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(AmazonBedrockApi::ConverseStream) SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
)
} else { } else {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
} }
}, }
} }
} }
} }

View file

@ -8,5 +8,5 @@ pub mod request;
pub mod response; pub mod response;
pub use id::ProviderId; pub use id::ProviderId;
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ; pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage }; pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage};

View file

@ -1,14 +1,14 @@
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::anthropic::MessagesRequest; use crate::apis::anthropic::MessagesRequest;
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest}; use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
use crate::clients::endpoints::SupportedAPIs; use crate::clients::endpoints::SupportedAPIs;
use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::clients::endpoints::SupportedUpstreamAPIs;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::collections::HashMap;
#[derive(Clone)] #[derive(Clone)]
pub enum ProviderRequestType { pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest), ChatCompletionsRequest(ChatCompletionsRequest),
@ -124,15 +124,18 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
// Use SupportedApi to determine the appropriate request type // Use SupportedApi to determine the appropriate request type
match client_api { match client_api {
SupportedAPIs::OpenAIChatCompletions(_) => { SupportedAPIs::OpenAIChatCompletions(_) => {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) let chat_completion_request: ChatCompletionsRequest =
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; ChatCompletionsRequest::try_from(bytes)
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
SupportedAPIs::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request)) Ok(ProviderRequestType::ChatCompletionsRequest(
} chat_completion_request,
))
}
SupportedAPIs::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
}
} }
} }
} }
@ -141,37 +144,57 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType { impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType {
type Error = ProviderRequestError; type Error = ProviderRequestError;
fn try_from((client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> { fn try_from(
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
match (client_request, upstream_api) { match (client_request, upstream_api) {
// Same API - no conversion needed, just clone the reference // Same API - no conversion needed, just clone the reference
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { (
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) ProviderRequestType::ChatCompletionsRequest(chat_req),
} SupportedUpstreamAPIs::OpenAIChatCompletions(_),
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { ) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
Ok(ProviderRequestType::MessagesRequest(messages_req)) (
} ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
// Cross-API conversion - cloning is necessary for transformation // Cross-API conversion - cloning is necessary for transformation
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { (
let messages_req = MessagesRequest::try_from(chat_req) ProviderRequestType::ChatCompletionsRequest(chat_req),
.map_err(|e| ProviderRequestError { SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e), ) => {
source: Some(Box::new(e)) let messages_req =
MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to MessagesRequest: {}",
e
),
source: Some(Box::new(e)),
})?; })?;
Ok(ProviderRequestType::MessagesRequest(messages_req)) Ok(ProviderRequestType::MessagesRequest(messages_req))
} }
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { (
let chat_req = ChatCompletionsRequest::try_from(messages_req) ProviderRequestType::MessagesRequest(messages_req),
.map_err(|e| ProviderRequestError { SupportedUpstreamAPIs::OpenAIChatCompletions(_),
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e), ) => {
source: Some(Box::new(e)) let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| {
})?; ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
} }
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock // Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { (
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
let bedrock_req = ConverseRequest::try_from(chat_req) let bedrock_req = ConverseRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError { .map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
@ -180,7 +203,10 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
Ok(ProviderRequestType::BedrockConverse(bedrock_req)) Ok(ProviderRequestType::BedrockConverse(bedrock_req))
} }
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { (
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
let bedrock_req = ConverseStreamRequest::try_from(chat_req) let bedrock_req = ConverseStreamRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError { .map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e), message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
@ -188,20 +214,33 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
})?; })?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req)) Ok(ProviderRequestType::BedrockConverse(bedrock_req))
} }
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => { (
let bedrock_req = ConverseRequest::try_from(messages_req) ProviderRequestType::MessagesRequest(messages_req),
.map_err(|e| ProviderRequestError { SupportedUpstreamAPIs::AmazonBedrockConverse(_),
message: format!("Failed to convert MessagesRequest to Amazon Bedrock request: {}", e), ) => {
source: Some(Box::new(e)) let bedrock_req =
ConverseRequest::try_from(messages_req).map_err(|e| ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
e
),
source: Some(Box::new(e)),
})?; })?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req)) Ok(ProviderRequestType::BedrockConverse(bedrock_req))
} }
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => { (
let bedrock_req = ConverseStreamRequest::try_from(messages_req) ProviderRequestType::MessagesRequest(messages_req),
.map_err(|e| ProviderRequestError { SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
message: format!("Failed to convert MessagesRequest to Amazon Bedrock request: {}", e), ) => {
source: Some(Box::new(e)) let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
})?; ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req)) Ok(ProviderRequestType::BedrockConverse(bedrock_req))
} }
@ -213,13 +252,10 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
(ProviderRequestType::BedrockConverseStream(_), _) => { (ProviderRequestType::BedrockConverseStream(_), _) => {
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet") todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
} }
} }
} }
} }
/// Error types for provider operations /// Error types for provider operations
#[derive(Debug)] #[derive(Debug)]
pub struct ProviderRequestError { pub struct ProviderRequestError {
@ -235,19 +271,20 @@ impl fmt::Display for ProviderRequestError {
impl Error for ProviderRequestError { impl Error for ProviderRequestError {
fn source(&self) -> Option<&(dyn Error + 'static)> { fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn Error + 'static))
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::anthropic::AnthropicApi::Messages; use crate::apis::anthropic::AnthropicApi::Messages;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest}; use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::clients::endpoints::SupportedAPIs;
use crate::transforms::lib::ExtractText; use crate::transforms::lib::ExtractText;
use serde_json::json; use serde_json::json;
@ -268,7 +305,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => { ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "gpt-4"); assert_eq!(r.model, "gpt-4");
assert_eq!(r.messages.len(), 2); assert_eq!(r.messages.len(), 2);
}, }
_ => panic!("Expected ChatCompletionsRequest variant"), _ => panic!("Expected ChatCompletionsRequest variant"),
} }
} }
@ -291,7 +328,7 @@ mod tests {
ProviderRequestType::MessagesRequest(r) => { ProviderRequestType::MessagesRequest(r) => {
assert_eq!(r.model, "claude-3-sonnet"); assert_eq!(r.model, "claude-3-sonnet");
assert_eq!(r.messages.len(), 1); assert_eq!(r.messages.len(), 1);
}, }
_ => panic!("Expected MessagesRequest variant"), _ => panic!("Expected MessagesRequest variant"),
} }
} }
@ -313,7 +350,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => { ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "gpt-4"); assert_eq!(r.model, "gpt-4");
assert_eq!(r.messages.len(), 2); assert_eq!(r.messages.len(), 2);
}, }
_ => panic!("Expected ChatCompletionsRequest variant"), _ => panic!("Expected ChatCompletionsRequest variant"),
} }
} }
@ -337,7 +374,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => { ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "claude-3-sonnet"); assert_eq!(r.model, "claude-3-sonnet");
assert_eq!(r.messages.len(), 1); assert_eq!(r.messages.len(), 1);
}, }
_ => panic!("Expected ChatCompletionsRequest variant"), _ => panic!("Expected ChatCompletionsRequest variant"),
} }
} }
@ -346,13 +383,15 @@ mod tests {
fn test_v1_messages_to_v1_chat_completions_roundtrip() { fn test_v1_messages_to_v1_chat_completions_roundtrip() {
let anthropic_req = AnthropicMessagesRequest { let anthropic_req = AnthropicMessagesRequest {
model: "claude-3-sonnet".to_string(), model: "claude-3-sonnet".to_string(),
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())), system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single(
messages: vec![ "You are a helpful assistant".to_string(),
crate::apis::anthropic::MessagesMessage { )),
role: crate::apis::anthropic::MessagesRole::User, messages: vec![crate::apis::anthropic::MessagesMessage {
content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()), role: crate::apis::anthropic::MessagesRole::User,
} content: crate::apis::anthropic::MessagesMessageContent::Single(
], "Hello!".to_string(),
),
}],
max_tokens: 128, max_tokens: 128,
container: None, container: None,
mcp_servers: None, mcp_servers: None,
@ -368,16 +407,27 @@ mod tests {
metadata: None, metadata: None,
}; };
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed"); let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone())
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed"); .expect("Anthropic->OpenAI conversion failed");
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req)
.expect("OpenAI->Anthropic conversion failed");
assert_eq!(anthropic_req.model, anthropic_req2.model); assert_eq!(anthropic_req.model, anthropic_req2.model);
// Compare system prompt text if present // Compare system prompt text if present
assert_eq!( assert_eq!(
anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }), anthropic_req.system.as_ref().and_then(|s| match s {
anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }) crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
_ => None,
}),
anthropic_req2.system.as_ref().and_then(|s| match s {
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
_ => None,
})
);
assert_eq!(
anthropic_req.messages[0].role,
anthropic_req2.messages[0].role
); );
assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role);
// Compare message content text if present // Compare message content text if present
assert_eq!( assert_eq!(
anthropic_req.messages[0].content.extract_text(), anthropic_req.messages[0].content.extract_text(),
@ -386,49 +436,54 @@ mod tests {
assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens); assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens);
} }
#[test] #[test]
fn test_v1_chat_completions_to_v1_messages_roundtrip() { fn test_v1_chat_completions_to_v1_messages_roundtrip() {
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest; use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent}; use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
let openai_req = ChatCompletionsRequest { let openai_req = ChatCompletionsRequest {
model: "gpt-4".to_string(), model: "gpt-4".to_string(),
messages: vec![ messages: vec![
Message { Message {
role: Role::System, role: Role::System,
content: MessageContent::Text("You are a helpful assistant".to_string()), content: MessageContent::Text("You are a helpful assistant".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None, tool_call_id: None,
}, },
Message { Message {
role: Role::User, role: Role::User,
content: MessageContent::Text("Hello!".to_string()), content: MessageContent::Text("Hello!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None, tool_call_id: None,
} },
], ],
temperature: Some(0.7), temperature: Some(0.7),
top_p: Some(1.0), top_p: Some(1.0),
max_tokens: Some(128), max_tokens: Some(128),
stream: Some(false), stream: Some(false),
stop: Some(vec!["\n".to_string()]), stop: Some(vec!["\n".to_string()]),
tools: None, tools: None,
tool_choice: None, tool_choice: None,
parallel_tool_calls: None, parallel_tool_calls: None,
..Default::default() ..Default::default()
}; };
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed"); let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone())
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed"); .expect("OpenAI->Anthropic conversion failed");
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req)
.expect("Anthropic->OpenAI conversion failed");
assert_eq!(openai_req.model, openai_req2.model); assert_eq!(openai_req.model, openai_req2.model);
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role); assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text()); assert_eq!(
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens openai_req.messages[0].content.extract_text(),
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens); openai_req2.messages[0].content.extract_text()
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens); );
assert_eq!(original_max_tokens, roundtrip_max_tokens); // After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
} let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
assert_eq!(original_max_tokens, roundtrip_max_tokens);
}
} }

View file

@ -1,18 +1,18 @@
use serde::Serialize; use serde::Serialize;
use std::convert::TryFrom;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::convert::TryFrom;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use crate::clients::endpoints::SupportedAPIs;
use crate::providers::id::ProviderId;
use crate::apis::sse::SseEvent;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::openai::ChatCompletionsStreamResponse;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::anthropic::MessagesResponse;
use crate::apis::amazon_bedrock::ConverseResponse; use crate::apis::amazon_bedrock::ConverseResponse;
use crate::apis::amazon_bedrock::ConverseStreamEvent; use crate::apis::amazon_bedrock::ConverseStreamEvent;
use crate::apis::anthropic::MessagesResponse;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::openai::ChatCompletionsStreamResponse;
use crate::apis::sse::SseEvent;
use crate::clients::endpoints::SupportedAPIs;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use crate::providers::id::ProviderId;
/// Trait for token usage information /// Trait for token usage information
pub trait TokenUsage { pub trait TokenUsage {
@ -34,7 +34,6 @@ pub enum ProviderStreamResponseType {
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
MessagesStreamEvent(MessagesStreamEvent), MessagesStreamEvent(MessagesStreamEvent),
ConverseStreamEvent(ConverseStreamEvent), ConverseStreamEvent(ConverseStreamEvent),
} }
pub trait ProviderResponse: Send + Sync { pub trait ProviderResponse: Send + Sync {
@ -43,7 +42,8 @@ pub trait ProviderResponse: Send + Sync {
/// Extract token counts for metrics /// Extract token counts for metrics
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) self.usage()
.map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
} }
} }
@ -74,7 +74,6 @@ pub trait ProviderStreamResponse: Send + Sync {
/// Get event type for SSE streaming (used by Anthropic) /// Get event type for SSE streaming (used by Anthropic)
fn event_type(&self) -> Option<&str>; fn event_type(&self) -> Option<&str>;
} }
impl ProviderStreamResponse for ProviderStreamResponseType { impl ProviderStreamResponse for ProviderStreamResponseType {
@ -135,59 +134,97 @@ impl Into<String> for ProviderStreamResponseType {
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
type Error = std::io::Error; type Error = std::io::Error;
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> { fn try_from(
(bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId),
) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api, false); let upstream_api = provider_id.compatible_api_for_client(client_api, false);
match (&upstream_api, client_api) { match (&upstream_api, client_api) {
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { (
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ChatCompletionsResponse(resp)) Ok(ProviderResponseType::ChatCompletionsResponse(resp))
} }
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { (
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let resp: MessagesResponse = serde_json::from_slice(bytes) let resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::MessagesResponse(resp)) Ok(ProviderResponseType::MessagesResponse(resp))
} }
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { (
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to OpenAI ChatCompletions format using the transformer // Transform to OpenAI ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into() let chat_resp: ChatCompletionsResponse =
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; anthropic_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
} }
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { (
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to Anthropic Messages format using the transformer // Transform to Anthropic Messages format using the transformer
let messages_resp: MessagesResponse = openai_resp.try_into() let messages_resp: MessagesResponse = openai_resp.try_into().map_err(|e| {
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::MessagesResponse(messages_resp)) Ok(ProviderResponseType::MessagesResponse(messages_resp))
} }
// Amazon Bedrock transformations // Amazon Bedrock transformations
(SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::OpenAIChatCompletions(_)) => { (
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to OpenAI ChatCompletions format using the transformer // Transform to OpenAI ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into() let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| {
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
} }
(SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { (
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to Anthropic Messages format using the transformer // Transform to Anthropic Messages format using the transformer
let messages_resp: MessagesResponse = bedrock_resp.try_into() let messages_resp: MessagesResponse = bedrock_resp.try_into().map_err(|e| {
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::MessagesResponse(messages_resp)) Ok(ProviderResponseType::MessagesResponse(messages_resp))
} }
_ => { _ => Err(std::io::Error::new(
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation")) std::io::ErrorKind::InvalidData,
} "Unsupported API combination for response transformation",
)),
} }
} }
} }
@ -196,55 +233,86 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType {
type Error = Box<dyn std::error::Error + Send + Sync>; type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> { fn try_from(
(bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
// Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion
if bytes == b"[DONE]" && matches!(client_api, SupportedAPIs::AnthropicMessagesAPI(_)) { if bytes == b"[DONE]" && matches!(client_api, SupportedAPIs::AnthropicMessagesAPI(_)) {
return Ok(ProviderStreamResponseType::MessagesStreamEvent( return Ok(ProviderStreamResponseType::MessagesStreamEvent(
crate::apis::anthropic::MessagesStreamEvent::MessageStop crate::apis::anthropic::MessagesStreamEvent::MessageStop,
)); ));
} }
match (upstream_api, client_api) { match (upstream_api, client_api) {
// OpenAI upstream // OpenAI upstream
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { (
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let resp = serde_json::from_slice(bytes)?; let resp = serde_json::from_slice(bytes)?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp)) Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
resp,
))
} }
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { (
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse =
serde_json::from_slice(bytes)?;
let anthropic_resp = openai_resp.try_into()?; let anthropic_resp = openai_resp.try_into()?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_resp)) Ok(ProviderStreamResponseType::MessagesStreamEvent(
anthropic_resp,
))
} }
// Anthropic upstream // Anthropic upstream
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { (
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let resp = serde_json::from_slice(bytes)?; let resp = serde_json::from_slice(bytes)?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(resp)) Ok(ProviderStreamResponseType::MessagesStreamEvent(resp))
} }
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { (
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?; SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent =
serde_json::from_slice(bytes)?;
let openai_resp = anthropic_resp.try_into()?; let openai_resp = anthropic_resp.try_into()?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_resp)) Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
openai_resp,
))
} }
// Amazon Bedrock ConverseStream upstream // Amazon Bedrock ConverseStream upstream
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { (
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = serde_json::from_slice(bytes)?; SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent =
serde_json::from_slice(bytes)?;
let anthropic_resp = bedrock_resp.try_into()?; let anthropic_resp = bedrock_resp.try_into()?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_resp)) Ok(ProviderStreamResponseType::MessagesStreamEvent(
} anthropic_resp,
_ => { ))
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation").into())
} }
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unsupported API combination for response transformation",
)
.into()),
} }
} }
} }
// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response // TryFrom implementation to convert raw bytes to SseEvent with parsed provider response
impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent { impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
type Error = Box<dyn std::error::Error + Send + Sync>; type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> { fn try_from(
(sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
// Create a new transformed event based on the original // Create a new transformed event based on the original
let mut transformed_event = sse_event; let mut transformed_event = sse_event;
@ -252,7 +320,8 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
if transformed_event.data.is_some() { if transformed_event.data.is_some() {
let data_str = transformed_event.data.as_ref().unwrap(); let data_str = transformed_event.data.as_ref().unwrap();
let data_bytes = data_str.as_bytes(); let data_bytes = data_str.as_bytes();
let transformed_response: ProviderStreamResponseType = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?; let transformed_response: ProviderStreamResponseType =
ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
// Convert to SSE string explicitly to avoid type ambiguity // Convert to SSE string explicitly to avoid type ambiguity
let sse_string: String = transformed_response.clone().into(); let sse_string: String = transformed_response.clone().into();
@ -261,7 +330,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
} }
match (client_api, upstream_api) { match (client_api, upstream_api) {
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => { (
SupportedAPIs::AnthropicMessagesAPI(_),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => {
if let Some(provider_response) = &transformed_event.provider_stream_response { if let Some(provider_response) = &transformed_event.provider_stream_response {
if let Some(event_type) = provider_response.event_type() { if let Some(event_type) = provider_response.event_type() {
// This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s) // This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s)
@ -280,19 +352,18 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
// The sse_transform_buffer already contains the properly formatted MessageStart // The sse_transform_buffer already contains the properly formatted MessageStart
transformed_event.sse_transform_buffer = format!( transformed_event.sse_transform_buffer = format!(
"{}{}", "{}{}",
transformed_event.sse_transform_buffer, transformed_event.sse_transform_buffer, content_block_start_sse,
content_block_start_sse,
); );
} else if event_type == "message_delta" { } else if event_type == "message_delta" {
// Create ContentBlockStop event and format it using Into<String> // Create ContentBlockStop event and format it using Into<String>
let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 }; let content_block_stop =
MessagesStreamEvent::ContentBlockStop { index: 0 };
let content_block_stop_sse: String = content_block_stop.into(); let content_block_stop_sse: String = content_block_stop.into();
// Format as proper SSE: ContentBlockStop first, then MessageDelta // Format as proper SSE: ContentBlockStop first, then MessageDelta
transformed_event.sse_transform_buffer = format!( transformed_event.sse_transform_buffer = format!(
"{}{}", "{}{}",
content_block_stop_sse, content_block_stop_sse, transformed_event.sse_transform_buffer
transformed_event.sse_transform_buffer
); );
} }
// For other event types, the sse_transform_buffer already has the correct format from Into<String> // For other event types, the sse_transform_buffer already has the correct format from Into<String>
@ -301,7 +372,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
// This handles cases where the transformation might not produce a valid event type // This handles cases where the transformation might not produce a valid event type
} }
} }
(SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => { (
SupportedAPIs::OpenAIChatCompletions(_),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => {
if transformed_event.is_event_only() && transformed_event.event.is_some() { if transformed_event.is_event_only() && transformed_event.event.is_some() {
transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI
} }
@ -316,32 +390,56 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
} }
// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType // TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType
impl TryFrom<(&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { impl
TryFrom<(
&aws_smithy_eventstream::frame::DecodedFrame,
&SupportedAPIs,
&SupportedUpstreamAPIs,
)> for ProviderStreamResponseType
{
type Error = Box<dyn std::error::Error + Send + Sync>; type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((frame, client_api, upstream_api): (&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> { fn try_from(
(frame, client_api, upstream_api): (
&aws_smithy_eventstream::frame::DecodedFrame,
&SupportedAPIs,
&SupportedUpstreamAPIs,
),
) -> Result<Self, Self::Error> {
use aws_smithy_eventstream::frame::DecodedFrame; use aws_smithy_eventstream::frame::DecodedFrame;
match frame { match frame {
DecodedFrame::Complete(_) => { DecodedFrame::Complete(_) => {
// We have a complete frame - parse it based on upstream API // We have a complete frame - parse it based on upstream API
match (upstream_api, client_api) { match (upstream_api, client_api) {
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { (
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
// Parse the DecodedFrame into ConverseStreamEvent // Parse the DecodedFrame into ConverseStreamEvent
let bedrock_event = crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; let bedrock_event =
let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = bedrock_event.try_into()?; crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
let anthropic_event: crate::apis::anthropic::MessagesStreamEvent =
bedrock_event.try_into()?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_event)) Ok(ProviderStreamResponseType::MessagesStreamEvent(
anthropic_event,
))
} }
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::OpenAIChatCompletions(_)) => { (
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
// Parse the DecodedFrame into ConverseStreamEvent // Parse the DecodedFrame into ConverseStreamEvent
let bedrock_event = crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?; let bedrock_event =
let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_event.try_into()?; crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event)) let openai_event: crate::apis::openai::ChatCompletionsStreamResponse =
} bedrock_event.try_into()?;
_ => { Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
Err("Unsupported API combination for event-stream decoding".into()) openai_event,
))
} }
_ => Err("Unsupported API combination for event-stream decoding".into()),
} }
} }
DecodedFrame::Incomplete => { DecodedFrame::Incomplete => {
@ -351,15 +449,12 @@ impl TryFrom<(&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &Sup
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ProviderResponseError { pub struct ProviderResponseError {
pub message: String, pub message: String,
pub source: Option<Box<dyn Error + Send + Sync>>, pub source: Option<Box<dyn Error + Send + Sync>>,
} }
impl fmt::Display for ProviderResponseError { impl fmt::Display for ProviderResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Provider response error: {}", self.message) write!(f, "Provider response error: {}", self.message)
@ -368,22 +463,23 @@ impl fmt::Display for ProviderResponseError {
impl Error for ProviderResponseError { impl Error for ProviderResponseError {
fn source(&self) -> Option<&(dyn Error + 'static)> { fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn Error + 'static))
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::apis::sse::SseStreamIter;
use crate::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; use crate::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::apis::sse::SseStreamIter;
use crate::clients::endpoints::SupportedAPIs; use crate::clients::endpoints::SupportedAPIs;
use crate::providers::id::ProviderId; use crate::providers::id::ProviderId;
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use serde_json::json; use serde_json::json;
#[test] #[test]
fn test_openai_response_from_bytes() { fn test_openai_response_from_bytes() {
let resp = json!({ let resp = json!({
@ -402,13 +498,17 @@ mod tests {
"system_fingerprint": null "system_fingerprint": null
}); });
let bytes = serde_json::to_vec(&resp).unwrap(); let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI)); let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
&ProviderId::OpenAI,
));
assert!(result.is_ok()); assert!(result.is_ok());
match result.unwrap() { match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => { ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "gpt-4"); assert_eq!(r.model, "gpt-4");
assert_eq!(r.choices.len(), 1); assert_eq!(r.choices.len(), 1);
}, }
_ => panic!("Expected ChatCompletionsResponse variant"), _ => panic!("Expected ChatCompletionsResponse variant"),
} }
} }
@ -427,13 +527,17 @@ mod tests {
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 } "usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
}); });
let bytes = serde_json::to_vec(&resp).unwrap(); let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic)); let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
&ProviderId::Anthropic,
));
assert!(result.is_ok()); assert!(result.is_ok());
match result.unwrap() { match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => { ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229"); assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.content.len(), 1); assert_eq!(r.content.len(), 1);
}, }
_ => panic!("Expected MessagesResponse variant"), _ => panic!("Expected MessagesResponse variant"),
} }
} }
@ -457,14 +561,18 @@ mod tests {
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 } "usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
}); });
let bytes = serde_json::to_vec(&resp).unwrap(); let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI)); let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
&ProviderId::OpenAI,
));
assert!(result.is_ok()); assert!(result.is_ok());
match result.unwrap() { match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => { ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "gpt-4"); assert_eq!(r.model, "gpt-4");
assert_eq!(r.usage.input_tokens, 10); assert_eq!(r.usage.input_tokens, 10);
assert_eq!(r.usage.output_tokens, 25); assert_eq!(r.usage.output_tokens, 25);
}, }
_ => panic!("Expected MessagesResponse variant"), _ => panic!("Expected MessagesResponse variant"),
} }
} }
@ -495,14 +603,18 @@ mod tests {
} }
}); });
let bytes = serde_json::to_vec(&resp).unwrap(); let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic)); let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
&ProviderId::Anthropic,
));
assert!(result.is_ok()); assert!(result.is_ok());
match result.unwrap() { match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => { ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229"); assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.usage.prompt_tokens, 10); assert_eq!(r.usage.prompt_tokens, 10);
assert_eq!(r.usage.completion_tokens, 25); assert_eq!(r.usage.completion_tokens, 25);
}, }
_ => panic!("Expected ChatCompletionsResponse variant"), _ => panic!("Expected ChatCompletionsResponse variant"),
} }
} }
@ -514,11 +626,17 @@ mod tests {
let event: Result<SseEvent, _> = line.parse(); let event: Result<SseEvent, _> = line.parse();
assert!(event.is_ok()); assert!(event.is_ok());
let event = event.unwrap(); let event = event.unwrap();
assert_eq!(event.data, Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string())); assert_eq!(
event.data,
Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string())
);
// Test conversion back to line using Display trait // Test conversion back to line using Display trait
let wire_format = event.to_string(); let wire_format = event.to_string();
assert_eq!(wire_format, "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"); assert_eq!(
wire_format,
"data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"
);
// Test [DONE] marker - should be valid SSE event // Test [DONE] marker - should be valid SSE event
let done_line = "data: [DONE]"; let done_line = "data: [DONE]";
@ -550,10 +668,12 @@ mod tests {
event: None, event: None,
raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"} raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"}
"#.to_string(), "#
.to_string(),
sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"} sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"}
"#.to_string(), "#
.to_string(),
provider_stream_response: None, provider_stream_response: None,
}; };
@ -590,7 +710,8 @@ mod tests {
data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()), data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()),
event: Some("content_block_delta".to_string()), event: Some("content_block_delta".to_string()),
raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(), sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#
.to_string(),
provider_stream_response: None, provider_stream_response: None,
}; };
assert!(!normal_event.should_skip()); assert!(!normal_event.should_skip());
@ -616,7 +737,7 @@ mod tests {
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
"data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(), "data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(),
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out "data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
"data: [DONE]".to_string(), // This should end the stream "data: [DONE]".to_string(), // This should end the stream
]; ];
let mut iter = SseStreamIter::new(test_lines.into_iter()); let mut iter = SseStreamIter::new(test_lines.into_iter());
@ -684,13 +805,15 @@ mod tests {
#[test] #[test]
fn test_provider_stream_response_event_type() { fn test_provider_stream_response_event_type() {
use crate::apis::anthropic::{MessagesStreamEvent, MessagesContentDelta}; use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent};
use crate::apis::openai::ChatCompletionsStreamResponse; use crate::apis::openai::ChatCompletionsStreamResponse;
// Test Anthropic event type // Test Anthropic event type
let anthropic_event = MessagesStreamEvent::ContentBlockDelta { let anthropic_event = MessagesStreamEvent::ContentBlockDelta {
index: 0, index: 0,
delta: MessagesContentDelta::TextDelta { text: "Hello".to_string() }, delta: MessagesContentDelta::TextDelta {
text: "Hello".to_string(),
},
}; };
let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event); let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event);
assert_eq!(provider_type.event_type(), Some("content_block_delta")); assert_eq!(provider_type.event_type(), Some("content_block_delta"));
@ -717,15 +840,24 @@ mod tests {
// Test that [DONE] marker is properly converted to MessageStop in the transformation layer // Test that [DONE] marker is properly converted to MessageStop in the transformation layer
let done_bytes = b"[DONE]"; let done_bytes = b"[DONE]";
let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions); let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(
crate::apis::openai::OpenAIApi::ChatCompletions,
);
let result = ProviderStreamResponseType::try_from((done_bytes.as_slice(), &client_api, &upstream_api)); let result = ProviderStreamResponseType::try_from((
done_bytes.as_slice(),
&client_api,
&upstream_api,
));
assert!(result.is_ok()); assert!(result.is_ok());
if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result { if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result {
// Verify it's a MessageStop event // Verify it's a MessageStop event
assert_eq!(event.event_type(), Some("message_stop")); assert_eq!(event.event_type(), Some("message_stop"));
assert!(matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStop)); assert!(matches!(
event,
crate::apis::anthropic::MessagesStreamEvent::MessageStop
));
} else { } else {
panic!("Expected MessagesStreamEvent::MessageStop"); panic!("Expected MessagesStreamEvent::MessageStop");
} }
@ -747,7 +879,10 @@ mod tests {
// This signals the caller to wait for more data // This signals the caller to wait for more data
let result = decoder.decode_frame(); let result = decoder.decode_frame();
assert!(result.is_some()); assert!(result.is_some());
assert!(matches!(result.unwrap(), aws_smithy_eventstream::frame::DecodedFrame::Incomplete)); assert!(matches!(
result.unwrap(),
aws_smithy_eventstream::frame::DecodedFrame::Incomplete
));
// Verify we can still access the buffer // Verify we can still access the buffer
assert!(decoder.has_remaining()); assert!(decoder.has_remaining());
@ -760,8 +895,8 @@ mod tests {
use std::path::PathBuf; use std::path::PathBuf;
// Read the actual response.hex file from tests/e2e directory // Read the actual response.hex file from tests/e2e directory
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) let test_file =
.join("../../tests/e2e/response.hex"); PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
// Only run this test if the file exists // Only run this test if the file exists
if !test_file.exists() { if !test_file.exists() {
@ -782,7 +917,8 @@ mod tests {
frame_count += 1; frame_count += 1;
// Verify we can access headers // Verify we can access headers
let event_type = message.headers() let event_type = message
.headers()
.iter() .iter()
.find(|h| h.name().as_str() == ":event-type") .find(|h| h.name().as_str() == ":event-type")
.and_then(|h| h.value().as_string().ok()); .and_then(|h| h.value().as_string().ok());
@ -811,8 +947,8 @@ mod tests {
use std::path::PathBuf; use std::path::PathBuf;
// Read the actual response.hex file // Read the actual response.hex file
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) let test_file =
.join("../../tests/e2e/response.hex"); PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
if !test_file.exists() { if !test_file.exists() {
println!("Skipping test - response.hex not found"); println!("Skipping test - response.hex not found");
@ -863,7 +999,10 @@ mod tests {
} }
} }
assert!(total_frames > 0, "Should have decoded frames from chunked data"); assert!(
total_frames > 0,
"Should have decoded frames from chunked data"
);
} }
#[test] #[test]
@ -872,7 +1011,7 @@ mod tests {
} }
#[test] #[test]
#[ignore] // Run with: cargo test -- --ignored --nocapture #[ignore] // Run with: cargo test -- --ignored --nocapture
fn test_bedrock_decoded_frame_to_provider_response_verbose() { fn test_bedrock_decoded_frame_to_provider_response_verbose() {
test_bedrock_conversion(true); test_bedrock_conversion(true);
} }
@ -883,7 +1022,7 @@ mod tests {
} }
#[test] #[test]
#[ignore] // Run with: cargo test -- --ignored --nocapture #[ignore] // Run with: cargo test -- --ignored --nocapture
fn test_bedrock_decoded_frame_with_tool_use_verbose() { fn test_bedrock_decoded_frame_with_tool_use_verbose() {
test_bedrock_conversion_with_tools(true); test_bedrock_conversion_with_tools(true);
} }
@ -894,8 +1033,8 @@ mod tests {
use std::path::PathBuf; use std::path::PathBuf;
// Read the actual response.hex file from tests/e2e directory // Read the actual response.hex file from tests/e2e directory
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")) let test_file =
.join("../../tests/e2e/response.hex"); PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
// Only run this test if the file exists // Only run this test if the file exists
if !test_file.exists() { if !test_file.exists() {
@ -908,8 +1047,11 @@ mod tests {
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); let client_api =
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream); SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
);
let mut conversion_count = 0; let mut conversion_count = 0;
let mut message_start_seen = false; let mut message_start_seen = false;
@ -919,14 +1061,18 @@ mod tests {
match decoder.decode_frame() { match decoder.decode_frame() {
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
// Convert DecodedFrame to ProviderStreamResponseType // Convert DecodedFrame to ProviderStreamResponseType
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); let result =
ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
match result { match result {
Ok(provider_response) => { Ok(provider_response) => {
conversion_count += 1; conversion_count += 1;
// Verify we got a MessagesStreamEvent // Verify we got a MessagesStreamEvent
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_))); assert!(matches!(
provider_response,
ProviderStreamResponseType::MessagesStreamEvent(_)
));
if verbose { if verbose {
// Print the SSE string output // Print the SSE string output
@ -935,8 +1081,13 @@ mod tests {
} }
// Check for MessageStart event // Check for MessageStart event
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response { if let ProviderStreamResponseType::MessagesStreamEvent(ref event) =
if matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. }) { provider_response
{
if matches!(
event,
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. }
) {
message_start_seen = true; message_start_seen = true;
} }
} }
@ -956,7 +1107,10 @@ mod tests {
} }
} }
assert!(conversion_count > 0, "Should have converted at least one frame"); assert!(
conversion_count > 0,
"Should have converted at least one frame"
);
assert!(message_start_seen, "Should have seen MessageStart event"); assert!(message_start_seen, "Should have seen MessageStart event");
} }
@ -980,8 +1134,11 @@ mod tests {
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); let client_api =
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream); SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
);
let mut conversion_count = 0; let mut conversion_count = 0;
let mut message_start_seen = false; let mut message_start_seen = false;
@ -993,14 +1150,18 @@ mod tests {
match decoder.decode_frame() { match decoder.decode_frame() {
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => { Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
// Convert DecodedFrame to ProviderStreamResponseType // Convert DecodedFrame to ProviderStreamResponseType
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api)); let result =
ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
match result { match result {
Ok(provider_response) => { Ok(provider_response) => {
conversion_count += 1; conversion_count += 1;
// Verify we got a MessagesStreamEvent // Verify we got a MessagesStreamEvent
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_))); assert!(matches!(
provider_response,
ProviderStreamResponseType::MessagesStreamEvent(_)
));
if verbose { if verbose {
// Print the SSE string output // Print the SSE string output
@ -1009,7 +1170,9 @@ mod tests {
} }
// Check for specific events related to tool use // Check for specific events related to tool use
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response { if let ProviderStreamResponseType::MessagesStreamEvent(ref event) =
provider_response
{
match event { match event {
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => { crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => {
message_start_seen = true; message_start_seen = true;
@ -1041,16 +1204,25 @@ mod tests {
} }
} }
assert!(conversion_count > 0, "Should have converted at least one frame"); assert!(
conversion_count > 0,
"Should have converted at least one frame"
);
assert!(message_start_seen, "Should have seen MessageStart event"); assert!(message_start_seen, "Should have seen MessageStart event");
assert!(content_block_start_seen, "Should have seen ContentBlockStart event for tool use"); assert!(
assert!(content_block_delta_tool_use_seen, "Should have seen ContentBlockDelta with ToolUseDelta"); content_block_start_seen,
"Should have seen ContentBlockStart event for tool use"
);
assert!(
content_block_delta_tool_use_seen,
"Should have seen ContentBlockDelta with ToolUseDelta"
);
} }
#[test] #[test]
fn test_sse_event_transformation_openai_to_anthropic_message_start() { fn test_sse_event_transformation_openai_to_anthropic_message_start() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an OpenAI stream response that represents a role start (which becomes message_start in Anthropic) // Create an OpenAI stream response that represents a role start (which becomes message_start in Anthropic)
let openai_stream_chunk = json!({ let openai_stream_chunk = json!({
@ -1085,8 +1257,14 @@ mod tests {
// Verify the transformation includes both message_start and content_block_start // Verify the transformation includes both message_start and content_block_start
let buffer = transformed.sse_transform_buffer; let buffer = transformed.sse_transform_buffer;
assert!(buffer.contains("event: message_start"), "Should contain message_start event"); assert!(
assert!(buffer.contains("event: content_block_start"), "Should contain content_block_start event"); buffer.contains("event: message_start"),
"Should contain message_start event"
);
assert!(
buffer.contains("event: content_block_start"),
"Should contain content_block_start event"
);
// Verify proper SSE format with event lines before data lines // Verify proper SSE format with event lines before data lines
assert!(buffer.find("event: message_start").unwrap() < buffer.find("data:").unwrap()); assert!(buffer.find("event: message_start").unwrap() < buffer.find("data:").unwrap());
@ -1095,8 +1273,8 @@ mod tests {
#[test] #[test]
fn test_sse_event_transformation_openai_to_anthropic_message_delta() { fn test_sse_event_transformation_openai_to_anthropic_message_delta() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic) // Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic)
let openai_stream_chunk = json!({ let openai_stream_chunk = json!({
@ -1136,19 +1314,28 @@ mod tests {
// Verify the transformation includes both content_block_stop and message_delta // Verify the transformation includes both content_block_stop and message_delta
let buffer = transformed.sse_transform_buffer; let buffer = transformed.sse_transform_buffer;
assert!(buffer.contains("event: content_block_stop"), "Should contain content_block_stop event"); assert!(
assert!(buffer.contains("event: message_delta"), "Should contain message_delta event"); buffer.contains("event: content_block_stop"),
"Should contain content_block_stop event"
);
assert!(
buffer.contains("event: message_delta"),
"Should contain message_delta event"
);
// Verify content_block_stop comes before message_delta // Verify content_block_stop comes before message_delta
let stop_pos = buffer.find("content_block_stop").unwrap(); let stop_pos = buffer.find("content_block_stop").unwrap();
let delta_pos = buffer.find("message_delta").unwrap(); let delta_pos = buffer.find("message_delta").unwrap();
assert!(stop_pos < delta_pos, "content_block_stop should come before message_delta"); assert!(
stop_pos < delta_pos,
"content_block_stop should come before message_delta"
);
} }
#[test] #[test]
fn test_sse_event_transformation_openai_to_anthropic_content_delta() { fn test_sse_event_transformation_openai_to_anthropic_content_delta() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic) // Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic)
let openai_stream_chunk = json!({ let openai_stream_chunk = json!({
@ -1183,9 +1370,18 @@ mod tests {
// Verify the transformation is a content_block_delta (no extra events injected) // Verify the transformation is a content_block_delta (no extra events injected)
let buffer = transformed.sse_transform_buffer; let buffer = transformed.sse_transform_buffer;
assert!(buffer.contains("event: content_block_delta"), "Should contain content_block_delta event"); assert!(
assert!(!buffer.contains("content_block_start"), "Should not inject content_block_start for content delta"); buffer.contains("event: content_block_delta"),
assert!(!buffer.contains("content_block_stop"), "Should not inject content_block_stop for content delta"); "Should contain content_block_delta event"
);
assert!(
!buffer.contains("content_block_start"),
"Should not inject content_block_start for content delta"
);
assert!(
!buffer.contains("content_block_stop"),
"Should not inject content_block_stop for content delta"
);
// Verify the content is preserved // Verify the content is preserved
assert!(buffer.contains("Hello"), "Should preserve the content text"); assert!(buffer.contains("Hello"), "Should preserve the content text");
@ -1193,8 +1389,8 @@ mod tests {
#[test] #[test]
fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() { fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an Anthropic event-only SSE line (no data) // Create an Anthropic event-only SSE line (no data)
let sse_event = SseEvent { let sse_event = SseEvent {
@ -1215,14 +1411,20 @@ mod tests {
let transformed = result.unwrap(); let transformed = result.unwrap();
// Verify the event line is suppressed (replaced with just newline) // Verify the event line is suppressed (replaced with just newline)
assert_eq!(transformed.sse_transform_buffer, "\n", "Event-only lines should be suppressed to newline for OpenAI"); assert_eq!(
assert!(transformed.is_event_only(), "Should still be marked as event-only"); transformed.sse_transform_buffer, "\n",
"Event-only lines should be suppressed to newline for OpenAI"
);
assert!(
transformed.is_event_only(),
"Should still be marked as event-only"
);
} }
#[test] #[test]
fn test_sse_event_transformation_anthropic_to_openai_preserves_data() { fn test_sse_event_transformation_anthropic_to_openai_preserves_data() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an Anthropic message_start event with data // Create an Anthropic message_start event with data
let anthropic_event = json!({ let anthropic_event = json!({
@ -1258,7 +1460,10 @@ mod tests {
// Verify data is transformed to OpenAI format // Verify data is transformed to OpenAI format
let buffer = transformed.sse_transform_buffer; let buffer = transformed.sse_transform_buffer;
assert!(buffer.starts_with("data: "), "Should have data: prefix"); assert!(buffer.starts_with("data: "), "Should have data: prefix");
assert!(!buffer.contains("event:"), "Should not have event: lines for OpenAI"); assert!(
!buffer.contains("event:"),
"Should not have event: lines for OpenAI"
);
// Verify provider response was parsed // Verify provider response was parsed
assert!(transformed.provider_stream_response.is_some()); assert!(transformed.provider_stream_response.is_some());
@ -1310,8 +1515,8 @@ mod tests {
#[test] #[test]
fn test_sse_event_transformation_preserves_provider_response() { fn test_sse_event_transformation_preserves_provider_response() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an OpenAI stream response // Create an OpenAI stream response
let openai_stream_chunk = json!({ let openai_stream_chunk = json!({
@ -1344,11 +1549,17 @@ mod tests {
let transformed = result.unwrap(); let transformed = result.unwrap();
// Verify provider_stream_response is populated // Verify provider_stream_response is populated
assert!(transformed.provider_stream_response.is_some(), "Should parse and store provider response"); assert!(
transformed.provider_stream_response.is_some(),
"Should parse and store provider response"
);
// Verify we can access the provider response // Verify we can access the provider response
let provider_response = transformed.provider_response(); let provider_response = transformed.provider_response();
assert!(provider_response.is_ok(), "Should be able to access provider response"); assert!(
provider_response.is_ok(),
"Should be able to access provider response"
);
// Verify the content delta is accessible // Verify the content delta is accessible
let content = provider_response.unwrap().content_delta(); let content = provider_response.unwrap().content_delta();

View file

@ -1,7 +1,7 @@
use serde_json::Value; use crate::apis::anthropic::{MessagesContentBlock, MessagesImageSource};
use crate::apis::anthropic::{MessagesContentBlock,MessagesImageSource};
use crate::apis::openai::{ContentPart, FunctionCall, ImageUrl, Message, MessageContent, ToolCall}; use crate::apis::openai::{ContentPart, FunctionCall, ImageUrl, Message, MessageContent, ToolCall};
use crate::clients::TransformError; use crate::clients::TransformError;
use serde_json::Value;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
pub trait ExtractText { pub trait ExtractText {
@ -11,12 +11,17 @@ pub trait ExtractText {
/// Trait for utility functions on content collections /// Trait for utility functions on content collections
pub trait ContentUtils<T> { pub trait ContentUtils<T> {
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>; fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>; fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
} }
/// Helper to create a current unix timestamp /// Helper to create a current unix timestamp
pub fn current_timestamp() -> u64 { pub fn current_timestamp() -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
} }
// Content Utilities // Content Utilities
@ -26,24 +31,36 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
for block in self { for block in self {
match block { match block {
MessagesContentBlock::ToolUse { id, name, input, .. } | MessagesContentBlock::ToolUse {
MessagesContentBlock::ServerToolUse { id, name, input } | id, name, input, ..
MessagesContentBlock::McpToolUse { id, name, input } => { }
| MessagesContentBlock::ServerToolUse { id, name, input }
| MessagesContentBlock::McpToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input)?; let arguments = serde_json::to_string(&input)?;
tool_calls.push(ToolCall { tool_calls.push(ToolCall {
id: id.clone(), id: id.clone(),
call_type: "function".to_string(), call_type: "function".to_string(),
function: FunctionCall { name: name.clone(), arguments }, function: FunctionCall {
name: name.clone(),
arguments,
},
}); });
} }
_ => continue, _ => continue,
} }
} }
Ok(if tool_calls.is_empty() { None } else { Some(tool_calls) }) Ok(if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
})
} }
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError> { fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>
{
let mut content_parts = Vec::new(); let mut content_parts = Vec::new();
let mut tool_calls = Vec::new(); let mut tool_calls = Vec::new();
let mut tool_results = Vec::new(); let mut tool_results = Vec::new();
@ -62,25 +79,55 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
}, },
}); });
} }
MessagesContentBlock::ToolUse { id, name, input, .. } | MessagesContentBlock::ToolUse {
MessagesContentBlock::ServerToolUse { id, name, input } | id, name, input, ..
MessagesContentBlock::McpToolUse { id, name, input } => { }
| MessagesContentBlock::ServerToolUse { id, name, input }
| MessagesContentBlock::McpToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input)?; let arguments = serde_json::to_string(&input)?;
tool_calls.push(ToolCall { tool_calls.push(ToolCall {
id: id.clone(), id: id.clone(),
call_type: "function".to_string(), call_type: "function".to_string(),
function: FunctionCall { name: name.clone(), arguments }, function: FunctionCall {
name: name.clone(),
arguments,
},
}); });
} }
MessagesContentBlock::ToolResult { tool_use_id, content, is_error, .. } => { MessagesContentBlock::ToolResult {
tool_use_id,
content,
is_error,
..
} => {
let result_text = content.extract_text(); let result_text = content.extract_text();
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false))); tool_results.push((
tool_use_id.clone(),
result_text,
is_error.unwrap_or(false),
));
} }
MessagesContentBlock::WebSearchToolResult { tool_use_id, content, is_error } | MessagesContentBlock::WebSearchToolResult {
MessagesContentBlock::CodeExecutionToolResult { tool_use_id, content, is_error } | tool_use_id,
MessagesContentBlock::McpToolResult { tool_use_id, content, is_error } => { content,
is_error,
}
| MessagesContentBlock::CodeExecutionToolResult {
tool_use_id,
content,
is_error,
}
| MessagesContentBlock::McpToolResult {
tool_use_id,
content,
is_error,
} => {
let result_text = content.extract_text(); let result_text = content.extract_text();
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false))); tool_results.push((
tool_use_id.clone(),
result_text,
is_error.unwrap_or(false),
));
} }
_ => { _ => {
// Skip unsupported content types // Skip unsupported content types
@ -122,29 +169,41 @@ fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource {
data: data.to_string(), data: data.to_string(),
} }
} else { } else {
MessagesImageSource::Url { url: image_url.url.clone() } MessagesImageSource::Url {
url: image_url.url.clone(),
}
} }
} else { } else {
MessagesImageSource::Url { url: image_url.url.clone() } MessagesImageSource::Url {
url: image_url.url.clone(),
}
} }
} }
/// Convert OpenAI message to Anthropic content blocks /// Convert OpenAI message to Anthropic content blocks
pub fn convert_openai_message_to_anthropic_content(message: &Message) -> Result<Vec<MessagesContentBlock>, TransformError> { pub fn convert_openai_message_to_anthropic_content(
message: &Message,
) -> Result<Vec<MessagesContentBlock>, TransformError> {
let mut blocks = Vec::new(); let mut blocks = Vec::new();
// Handle regular content // Handle regular content
match &message.content { match &message.content {
MessageContent::Text(text) => { MessageContent::Text(text) => {
if !text.is_empty() { if !text.is_empty() {
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None }); blocks.push(MessagesContentBlock::Text {
text: text.clone(),
cache_control: None,
});
} }
} }
MessageContent::Parts(parts) => { MessageContent::Parts(parts) => {
for part in parts { for part in parts {
match part { match part {
ContentPart::Text { text } => { ContentPart::Text { text } => {
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None }); blocks.push(MessagesContentBlock::Text {
text: text.clone(),
cache_control: None,
});
} }
ContentPart::ImageUrl { image_url } => { ContentPart::ImageUrl { image_url } => {
let source = convert_image_url_to_source(image_url); let source = convert_image_url_to_source(image_url);

View file

@ -8,14 +8,14 @@
//! by the gateway, but the external API surface remains these two standard formats. //! by the gateway, but the external API surface remains these two standard formats.
//! The transformations are split into logical modules for maintainability. //! The transformations are split into logical modules for maintainability.
pub mod lib;
pub mod request; pub mod request;
pub mod response; pub mod response;
pub mod lib;
// Re-export commonly used items for convenience // Re-export commonly used items for convenience
pub use lib::*;
pub use request::*; pub use request::*;
pub use response::*; pub use response::*;
pub use lib::*;
// ============================================================================ // ============================================================================
// CONSTANTS // CONSTANTS

View file

@ -1,14 +1,21 @@
use crate::transforms::lib::*;
use crate::clients::TransformError;
use crate::apis::anthropic::{MessagesMessage, MessagesRequest, MessagesMessageContent, MessagesRole, MessagesStopReason, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage, MessagesSystemPrompt, ToolResultContent};
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Function, FunctionChoice,FinishReason, Usage, ContentPart};
use crate::apis::amazon_bedrock::{ use crate::apis::amazon_bedrock::{
ConverseRequest, SystemContentBlock, InferenceConfiguration, ToolConfiguration, AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, ImageBlock,
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolInputSchema, ToolSpecDefinition, ImageSource, InferenceConfiguration, Message as BedrockMessage, SystemContentBlock,
AutoChoice, AnyChoice, ToolChoiceSpec, Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration,
Message as BedrockMessage, ConversationRole, ContentBlock, ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ToolSpecDefinition,
ToolUseBlock, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ImageBlock, ImageSource ToolUseBlock,
}; };
use crate::apis::anthropic::{
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, MessagesStopReason,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage,
ToolResultContent,
};
use crate::apis::openai::{
ChatCompletionsRequest, ContentPart, FinishReason, Function, FunctionChoice, Message,
MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Usage,
};
use crate::clients::TransformError;
use crate::transforms::lib::*;
type AnthropicMessagesRequest = MessagesRequest; type AnthropicMessagesRequest = MessagesRequest;
@ -32,7 +39,8 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
// Convert tools and tool choice // Convert tools and tool choice
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools)); let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice); let (openai_tool_choice, parallel_tool_calls) =
convert_anthropic_tool_choice(req.tool_choice);
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest { let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
model: req.model, model: req.model,
@ -91,7 +99,8 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
// Convert tools and tool choice to ToolConfiguration // Convert tools and tool choice to ToolConfiguration
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() { let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
let tools = req.tools.map(|anthropic_tools| { let tools = req.tools.map(|anthropic_tools| {
anthropic_tools.into_iter() anthropic_tools
.into_iter()
.map(|tool| BedrockTool::ToolSpec { .map(|tool| BedrockTool::ToolSpec {
tool_spec: ToolSpecDefinition { tool_spec: ToolSpecDefinition {
name: tool.name, name: tool.name,
@ -106,25 +115,28 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
let tool_choice = req.tool_choice.map(|choice| { let tool_choice = req.tool_choice.map(|choice| {
match choice.kind { match choice.kind {
MessagesToolChoiceType::Auto => BedrockToolChoice::Auto { auto: AutoChoice {} }, MessagesToolChoiceType::Auto => BedrockToolChoice::Auto {
auto: AutoChoice {},
},
MessagesToolChoiceType::Any => BedrockToolChoice::Any { any: AnyChoice {} }, MessagesToolChoiceType::Any => BedrockToolChoice::Any { any: AnyChoice {} },
MessagesToolChoiceType::None => BedrockToolChoice::Auto { auto: AutoChoice {} }, // Bedrock doesn't have explicit "none" MessagesToolChoiceType::None => BedrockToolChoice::Auto {
auto: AutoChoice {},
}, // Bedrock doesn't have explicit "none"
MessagesToolChoiceType::Tool => { MessagesToolChoiceType::Tool => {
if let Some(name) = choice.name { if let Some(name) = choice.name {
BedrockToolChoice::Tool { BedrockToolChoice::Tool {
tool: ToolChoiceSpec { name } tool: ToolChoiceSpec { name },
} }
} else { } else {
BedrockToolChoice::Auto { auto: AutoChoice {} } BedrockToolChoice::Auto {
auto: AutoChoice {},
}
} }
} }
} }
}); });
Some(ToolConfiguration { Some(ToolConfiguration { tools, tool_choice })
tools,
tool_choice,
})
} else { } else {
None None
}; };
@ -147,7 +159,6 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
} }
} }
// Message Conversions // Message Conversions
impl TryFrom<MessagesMessage> for Vec<Message> { impl TryFrom<MessagesMessage> for Vec<Message> {
type Error = TransformError; type Error = TransformError;
@ -186,7 +197,11 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
role: message.role.into(), role: message.role.into(),
content, content,
name: None, name: None,
tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) }, tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None, tool_call_id: None,
}; };
result.push(main_message); result.push(main_message);
@ -198,7 +213,6 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
} }
} }
// Role Conversions // Role Conversions
impl Into<Role> for MessagesRole { impl Into<Role> for MessagesRole {
fn into(self) -> Role { fn into(self) -> Role {
@ -237,9 +251,7 @@ impl Into<Message> for MessagesSystemPrompt {
fn into(self) -> Message { fn into(self) -> Message {
let system_content = match self { let system_content = match self {
MessagesSystemPrompt::Single(text) => MessageContent::Text(text), MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
MessagesSystemPrompt::Blocks(blocks) => { MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
MessageContent::Text(blocks.extract_text())
}
}; };
Message { Message {
@ -255,7 +267,8 @@ impl Into<Message> for MessagesSystemPrompt {
//Utility Functions //Utility Functions
/// Convert Anthropic tools to OpenAI format /// Convert Anthropic tools to OpenAI format
fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> { fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
tools.into_iter() tools
.into_iter()
.map(|tool| Tool { .map(|tool| Tool {
tool_type: "function".to_string(), tool_type: "function".to_string(),
function: Function { function: Function {
@ -269,7 +282,9 @@ fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
} }
/// Convert Anthropic tool choice to OpenAI format /// Convert Anthropic tool choice to OpenAI format
fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Option<ToolChoice>, Option<bool>) { fn convert_anthropic_tool_choice(
tool_choice: Option<MessagesToolChoice>,
) -> (Option<ToolChoice>, Option<bool>) {
match tool_choice { match tool_choice {
Some(choice) => { Some(choice) => {
let openai_choice = match choice.kind { let openai_choice = match choice.kind {
@ -290,12 +305,15 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
let parallel = choice.disable_parallel_tool_use.map(|disable| !disable); let parallel = choice.disable_parallel_tool_use.map(|disable| !disable);
(Some(openai_choice), parallel) (Some(openai_choice), parallel)
} }
None => (None, None) None => (None, None),
} }
} }
/// Build OpenAI message content from parts and tool calls /// Build OpenAI message content from parts and tool calls
fn build_openai_content(content_parts: Vec<ContentPart>, tool_calls: &[ToolCall]) -> MessageContent { fn build_openai_content(
content_parts: Vec<ContentPart>,
tool_calls: &[ToolCall],
) -> MessageContent {
if content_parts.len() == 1 && tool_calls.is_empty() { if content_parts.len() == 1 && tool_calls.is_empty() {
match &content_parts[0] { match &content_parts[0] {
ContentPart::Text { text } => MessageContent::Text(text.clone()), ContentPart::Text { text } => MessageContent::Text(text.clone()),
@ -334,7 +352,12 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
content_blocks.push(ContentBlock::Text { text }); content_blocks.push(ContentBlock::Text { text });
} }
} }
crate::apis::anthropic::MessagesContentBlock::ToolUse { id, name, input, .. } => { crate::apis::anthropic::MessagesContentBlock::ToolUse {
id,
name,
input,
..
} => {
content_blocks.push(ContentBlock::ToolUse { content_blocks.push(ContentBlock::ToolUse {
tool_use: ToolUseBlock { tool_use: ToolUseBlock {
tool_use_id: id, tool_use_id: id,
@ -343,7 +366,12 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
}, },
}); });
} }
crate::apis::anthropic::MessagesContentBlock::ToolResult { tool_use_id, is_error, content, .. } => { crate::apis::anthropic::MessagesContentBlock::ToolResult {
tool_use_id,
is_error,
content,
..
} => {
// Convert Anthropic ToolResultContent to Bedrock ToolResultContentBlock // Convert Anthropic ToolResultContent to Bedrock ToolResultContentBlock
let tool_result_content = match content { let tool_result_content = match content {
ToolResultContent::Text(text) => { ToolResultContent::Text(text) => {
@ -366,7 +394,9 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
// Ensure we have at least one content block // Ensure we have at least one content block
let final_content = if tool_result_content.is_empty() { let final_content = if tool_result_content.is_empty() {
vec![ToolResultContentBlock::Text { text: " ".to_string() }] vec![ToolResultContentBlock::Text {
text: " ".to_string(),
}]
} else { } else {
tool_result_content tool_result_content
}; };
@ -388,13 +418,13 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
crate::apis::anthropic::MessagesContentBlock::Image { source } => { crate::apis::anthropic::MessagesContentBlock::Image { source } => {
// Convert Anthropic image to Bedrock image format // Convert Anthropic image to Bedrock image format
match source { match source {
crate::apis::anthropic::MessagesImageSource::Base64 { media_type, data } => { crate::apis::anthropic::MessagesImageSource::Base64 {
media_type,
data,
} => {
content_blocks.push(ContentBlock::Image { content_blocks.push(ContentBlock::Image {
image: ImageBlock { image: ImageBlock {
source: ImageSource::Base64 { source: ImageSource::Base64 { media_type, data },
media_type,
data,
},
}, },
}); });
} }
@ -413,7 +443,9 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
// Ensure we have at least one content block // Ensure we have at least one content block
if content_blocks.is_empty() { if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text { text: " ".to_string() }); content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
} }
Ok(BedrockMessage { Ok(BedrockMessage {
@ -423,29 +455,33 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::apis::anthropic::{MessagesRequest, MessagesMessage, MessagesMessageContent, MessagesRole, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesSystemPrompt}; use crate::apis::amazon_bedrock::{
use crate::apis::amazon_bedrock::{ConverseRequest, SystemContentBlock, ToolChoice as BedrockToolChoice, ConversationRole, ContentBlock}; ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
ToolChoice as BedrockToolChoice,
};
use crate::apis::anthropic::{
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
};
use serde_json::json; use serde_json::json;
#[test] #[test]
fn test_anthropic_to_bedrock_basic_request() { fn test_anthropic_to_bedrock_basic_request() {
let anthropic_request = MessagesRequest { let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(), model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![ messages: vec![MessagesMessage {
MessagesMessage { role: MessagesRole::User,
role: MessagesRole::User, content: MessagesMessageContent::Single("Hello, how are you?".to_string()),
content: MessagesMessageContent::Single("Hello, how are you?".to_string()), }],
}
],
max_tokens: 1000, max_tokens: 1000,
container: None, container: None,
mcp_servers: None, mcp_servers: None,
system: Some(MessagesSystemPrompt::Single("You are a helpful assistant.".to_string())), system: Some(MessagesSystemPrompt::Single(
"You are a helpful assistant.".to_string(),
)),
metadata: None, metadata: None,
service_tier: None, service_tier: None,
thinking: None, thinking: None,
@ -478,19 +514,20 @@ mod tests {
assert_eq!(inference_config.temperature, Some(0.7)); assert_eq!(inference_config.temperature, Some(0.7));
assert_eq!(inference_config.top_p, Some(0.9)); assert_eq!(inference_config.top_p, Some(0.9));
assert_eq!(inference_config.max_tokens, Some(1000)); assert_eq!(inference_config.max_tokens, Some(1000));
assert_eq!(inference_config.stop_sequences, Some(vec!["STOP".to_string()])); assert_eq!(
inference_config.stop_sequences,
Some(vec!["STOP".to_string()])
);
} }
#[test] #[test]
fn test_anthropic_to_bedrock_with_tools() { fn test_anthropic_to_bedrock_with_tools() {
let anthropic_request = MessagesRequest { let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(), model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![ messages: vec![MessagesMessage {
MessagesMessage { role: MessagesRole::User,
role: MessagesRole::User, content: MessagesMessageContent::Single("What's the weather like?".to_string()),
content: MessagesMessageContent::Single("What's the weather like?".to_string()), }],
}
],
max_tokens: 1000, max_tokens: 1000,
container: None, container: None,
mcp_servers: None, mcp_servers: None,
@ -503,22 +540,20 @@ mod tests {
top_k: None, top_k: None,
stream: None, stream: None,
stop_sequences: None, stop_sequences: None,
tools: Some(vec![ tools: Some(vec![MessagesTool {
MessagesTool { name: "get_weather".to_string(),
name: "get_weather".to_string(), description: Some("Get current weather information".to_string()),
description: Some("Get current weather information".to_string()), input_schema: json!({
input_schema: json!({ "type": "object",
"type": "object", "properties": {
"properties": { "location": {
"location": { "type": "string",
"type": "string", "description": "The city name"
"description": "The city name" }
} },
}, "required": ["location"]
"required": ["location"] }),
}), }]),
}
]),
tool_choice: Some(MessagesToolChoice { tool_choice: Some(MessagesToolChoice {
kind: MessagesToolChoiceType::Tool, kind: MessagesToolChoiceType::Tool,
name: Some("get_weather".to_string()), name: Some("get_weather".to_string()),
@ -537,7 +572,10 @@ mod tests {
assert_eq!(tools.len(), 1); assert_eq!(tools.len(), 1);
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0]; let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
assert_eq!(tool_spec.name, "get_weather"); assert_eq!(tool_spec.name, "get_weather");
assert_eq!(tool_spec.description, Some("Get current weather information".to_string())); assert_eq!(
tool_spec.description,
Some("Get current weather information".to_string())
);
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice { if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
assert_eq!(tool.name, "get_weather"); assert_eq!(tool.name, "get_weather");
@ -550,12 +588,10 @@ mod tests {
fn test_anthropic_to_bedrock_auto_tool_choice() { fn test_anthropic_to_bedrock_auto_tool_choice() {
let anthropic_request = MessagesRequest { let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(), model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![ messages: vec![MessagesMessage {
MessagesMessage { role: MessagesRole::User,
role: MessagesRole::User, content: MessagesMessageContent::Single("Help me with something".to_string()),
content: MessagesMessageContent::Single("Help me with something".to_string()), }],
}
],
max_tokens: 500, max_tokens: 500,
container: None, container: None,
mcp_servers: None, mcp_servers: None,
@ -568,16 +604,14 @@ mod tests {
top_k: None, top_k: None,
stream: None, stream: None,
stop_sequences: None, stop_sequences: None,
tools: Some(vec![ tools: Some(vec![MessagesTool {
MessagesTool { name: "help_tool".to_string(),
name: "help_tool".to_string(), description: Some("A helpful tool".to_string()),
description: Some("A helpful tool".to_string()), input_schema: json!({
input_schema: json!({ "type": "object",
"type": "object", "properties": {}
"properties": {} }),
}), }]),
}
]),
tool_choice: Some(MessagesToolChoice { tool_choice: Some(MessagesToolChoice {
kind: MessagesToolChoiceType::Auto, kind: MessagesToolChoiceType::Auto,
name: None, name: None,
@ -589,7 +623,10 @@ mod tests {
assert!(bedrock_request.tool_config.is_some()); assert!(bedrock_request.tool_config.is_some());
let tool_config = bedrock_request.tool_config.as_ref().unwrap(); let tool_config = bedrock_request.tool_config.as_ref().unwrap();
assert!(matches!(tool_config.tool_choice, Some(BedrockToolChoice::Auto { .. }))); assert!(matches!(
tool_config.tool_choice,
Some(BedrockToolChoice::Auto { .. })
));
} }
#[test] #[test]
@ -603,12 +640,14 @@ mod tests {
}, },
MessagesMessage { MessagesMessage {
role: MessagesRole::Assistant, role: MessagesRole::Assistant,
content: MessagesMessageContent::Single("Hi there! How can I help you?".to_string()), content: MessagesMessageContent::Single(
"Hi there! How can I help you?".to_string(),
),
}, },
MessagesMessage { MessagesMessage {
role: MessagesRole::User, role: MessagesRole::User,
content: MessagesMessageContent::Single("What's 2+2?".to_string()), content: MessagesMessageContent::Single("What's 2+2?".to_string()),
} },
], ],
max_tokens: 100, max_tokens: 100,
container: None, container: None,

View file

@ -1,15 +1,21 @@
use crate::apis::amazon_bedrock::{
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, InferenceConfiguration,
Message as BedrockMessage, SystemContentBlock, Tool as BedrockTool,
ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration, ToolInputSchema,
ToolSpecDefinition,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
ToolResultContent,
};
use crate::apis::openai::{
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
};
use crate::clients::TransformError;
use crate::transforms::lib::ExtractText; use crate::transforms::lib::ExtractText;
use crate::transforms::lib::*; use crate::transforms::lib::*;
use crate::clients::TransformError;
use crate::transforms::*; use crate::transforms::*;
use crate::apis::anthropic::{MessagesSystemPrompt, MessagesMessage,MessagesRequest, MessagesMessageContent, MessagesContentBlock, MessagesRole, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, ToolResultContent};
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, Tool, ToolChoice, ToolChoiceType, MessageContent};
use crate::apis::amazon_bedrock::{
ConverseRequest, SystemContentBlock, InferenceConfiguration, ToolConfiguration,
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolInputSchema, ToolSpecDefinition,
AutoChoice, AnyChoice, ToolChoiceSpec,
Message as BedrockMessage, ConversationRole, ContentBlock
};
type AnthropicMessagesRequest = MessagesRequest; type AnthropicMessagesRequest = MessagesRequest;
@ -21,7 +27,7 @@ impl Into<MessagesSystemPrompt> for Message {
fn into(self) -> MessagesSystemPrompt { fn into(self) -> MessagesSystemPrompt {
let system_text = match self.content { let system_text = match self.content {
MessageContent::Text(text) => text, MessageContent::Text(text) => text,
MessageContent::Parts(parts) => parts.extract_text() MessageContent::Parts(parts) => parts.extract_text(),
}; };
MessagesSystemPrompt::Single(system_text) MessagesSystemPrompt::Single(system_text)
} }
@ -36,8 +42,11 @@ impl TryFrom<Message> for MessagesMessage {
Role::Assistant => MessagesRole::Assistant, Role::Assistant => MessagesRole::Assistant,
Role::Tool => { Role::Tool => {
// Tool messages become user messages with tool results // Tool messages become user messages with tool results
let tool_call_id = message.tool_call_id let tool_call_id = message.tool_call_id.ok_or_else(|| {
.ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?; TransformError::MissingField(
"tool_call_id required for Tool messages".to_string(),
)
})?;
return Ok(MessagesMessage { return Ok(MessagesMessage {
role: MessagesRole::User, role: MessagesRole::User,
@ -55,7 +64,9 @@ impl TryFrom<Message> for MessagesMessage {
}); });
} }
Role::System => { Role::System => {
return Err(TransformError::UnsupportedConversion("System messages should be handled separately".to_string())); return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately".to_string(),
));
} }
}; };
@ -75,7 +86,9 @@ impl TryFrom<Message> for BedrockMessage {
Role::Assistant => ConversationRole::Assistant, Role::Assistant => ConversationRole::Assistant,
Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock
Role::System => { Role::System => {
return Err(TransformError::UnsupportedConversion("System messages should be handled separately in Bedrock".to_string())); return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately in Bedrock".to_string(),
));
} }
}; };
@ -103,7 +116,9 @@ impl TryFrom<Message> for BedrockMessage {
crate::apis::openai::ContentPart::ImageUrl { image_url } => { crate::apis::openai::ContentPart::ImageUrl { image_url } => {
// Convert image URL to Bedrock image format // Convert image URL to Bedrock image format
if image_url.url.starts_with("data:") { if image_url.url.starts_with("data:") {
if let Some((media_type, data)) = parse_data_url(&image_url.url) { if let Some((media_type, data)) =
parse_data_url(&image_url.url)
{
content_blocks.push(ContentBlock::Image { content_blocks.push(ContentBlock::Image {
image: crate::apis::amazon_bedrock::ImageBlock { image: crate::apis::amazon_bedrock::ImageBlock {
source: crate::apis::amazon_bedrock::ImageSource::Base64 { source: crate::apis::amazon_bedrock::ImageSource::Base64 {
@ -114,7 +129,10 @@ impl TryFrom<Message> for BedrockMessage {
}); });
} else { } else {
return Err(TransformError::UnsupportedConversion( return Err(TransformError::UnsupportedConversion(
format!("Invalid data URL format: {}", image_url.url) format!(
"Invalid data URL format: {}",
image_url.url
),
)); ));
} }
} else { } else {
@ -130,13 +148,18 @@ impl TryFrom<Message> for BedrockMessage {
// Ensure we have at least one content block // Ensure we have at least one content block
if content_blocks.is_empty() { if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text { text: " ".to_string() }); content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
} }
} }
Role::Assistant => { Role::Assistant => {
// Handle text content - but only add if non-empty OR if we don't have tool calls // Handle text content - but only add if non-empty OR if we don't have tool calls
let text_content = message.content.extract_text(); let text_content = message.content.extract_text();
let has_tool_calls = message.tool_calls.as_ref().map_or(false, |calls| !calls.is_empty()); let has_tool_calls = message
.tool_calls
.as_ref()
.map_or(false, |calls| !calls.is_empty());
// Add text content if it's non-empty, or if we have no tool calls (to avoid empty content) // Add text content if it's non-empty, or if we have no tool calls (to avoid empty content)
if !text_content.is_empty() { if !text_content.is_empty() {
@ -144,17 +167,22 @@ impl TryFrom<Message> for BedrockMessage {
} else if !has_tool_calls { } else if !has_tool_calls {
// If we have empty content and no tool calls, add a minimal placeholder // If we have empty content and no tool calls, add a minimal placeholder
// This prevents the "blank text field" error // This prevents the "blank text field" error
content_blocks.push(ContentBlock::Text { text: " ".to_string() }); content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
} }
// Convert tool calls to ToolUse content blocks // Convert tool calls to ToolUse content blocks
if let Some(tool_calls) = message.tool_calls { if let Some(tool_calls) = message.tool_calls {
for tool_call in tool_calls { for tool_call in tool_calls {
// Parse the arguments string as JSON // Parse the arguments string as JSON
let input: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) let input: serde_json::Value =
.map_err(|e| TransformError::UnsupportedConversion( serde_json::from_str(&tool_call.function.arguments).map_err(|e| {
format!("Failed to parse tool arguments as JSON: {}. Arguments: {}", e, tool_call.function.arguments) TransformError::UnsupportedConversion(format!(
))?; "Failed to parse tool arguments as JSON: {}. Arguments: {}",
e, tool_call.function.arguments
))
})?;
content_blocks.push(ContentBlock::ToolUse { content_blocks.push(ContentBlock::ToolUse {
tool_use: crate::apis::amazon_bedrock::ToolUseBlock { tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
@ -168,13 +196,18 @@ impl TryFrom<Message> for BedrockMessage {
// Bedrock requires at least one content block // Bedrock requires at least one content block
if content_blocks.is_empty() { if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text { text: " ".to_string() }); content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
} }
} }
Role::Tool => { Role::Tool => {
// Tool messages become user messages with ToolResult content blocks // Tool messages become user messages with ToolResult content blocks
let tool_call_id = message.tool_call_id let tool_call_id = message.tool_call_id.ok_or_else(|| {
.ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?; TransformError::MissingField(
"tool_call_id required for Tool messages".to_string(),
)
})?;
let tool_content = message.content.extract_text(); let tool_content = message.content.extract_text();
@ -182,11 +215,11 @@ impl TryFrom<Message> for BedrockMessage {
let tool_result_content = if tool_content.is_empty() { let tool_result_content = if tool_content.is_empty() {
// Even for tool results, we need non-empty content // Even for tool results, we need non-empty content
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text { vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
text: " ".to_string() text: " ".to_string(),
}] }]
} else { } else {
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text { vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
text: tool_content text: tool_content,
}] }]
}; };
@ -232,13 +265,15 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
// Convert tools and tool choice // Convert tools and tool choice
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); let anthropic_tool_choice =
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
Ok(AnthropicMessagesRequest { Ok(AnthropicMessagesRequest {
model: req.model, model: req.model,
system: system_prompt, system: system_prompt,
messages, messages,
max_tokens: req.max_completion_tokens max_tokens: req
.max_completion_tokens
.or(req.max_tokens) .or(req.max_tokens)
.unwrap_or(DEFAULT_MAX_TOKENS), .unwrap_or(DEFAULT_MAX_TOKENS),
container: None, container: None,
@ -297,8 +332,11 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
// Build inference configuration // Build inference configuration
let max_tokens = req.max_completion_tokens.or(req.max_tokens); let max_tokens = req.max_completion_tokens.or(req.max_tokens);
let inference_config = if max_tokens.is_some() || req.temperature.is_some() || let inference_config = if max_tokens.is_some()
req.top_p.is_some() || req.stop.is_some() { || req.temperature.is_some()
|| req.top_p.is_some()
|| req.stop.is_some()
{
Some(InferenceConfiguration { Some(InferenceConfiguration {
max_tokens, max_tokens,
temperature: req.temperature, temperature: req.temperature,
@ -312,7 +350,8 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
// Convert tools and tool choice to ToolConfiguration // Convert tools and tool choice to ToolConfiguration
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() { let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
let tools = req.tools.map(|openai_tools| { let tools = req.tools.map(|openai_tools| {
openai_tools.into_iter() openai_tools
.into_iter()
.map(|tool| BedrockTool::ToolSpec { .map(|tool| BedrockTool::ToolSpec {
tool_spec: ToolSpecDefinition { tool_spec: ToolSpecDefinition {
name: tool.function.name, name: tool.function.name,
@ -325,34 +364,40 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
.collect() .collect()
}); });
let tool_choice = req.tool_choice.map(|choice| { let tool_choice = req
match choice { .tool_choice
ToolChoice::Type(tool_type) => match tool_type { .map(|choice| {
ToolChoiceType::Auto => BedrockToolChoice::Auto { auto: AutoChoice {} }, match choice {
ToolChoiceType::Required => BedrockToolChoice::Any { any: AnyChoice {} }, ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::None => BedrockToolChoice::Auto { auto: AutoChoice {} }, // Bedrock doesn't have explicit "none" ToolChoiceType::Auto => BedrockToolChoice::Auto {
}, auto: AutoChoice {},
ToolChoice::Function { function, .. } => { },
BedrockToolChoice::Tool { ToolChoiceType::Required => {
tool: ToolChoiceSpec { BedrockToolChoice::Any { any: AnyChoice {} }
name: function.name
} }
} ToolChoiceType::None => BedrockToolChoice::Auto {
auto: AutoChoice {},
}, // Bedrock doesn't have explicit "none"
},
ToolChoice::Function { function, .. } => BedrockToolChoice::Tool {
tool: ToolChoiceSpec {
name: function.name,
},
},
} }
} })
}).or_else(|| { .or_else(|| {
// If tools are present but no tool_choice specified, default to "auto" // If tools are present but no tool_choice specified, default to "auto"
if tools.is_some() { if tools.is_some() {
Some(BedrockToolChoice::Auto { auto: AutoChoice {} }) Some(BedrockToolChoice::Auto {
} else { auto: AutoChoice {},
None })
} } else {
}); None
}
});
Some(ToolConfiguration { Some(ToolConfiguration { tools, tool_choice })
tools,
tool_choice,
})
} else { } else {
None None
}; };
@ -377,7 +422,8 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
/// Convert OpenAI tools to Anthropic format /// Convert OpenAI tools to Anthropic format
fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> { fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
tools.into_iter() tools
.into_iter()
.map(|tool| MessagesTool { .map(|tool| MessagesTool {
name: tool.function.name, name: tool.function.name,
description: tool.function.description, description: tool.function.description,
@ -386,37 +432,34 @@ fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
.collect() .collect()
} }
/// Convert OpenAI tool choice to Anthropic format /// Convert OpenAI tool choice to Anthropic format
fn convert_openai_tool_choice( fn convert_openai_tool_choice(
tool_choice: Option<ToolChoice>, tool_choice: Option<ToolChoice>,
parallel_tool_calls: Option<bool> parallel_tool_calls: Option<bool>,
) -> Option<MessagesToolChoice> { ) -> Option<MessagesToolChoice> {
tool_choice.map(|choice| { tool_choice.map(|choice| match choice {
match choice { ToolChoice::Type(tool_type) => match tool_type {
ToolChoice::Type(tool_type) => match tool_type { ToolChoiceType::Auto => MessagesToolChoice {
ToolChoiceType::Auto => MessagesToolChoice { kind: MessagesToolChoiceType::Auto,
kind: MessagesToolChoiceType::Auto, name: None,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::Required => MessagesToolChoice {
kind: MessagesToolChoiceType::Any,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::None => MessagesToolChoice {
kind: MessagesToolChoiceType::None,
name: None,
disable_parallel_tool_use: None,
},
},
ToolChoice::Function { function, .. } => MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some(function.name),
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
}, },
} ToolChoiceType::Required => MessagesToolChoice {
kind: MessagesToolChoiceType::Any,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::None => MessagesToolChoice {
kind: MessagesToolChoiceType::None,
name: None,
disable_parallel_tool_use: None,
},
},
ToolChoice::Function { function, .. } => MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some(function.name),
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
}) })
} }
@ -434,8 +477,6 @@ fn build_anthropic_content(content_blocks: Vec<MessagesContentBlock>) -> Message
} }
} }
/// Parse a data URL into media type and base64 data /// Parse a data URL into media type and base64 data
/// Supports format: data:image/jpeg;base64,<data> /// Supports format: data:image/jpeg;base64,<data>
fn parse_data_url(url: &str) -> Option<(String, String)> { fn parse_data_url(url: &str) -> Option<(String, String)> {
@ -473,8 +514,14 @@ fn parse_data_url(url: &str) -> Option<(String, String)> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType, Function, FunctionChoice}; use crate::apis::amazon_bedrock::{
use crate::apis::amazon_bedrock::{ConverseRequest, SystemContentBlock, ConversationRole, ContentBlock, ToolChoice as BedrockToolChoice}; ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
ToolChoice as BedrockToolChoice,
};
use crate::apis::openai::{
ChatCompletionsRequest, Function, FunctionChoice, Message, MessageContent, Role, Tool,
ToolChoice, ToolChoiceType,
};
use serde_json::json; use serde_json::json;
#[test] #[test]
@ -495,7 +542,7 @@ mod tests {
name: None, name: None,
tool_call_id: None, tool_call_id: None,
tool_calls: None, tool_calls: None,
} },
], ],
temperature: Some(0.7), temperature: Some(0.7),
top_p: Some(0.9), top_p: Some(0.9),
@ -534,50 +581,51 @@ mod tests {
assert_eq!(inference_config.temperature, Some(0.7)); assert_eq!(inference_config.temperature, Some(0.7));
assert_eq!(inference_config.top_p, Some(0.9)); assert_eq!(inference_config.top_p, Some(0.9));
assert_eq!(inference_config.max_tokens, Some(1000)); assert_eq!(inference_config.max_tokens, Some(1000));
assert_eq!(inference_config.stop_sequences, Some(vec!["STOP".to_string()])); assert_eq!(
inference_config.stop_sequences,
Some(vec!["STOP".to_string()])
);
} }
#[test] #[test]
fn test_openai_to_bedrock_with_tools() { fn test_openai_to_bedrock_with_tools() {
let openai_request = ChatCompletionsRequest { let openai_request = ChatCompletionsRequest {
model: "gpt-4".to_string(), model: "gpt-4".to_string(),
messages: vec![ messages: vec![Message {
Message { role: Role::User,
role: Role::User, content: MessageContent::Text("What's the weather like?".to_string()),
content: MessageContent::Text("What's the weather like?".to_string()), name: None,
name: None, tool_call_id: None,
tool_call_id: None, tool_calls: None,
tool_calls: None, }],
}
],
temperature: None, temperature: None,
top_p: None, top_p: None,
max_completion_tokens: Some(1000), max_completion_tokens: Some(1000),
stop: None, stop: None,
stream: None, stream: None,
tools: Some(vec![ tools: Some(vec![Tool {
Tool { tool_type: "function".to_string(),
tool_type: "function".to_string(), function: Function {
function: Function { name: "get_weather".to_string(),
name: "get_weather".to_string(), description: Some("Get current weather information".to_string()),
description: Some("Get current weather information".to_string()), parameters: json!({
parameters: json!({ "type": "object",
"type": "object", "properties": {
"properties": { "location": {
"location": { "type": "string",
"type": "string", "description": "The city name"
"description": "The city name" }
} },
}, "required": ["location"]
"required": ["location"] }),
}), strict: None,
strict: None, },
}, }]),
}
]),
tool_choice: Some(ToolChoice::Function { tool_choice: Some(ToolChoice::Function {
choice_type: "function".to_string(), choice_type: "function".to_string(),
function: FunctionChoice { name: "get_weather".to_string() }, function: FunctionChoice {
name: "get_weather".to_string(),
},
}), }),
..Default::default() ..Default::default()
}; };
@ -594,7 +642,10 @@ mod tests {
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0]; let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
assert_eq!(tool_spec.name, "get_weather"); assert_eq!(tool_spec.name, "get_weather");
assert_eq!(tool_spec.description, Some("Get current weather information".to_string())); assert_eq!(
tool_spec.description,
Some("Get current weather information".to_string())
);
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice { if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
assert_eq!(tool.name, "get_weather"); assert_eq!(tool.name, "get_weather");
@ -607,34 +658,30 @@ mod tests {
fn test_openai_to_bedrock_auto_tool_choice() { fn test_openai_to_bedrock_auto_tool_choice() {
let openai_request = ChatCompletionsRequest { let openai_request = ChatCompletionsRequest {
model: "gpt-4".to_string(), model: "gpt-4".to_string(),
messages: vec![ messages: vec![Message {
Message { role: Role::User,
role: Role::User, content: MessageContent::Text("Help me with something".to_string()),
content: MessageContent::Text("Help me with something".to_string()), name: None,
name: None, tool_call_id: None,
tool_call_id: None, tool_calls: None,
tool_calls: None, }],
}
],
temperature: None, temperature: None,
top_p: None, top_p: None,
max_completion_tokens: Some(500), max_completion_tokens: Some(500),
stop: None, stop: None,
stream: None, stream: None,
tools: Some(vec![ tools: Some(vec![Tool {
Tool { tool_type: "function".to_string(),
tool_type: "function".to_string(), function: Function {
function: Function { name: "help_tool".to_string(),
name: "help_tool".to_string(), description: Some("A helpful tool".to_string()),
description: Some("A helpful tool".to_string()), parameters: json!({
parameters: json!({ "type": "object",
"type": "object", "properties": {}
"properties": {} }),
}), strict: None,
strict: None, },
}, }]),
}
]),
tool_choice: Some(ToolChoice::Type(ToolChoiceType::Auto)), tool_choice: Some(ToolChoice::Type(ToolChoiceType::Auto)),
..Default::default() ..Default::default()
}; };
@ -643,7 +690,10 @@ mod tests {
assert!(bedrock_request.tool_config.is_some()); assert!(bedrock_request.tool_config.is_some());
let tool_config = bedrock_request.tool_config.as_ref().unwrap(); let tool_config = bedrock_request.tool_config.as_ref().unwrap();
assert!(matches!(tool_config.tool_choice, Some(BedrockToolChoice::Auto { .. }))); assert!(matches!(
tool_config.tool_choice,
Some(BedrockToolChoice::Auto { .. })
));
} }
#[test] #[test]
@ -678,7 +728,7 @@ mod tests {
name: None, name: None,
tool_call_id: None, tool_call_id: None,
tool_calls: None, tool_calls: None,
} },
], ],
temperature: Some(0.5), temperature: Some(0.5),
top_p: None, top_p: None,

View file

@ -1,16 +1,16 @@
use serde_json::Value; use crate::apis::amazon_bedrock::{
use crate::transforms::lib::*; ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
use crate::clients::TransformError;
use crate::apis::openai::{
ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta
}; };
use crate::apis::anthropic::{ use crate::apis::anthropic::{
MessagesStreamEvent, MessagesStopReason, MessagesMessageDelta, MessagesResponse, MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse,
MessagesStreamMessage, MessagesUsage, MessagesContentDelta, MessagesRole, MessagesContentBlock MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
}; };
use crate::apis::amazon_bedrock::{ use crate::apis::openai::{
ConverseResponse, ConverseOutput, StopReason, ConverseStreamEvent, ContentBlockDelta ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta,
}; };
use crate::clients::TransformError;
use crate::transforms::lib::*;
use serde_json::Value;
// ============================================================================ // ============================================================================
// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience // STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience
@ -20,11 +20,15 @@ impl TryFrom<ChatCompletionsResponse> for MessagesResponse {
type Error = TransformError; type Error = TransformError;
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> { fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
let choice = resp.choices.into_iter().next() let choice = resp
.choices
.into_iter()
.next()
.ok_or_else(|| TransformError::MissingField("choices".to_string()))?; .ok_or_else(|| TransformError::MissingField("choices".to_string()))?;
let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?; let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?;
let stop_reason = choice.finish_reason let stop_reason = choice
.finish_reason
.map(|fr| fr.into()) .map(|fr| fr.into())
.unwrap_or(MessagesStopReason::EndTurn); .unwrap_or(MessagesStopReason::EndTurn);
@ -86,13 +90,17 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
}; };
// Generate a response ID (Bedrock doesn't provide one) // Generate a response ID (Bedrock doesn't provide one)
let id = format!("bedrock-{}", std::time::SystemTime::now() let id = format!(
.duration_since(std::time::UNIX_EPOCH) "bedrock-{}",
.unwrap_or_default() std::time::SystemTime::now()
.as_nanos()); .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
);
// Extract model ID from trace information if available, otherwise use fallback // Extract model ID from trace information if available, otherwise use fallback
let model = resp.trace let model = resp
.trace
.as_ref() .as_ref()
.and_then(|trace| trace.prompt_router.as_ref()) .and_then(|trace| trace.prompt_router.as_ref())
.map(|router| router.invoked_model_id.clone()) .map(|router| router.invoked_model_id.clone())
@ -231,15 +239,20 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
ConverseStreamEvent::MessageStart(start_event) => { ConverseStreamEvent::MessageStart(start_event) => {
let role = match start_event.role { let role = match start_event.role {
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User, crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
crate::apis::amazon_bedrock::ConversationRole::Assistant => MessagesRole::Assistant, crate::apis::amazon_bedrock::ConversationRole::Assistant => {
MessagesRole::Assistant
}
}; };
Ok(MessagesStreamEvent::MessageStart { Ok(MessagesStreamEvent::MessageStart {
message: MessagesStreamMessage { message: MessagesStreamMessage {
id: format!("bedrock-stream-{}", std::time::SystemTime::now() id: format!(
.duration_since(std::time::UNIX_EPOCH) "bedrock-stream-{}",
.unwrap_or_default() std::time::SystemTime::now()
.as_nanos()), .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
),
obj_type: "message".to_string(), obj_type: "message".to_string(),
role, role,
content: vec![], content: vec![],
@ -278,11 +291,11 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
// ContentBlockDelta - convert to Anthropic ContentBlockDelta // ContentBlockDelta - convert to Anthropic ContentBlockDelta
ConverseStreamEvent::ContentBlockDelta(delta_event) => { ConverseStreamEvent::ContentBlockDelta(delta_event) => {
let delta = match delta_event.delta { let delta = match delta_event.delta {
ContentBlockDelta::Text { text } => { ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
MessagesContentDelta::TextDelta { text }
}
ContentBlockDelta::ToolUse { tool_use } => { ContentBlockDelta::ToolUse { tool_use } => {
MessagesContentDelta::InputJsonDelta { partial_json: tool_use.input } MessagesContentDelta::InputJsonDelta {
partial_json: tool_use.input,
}
} }
}; };
@ -342,11 +355,11 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
} }
// Exception events - convert to Ping (could be enhanced to return error events) // Exception events - convert to Ping (could be enhanced to return error events)
ConverseStreamEvent::InternalServerException(_) | ConverseStreamEvent::InternalServerException(_)
ConverseStreamEvent::ModelStreamErrorException(_) | | ConverseStreamEvent::ModelStreamErrorException(_)
ConverseStreamEvent::ServiceUnavailableException(_) | | ConverseStreamEvent::ServiceUnavailableException(_)
ConverseStreamEvent::ThrottlingException(_) | | ConverseStreamEvent::ThrottlingException(_)
ConverseStreamEvent::ValidationException(_) => { | ConverseStreamEvent::ValidationException(_) => {
// TODO: Consider adding proper error handling/events // TODO: Consider adding proper error handling/events
Ok(MessagesStreamEvent::Ping) Ok(MessagesStreamEvent::Ping)
} }
@ -355,7 +368,9 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
} }
/// Convert tool call deltas to Anthropic stream events /// Convert tool call deltas to Anthropic stream events
fn convert_tool_call_deltas(tool_calls: Vec<ToolCallDelta>) -> Result<MessagesStreamEvent, TransformError> { fn convert_tool_call_deltas(
tool_calls: Vec<ToolCallDelta>,
) -> Result<MessagesStreamEvent, TransformError> {
for tool_call in tool_calls { for tool_call in tool_calls {
if let Some(id) = &tool_call.id { if let Some(id) = &tool_call.id {
// Tool call start // Tool call start
@ -403,7 +418,9 @@ fn convert_tool_call_deltas(tool_calls: Vec<ToolCallDelta>) -> Result<MessagesSt
/// ///
/// Note on S3/URL handling: Converting S3 locations or URLs would require async operations /// Note on S3/URL handling: Converting S3 locations or URLs would require async operations
/// to download and convert to base64, which is not implemented in this synchronous function. /// to download and convert to base64, which is not implemented in this synchronous function.
fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_bedrock::Message) -> Result<Vec<MessagesContentBlock>, TransformError> { fn convert_bedrock_message_to_anthropic_content(
message: &crate::apis::amazon_bedrock::Message,
) -> Result<Vec<MessagesContentBlock>, TransformError> {
use crate::apis::amazon_bedrock::ContentBlock; use crate::apis::amazon_bedrock::ContentBlock;
let mut content_blocks = Vec::new(); let mut content_blocks = Vec::new();
@ -438,16 +455,19 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
crate::apis::amazon_bedrock::ToolResultContentBlock::Image { source } => { crate::apis::amazon_bedrock::ToolResultContentBlock::Image { source } => {
// Convert Bedrock ImageSource to Anthropic format // Convert Bedrock ImageSource to Anthropic format
match source { match source {
crate::apis::amazon_bedrock::ImageSource::Base64 { media_type, data } => { crate::apis::amazon_bedrock::ImageSource::Base64 {
media_type,
data,
} => {
tool_result_blocks.push(MessagesContentBlock::Image { tool_result_blocks.push(MessagesContentBlock::Image {
source: crate::apis::anthropic::MessagesImageSource::Base64 { source:
media_type: media_type.clone(), crate::apis::anthropic::MessagesImageSource::Base64 {
data: data.clone(), media_type: media_type.clone(),
}, data: data.clone(),
},
}); });
} } // Note: S3Location is not yet implemented in the current Bedrock API definition
// Note: S3Location is not yet implemented in the current Bedrock API definition // but would need async handling when added
// but would need async handling when added
} }
} }
crate::apis::amazon_bedrock::ToolResultContentBlock::Json { json } => { crate::apis::amazon_bedrock::ToolResultContentBlock::Json { json } => {
@ -463,7 +483,10 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
use crate::apis::anthropic::ToolResultContent; use crate::apis::anthropic::ToolResultContent;
content_blocks.push(MessagesContentBlock::ToolResult { content_blocks.push(MessagesContentBlock::ToolResult {
tool_use_id: tool_result.tool_use_id.clone(), tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.status.as_ref().map(|s| matches!(s, crate::apis::amazon_bedrock::ToolResultStatus::Error)), is_error: tool_result
.status
.as_ref()
.map(|s| matches!(s, crate::apis::amazon_bedrock::ToolResultStatus::Error)),
content: ToolResultContent::Blocks(tool_result_blocks), content: ToolResultContent::Blocks(tool_result_blocks),
cache_control: None, cache_control: None,
}); });
@ -478,8 +501,7 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
data: data.clone(), data: data.clone(),
}, },
}); });
} } // Note: S3Location would require async handling if implemented
// Note: S3Location would require async handling if implemented
} }
} }
ContentBlock::Document { document } => { ContentBlock::Document { document } => {
@ -493,8 +515,7 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
data: data.clone(), data: data.clone(),
}, },
}); });
} } // Note: S3Location would require async handling if implemented
// Note: S3Location would require async handling if implemented
} }
} }
ContentBlock::GuardContent { guard_content } => { ContentBlock::GuardContent { guard_content } => {
@ -516,11 +537,13 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
mod tests { mod tests {
use super::*; use super::*;
use crate::apis::amazon_bedrock::{ use crate::apis::amazon_bedrock::{
ConverseResponse, ConverseOutput, Message as BedrockMessage, ConversationRole, BedrockTokenUsage, ContentBlock, ConversationRole, ConverseOutput, ConverseResponse,
ContentBlock, StopReason, BedrockTokenUsage, ToolResultContentBlock, ToolResultStatus, ConverseTrace, Message as BedrockMessage, PromptRouterTrace, StopReason,
ConverseTrace, PromptRouterTrace ToolResultContentBlock, ToolResultStatus,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesResponse, MessagesRole, MessagesStopReason, ToolResultContent,
}; };
use crate::apis::anthropic::{MessagesResponse, MessagesContentBlock, MessagesStopReason, MessagesRole, ToolResultContent};
use serde_json::json; use serde_json::json;
#[test] #[test]
@ -529,11 +552,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "Hello! How can I help you today?".to_string(),
text: "Hello! How can I help you today?".to_string(), }],
}
],
}, },
}, },
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
@ -592,7 +613,7 @@ mod tests {
"location": "San Francisco" "location": "San Francisco"
}), }),
}, },
} },
], ],
}, },
}, },
@ -624,7 +645,10 @@ mod tests {
} }
// Check tool use content // Check tool use content
if let MessagesContentBlock::ToolUse { id, name, input, .. } = &anthropic_response.content[1] { if let MessagesContentBlock::ToolUse {
id, name, input, ..
} = &anthropic_response.content[1]
{
assert_eq!(id, "tool_12345"); assert_eq!(id, "tool_12345");
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco"); assert_eq!(input["location"], "San Francisco");
@ -649,11 +673,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "Test response".to_string(),
text: "Test response".to_string(), }],
}
],
}, },
}, },
stop_reason: bedrock_stop_reason, stop_reason: bedrock_stop_reason,
@ -670,7 +692,10 @@ mod tests {
}; };
let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap();
assert_eq!(anthropic_response.stop_reason, expected_anthropic_stop_reason); assert_eq!(
anthropic_response.stop_reason,
expected_anthropic_stop_reason
);
} }
} }
@ -680,11 +705,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "Cached response".to_string(),
text: "Cached response".to_string(), }],
}
],
}, },
}, },
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
@ -706,7 +729,10 @@ mod tests {
assert_eq!(anthropic_response.usage.input_tokens, 100); assert_eq!(anthropic_response.usage.input_tokens, 100);
assert_eq!(anthropic_response.usage.output_tokens, 50); assert_eq!(anthropic_response.usage.output_tokens, 50);
assert_eq!(anthropic_response.usage.cache_creation_input_tokens, Some(20)); assert_eq!(
anthropic_response.usage.cache_creation_input_tokens,
Some(20)
);
assert_eq!(anthropic_response.usage.cache_read_input_tokens, Some(10)); assert_eq!(anthropic_response.usage.cache_read_input_tokens, Some(10));
} }
@ -723,14 +749,12 @@ mod tests {
ContentBlock::ToolResult { ContentBlock::ToolResult {
tool_result: crate::apis::amazon_bedrock::ToolResultBlock { tool_result: crate::apis::amazon_bedrock::ToolResultBlock {
tool_use_id: "tool_67890".to_string(), tool_use_id: "tool_67890".to_string(),
content: vec![ content: vec![ToolResultContentBlock::Text {
ToolResultContentBlock::Text { text: "Temperature: 72°F, Sunny".to_string(),
text: "Temperature: 72°F, Sunny".to_string(), }],
}
],
status: Some(ToolResultStatus::Success), status: Some(ToolResultStatus::Success),
}, },
} },
], ],
}, },
}, },
@ -761,7 +785,12 @@ mod tests {
} }
// Check tool result content // Check tool result content
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &anthropic_response.content[1] { if let MessagesContentBlock::ToolResult {
tool_use_id,
content,
..
} = &anthropic_response.content[1]
{
assert_eq!(tool_use_id, "tool_67890"); assert_eq!(tool_use_id, "tool_67890");
if let ToolResultContent::Blocks(blocks) = content { if let ToolResultContent::Blocks(blocks) = content {
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
@ -804,7 +833,7 @@ mod tests {
name: "lookup".to_string(), name: "lookup".to_string(),
input: json!({"id": "12345"}), input: json!({"id": "12345"}),
}, },
} },
], ],
}, },
}, },
@ -870,11 +899,12 @@ mod tests {
name: "test_function".to_string(), name: "test_function".to_string(),
input: json!({"param": "value"}), input: json!({"param": "value"}),
}, },
} },
], ],
}; };
let content_blocks = convert_bedrock_message_to_anthropic_content(&bedrock_message).unwrap(); let content_blocks =
convert_bedrock_message_to_anthropic_content(&bedrock_message).unwrap();
assert_eq!(content_blocks.len(), 2); assert_eq!(content_blocks.len(), 2);
@ -884,7 +914,10 @@ mod tests {
panic!("Expected text content block"); panic!("Expected text content block");
} }
if let MessagesContentBlock::ToolUse { id, name, input, .. } = &content_blocks[1] { if let MessagesContentBlock::ToolUse {
id, name, input, ..
} = &content_blocks[1]
{
assert_eq!(id, "test_tool"); assert_eq!(id, "test_tool");
assert_eq!(name, "test_function"); assert_eq!(name, "test_function");
assert_eq!(input["param"], "value"); assert_eq!(input["param"], "value");
@ -900,11 +933,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "I am an assistant".to_string(),
text: "I am an assistant".to_string(), }],
}
],
}, },
}, },
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
@ -928,11 +959,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::User, role: ConversationRole::User,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "I am a user".to_string(),
text: "I am a user".to_string(), }],
}
],
}, },
}, },
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
@ -985,7 +1014,10 @@ mod tests {
let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap(); let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap();
// Should extract model ID from trace // Should extract model ID from trace
assert_eq!(anthropic_response.model, "anthropic.claude-3-sonnet-20240229-v1:0"); assert_eq!(
anthropic_response.model,
"anthropic.claude-3-sonnet-20240229-v1:0"
);
// Test fallback when no trace information is available // Test fallback when no trace information is available
let bedrock_response_no_trace = ConverseResponse { let bedrock_response_no_trace = ConverseResponse {
@ -1010,7 +1042,8 @@ mod tests {
performance_config: None, performance_config: None,
}; };
let anthropic_response_fallback: MessagesResponse = bedrock_response_no_trace.try_into().unwrap(); let anthropic_response_fallback: MessagesResponse =
bedrock_response_no_trace.try_into().unwrap();
// Should use fallback model name // Should use fallback model name
assert_eq!(anthropic_response_fallback.model, "bedrock-model"); assert_eq!(anthropic_response_fallback.model, "bedrock-model");

View file

@ -1,10 +1,18 @@
use crate::apis::openai::{ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, ResponseMessage, Role, ToolCallDelta, FunctionCallDelta, Usage, StreamChoice, MessageDelta, MessageContent}; use crate::apis::amazon_bedrock::{
use crate::apis::anthropic::{MessagesResponse, MessagesStreamEvent, MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesUsage}; ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason}; };
use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason,
MessagesStreamEvent, MessagesUsage,
};
use crate::apis::openai::{
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice,
ToolCallDelta, Usage,
};
use crate::clients::TransformError; use crate::clients::TransformError;
use crate::transforms::lib::*; use crate::transforms::lib::*;
// ============================================================================ // ============================================================================
// MAIN RESPONSE TRANSFORMATIONS // MAIN RESPONSE TRANSFORMATIONS
// ============================================================================ // ============================================================================
@ -35,7 +43,11 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
MessageContent::Text(text) => Some(text), MessageContent::Text(text) => Some(text),
MessageContent::Parts(parts) => { MessageContent::Parts(parts) => {
let text = parts.extract_text(); let text = parts.extract_text();
if text.is_empty() { None } else { Some(text) } if text.is_empty() {
None
} else {
Some(text)
}
} }
}; };
@ -105,7 +117,6 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
StopReason::ContentFiltered => FinishReason::ContentFilter, StopReason::ContentFiltered => FinishReason::ContentFilter,
}; };
// Create response message // Create response message
let response_message = ResponseMessage { let response_message = ResponseMessage {
role, role,
@ -135,13 +146,17 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
}; };
// Generate a response ID (using timestamp since Bedrock doesn't provide one) // Generate a response ID (using timestamp since Bedrock doesn't provide one)
let id = format!("bedrock-{}", std::time::SystemTime::now() let id = format!(
.duration_since(std::time::UNIX_EPOCH) "bedrock-{}",
.unwrap_or_default() std::time::SystemTime::now()
.as_nanos()); .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
);
// Extract model ID from trace information if available, otherwise use fallback // Extract model ID from trace information if available, otherwise use fallback
let model = resp.trace let model = resp
.trace
.as_ref() .as_ref()
.and_then(|trace| trace.prompt_router.as_ref()) .and_then(|trace| trace.prompt_router.as_ref())
.map(|router| router.invoked_model_id.clone()) .map(|router| router.invoked_model_id.clone())
@ -160,7 +175,6 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
} }
} }
// ============================================================================ // ============================================================================
// STREAMING TRANSFORMATIONS // STREAMING TRANSFORMATIONS
// ============================================================================ // ============================================================================
@ -170,33 +184,27 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> { fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
match event { match event {
MessagesStreamEvent::MessageStart { message } => { MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
Ok(create_openai_chunk( &message.id,
&message.id, &message.model,
&message.model, MessageDelta {
MessageDelta { role: Some(Role::Assistant),
role: Some(Role::Assistant), content: None,
content: None, refusal: None,
refusal: None, function_call: None,
function_call: None, tool_calls: None,
tool_calls: None, },
}, None,
None, None,
None, )),
))
}
MessagesStreamEvent::ContentBlockStart { content_block, .. } => { MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
convert_content_block_start(content_block) convert_content_block_start(content_block)
} }
MessagesStreamEvent::ContentBlockDelta { delta, .. } => { MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
convert_content_delta(delta)
}
MessagesStreamEvent::ContentBlockStop { .. } => { MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
Ok(create_empty_openai_chunk())
}
MessagesStreamEvent::MessageDelta { delta, usage } => { MessagesStreamEvent::MessageDelta { delta, usage } => {
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into()); let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
@ -217,39 +225,34 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
)) ))
} }
MessagesStreamEvent::MessageStop => { MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
Ok(create_openai_chunk( "stream",
"stream", "unknown",
"unknown", MessageDelta {
MessageDelta { role: None,
role: None, content: None,
content: None, refusal: None,
refusal: None, function_call: None,
function_call: None, tool_calls: None,
tool_calls: None, },
}, Some(FinishReason::Stop),
Some(FinishReason::Stop), None,
None, )),
))
}
MessagesStreamEvent::Ping => { MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
Ok(ChatCompletionsStreamResponse { id: "stream".to_string(),
id: "stream".to_string(), object: Some("chat.completion.chunk".to_string()),
object: Some("chat.completion.chunk".to_string()), created: current_timestamp(),
created: current_timestamp(), model: "unknown".to_string(),
model: "unknown".to_string(), choices: vec![],
choices: vec![], usage: None,
usage: None, system_fingerprint: None,
system_fingerprint: None, service_tier: None,
service_tier: None, }),
})
}
} }
} }
} }
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse { impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
type Error = TransformError; type Error = TransformError;
@ -280,29 +283,27 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
use crate::apis::amazon_bedrock::ContentBlockStart; use crate::apis::amazon_bedrock::ContentBlockStart;
match start_event.start { match start_event.start {
ContentBlockStart::ToolUse { tool_use } => { ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
Ok(create_openai_chunk( "stream",
"stream", "unknown",
"unknown", MessageDelta {
MessageDelta { role: None,
role: None, content: None,
content: None, refusal: None,
refusal: None, function_call: None,
function_call: None, tool_calls: Some(vec![ToolCallDelta {
tool_calls: Some(vec![ToolCallDelta { index: start_event.content_block_index as u32,
index: start_event.content_block_index as u32, id: Some(tool_use.tool_use_id),
id: Some(tool_use.tool_use_id), call_type: Some("function".to_string()),
call_type: Some("function".to_string()), function: Some(FunctionCallDelta {
function: Some(FunctionCallDelta { name: Some(tool_use.name),
name: Some(tool_use.name), arguments: Some("".to_string()),
arguments: Some("".to_string()), }),
}), }]),
}]), },
}, None,
None, None,
None, )),
))
}
} }
} }
@ -310,50 +311,44 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
use crate::apis::amazon_bedrock::ContentBlockDelta; use crate::apis::amazon_bedrock::ContentBlockDelta;
match delta_event.delta { match delta_event.delta {
ContentBlockDelta::Text { text } => { ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
Ok(create_openai_chunk( "stream",
"stream", "unknown",
"unknown", MessageDelta {
MessageDelta { role: None,
role: None, content: Some(text),
content: Some(text), refusal: None,
refusal: None, function_call: None,
function_call: None, tool_calls: None,
tool_calls: None, },
}, None,
None, None,
None, )),
)) ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
} "stream",
ContentBlockDelta::ToolUse { tool_use } => { "unknown",
Ok(create_openai_chunk( MessageDelta {
"stream", role: None,
"unknown", content: None,
MessageDelta { refusal: None,
role: None, function_call: None,
content: None, tool_calls: Some(vec![ToolCallDelta {
refusal: None, index: delta_event.content_block_index as u32,
function_call: None, id: None,
tool_calls: Some(vec![ToolCallDelta { call_type: None,
index: delta_event.content_block_index as u32, function: Some(FunctionCallDelta {
id: None, name: None,
call_type: None, arguments: Some(tool_use.input),
function: Some(FunctionCallDelta { }),
name: None, }]),
arguments: Some(tool_use.input), },
}), None,
}]), None,
}, )),
None,
None,
))
}
} }
} }
ConverseStreamEvent::ContentBlockStop(_) => { ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
Ok(create_empty_openai_chunk())
}
ConverseStreamEvent::MessageStop(stop_event) => { ConverseStreamEvent::MessageStop(stop_event) => {
let finish_reason = match stop_event.stop_reason { let finish_reason = match stop_event.stop_reason {
@ -405,27 +400,27 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
} }
// Error events - convert to empty chunks (errors should be handled elsewhere) // Error events - convert to empty chunks (errors should be handled elsewhere)
ConverseStreamEvent::InternalServerException(_) | ConverseStreamEvent::InternalServerException(_)
ConverseStreamEvent::ModelStreamErrorException(_) | | ConverseStreamEvent::ModelStreamErrorException(_)
ConverseStreamEvent::ServiceUnavailableException(_) | | ConverseStreamEvent::ServiceUnavailableException(_)
ConverseStreamEvent::ThrottlingException(_) | | ConverseStreamEvent::ThrottlingException(_)
ConverseStreamEvent::ValidationException(_) => { | ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
Ok(create_empty_openai_chunk())
}
} }
} }
} }
/// Convert content block start to OpenAI chunk /// Convert content block start to OpenAI chunk
fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<ChatCompletionsStreamResponse, TransformError> { fn convert_content_block_start(
content_block: MessagesContentBlock,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match content_block { match content_block {
MessagesContentBlock::Text { .. } => { MessagesContentBlock::Text { .. } => {
// No immediate output for text block start // No immediate output for text block start
Ok(create_empty_openai_chunk()) Ok(create_empty_openai_chunk())
} }
MessagesContentBlock::ToolUse { id, name, .. } | MessagesContentBlock::ToolUse { id, name, .. }
MessagesContentBlock::ServerToolUse { id, name, .. } | | MessagesContentBlock::ServerToolUse { id, name, .. }
MessagesContentBlock::McpToolUse { id, name, .. } => { | MessagesContentBlock::McpToolUse { id, name, .. } => {
// Tool use start → OpenAI chunk with tool_calls // Tool use start → OpenAI chunk with tool_calls
Ok(create_openai_chunk( Ok(create_openai_chunk(
"stream", "stream",
@ -449,66 +444,64 @@ fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<Ch
None, None,
)) ))
} }
_ => Err(TransformError::UnsupportedContent("Unsupported content block type in stream start".to_string())), _ => Err(TransformError::UnsupportedContent(
"Unsupported content block type in stream start".to_string(),
)),
} }
} }
/// Convert content delta to OpenAI chunk /// Convert content delta to OpenAI chunk
fn convert_content_delta(delta: MessagesContentDelta) -> Result<ChatCompletionsStreamResponse, TransformError> { fn convert_content_delta(
delta: MessagesContentDelta,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match delta { match delta {
MessagesContentDelta::TextDelta { text } => { MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
Ok(create_openai_chunk( "stream",
"stream", "unknown",
"unknown", MessageDelta {
MessageDelta { role: None,
role: None, content: Some(text),
content: Some(text), refusal: None,
refusal: None, function_call: None,
function_call: None, tool_calls: None,
tool_calls: None, },
}, None,
None, None,
None, )),
)) MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
} "stream",
MessagesContentDelta::ThinkingDelta { thinking } => { "unknown",
Ok(create_openai_chunk( MessageDelta {
"stream", role: None,
"unknown", content: Some(format!("thinking: {}", thinking)),
MessageDelta { refusal: None,
role: None, function_call: None,
content: Some(format!("thinking: {}", thinking)), tool_calls: None,
refusal: None, },
function_call: None, None,
tool_calls: None, None,
}, )),
None, MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
None, "stream",
)) "unknown",
} MessageDelta {
MessagesContentDelta::InputJsonDelta { partial_json } => { role: None,
Ok(create_openai_chunk( content: None,
"stream", refusal: None,
"unknown", function_call: None,
MessageDelta { tool_calls: Some(vec![ToolCallDelta {
role: None, index: 0,
content: None, id: None,
refusal: None, call_type: None,
function_call: None, function: Some(FunctionCallDelta {
tool_calls: Some(vec![ToolCallDelta { name: None,
index: 0, arguments: Some(partial_json),
id: None, }),
call_type: None, }]),
function: Some(FunctionCallDelta { },
name: None, None,
arguments: Some(partial_json), None,
}), )),
}]),
},
None,
None,
))
}
} }
} }
@ -518,7 +511,7 @@ fn create_openai_chunk(
model: &str, model: &str,
delta: MessageDelta, delta: MessageDelta,
finish_reason: Option<FinishReason>, finish_reason: Option<FinishReason>,
usage: Option<Usage> usage: Option<Usage>,
) -> ChatCompletionsStreamResponse { ) -> ChatCompletionsStreamResponse {
ChatCompletionsStreamResponse { ChatCompletionsStreamResponse {
id: id.to_string(), id: id.to_string(),
@ -555,7 +548,9 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
} }
/// Convert Anthropic content blocks to OpenAI message content /// Convert Anthropic content blocks to OpenAI message content
fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Result<MessageContent, TransformError> { fn convert_anthropic_content_to_openai(
content: &[MessagesContentBlock],
) -> Result<MessageContent, TransformError> {
let mut text_parts = Vec::new(); let mut text_parts = Vec::new();
for block in content { for block in content {
@ -592,9 +587,11 @@ impl Into<FinishReason> for MessagesStopReason {
/// Convert Bedrock Message to OpenAI content and tool calls /// Convert Bedrock Message to OpenAI content and tool calls
/// This function extracts text content and tool calls from a Bedrock message /// This function extracts text content and tool calls from a Bedrock message
fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Message) -> Result<(Option<String>, Option<Vec<crate::apis::openai::ToolCall>>), TransformError> { fn convert_bedrock_message_to_openai(
message: &crate::apis::amazon_bedrock::Message,
) -> Result<(Option<String>, Option<Vec<crate::apis::openai::ToolCall>>), TransformError> {
use crate::apis::amazon_bedrock::ContentBlock; use crate::apis::amazon_bedrock::ContentBlock;
use crate::apis::openai::{ToolCall, FunctionCall}; use crate::apis::openai::{FunctionCall, ToolCall};
let mut text_content = String::new(); let mut text_content = String::new();
let mut tool_calls = Vec::new(); let mut tool_calls = Vec::new();
@ -614,12 +611,20 @@ fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Mess
}, },
}); });
} }
_ => continue, _ => continue,
} }
} }
let content = if text_content.is_empty() { None } else { Some(text_content) }; let content = if text_content.is_empty() {
let tool_calls = if tool_calls.is_empty() { None } else { Some(tool_calls) }; None
} else {
Some(text_content)
};
let tool_calls = if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
};
Ok((content, tool_calls)) Ok((content, tool_calls))
} }
@ -628,8 +633,8 @@ fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Mess
mod tests { mod tests {
use super::*; use super::*;
use crate::apis::amazon_bedrock::{ use crate::apis::amazon_bedrock::{
ConverseResponse, ConverseOutput, Message as BedrockMessage, ConversationRole, BedrockTokenUsage, ContentBlock, ConversationRole, ConverseOutput, ConverseResponse,
ContentBlock, StopReason, BedrockTokenUsage, ConverseTrace, PromptRouterTrace ConverseTrace, Message as BedrockMessage, PromptRouterTrace, StopReason,
}; };
use crate::apis::openai::{ChatCompletionsResponse, FinishReason, Role}; use crate::apis::openai::{ChatCompletionsResponse, FinishReason, Role};
use serde_json::json; use serde_json::json;
@ -640,11 +645,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "Hello! How can I help you today?".to_string(),
text: "Hello! How can I help you today?".to_string(), }],
}
],
}, },
}, },
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
@ -671,7 +674,10 @@ mod tests {
let choice = &openai_response.choices[0]; let choice = &openai_response.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.message.role, Role::Assistant); assert_eq!(choice.message.role, Role::Assistant);
assert_eq!(choice.message.content, Some("Hello! How can I help you today?".to_string())); assert_eq!(
choice.message.content,
Some("Hello! How can I help you today?".to_string())
);
assert_eq!(choice.finish_reason, Some(FinishReason::Stop)); assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
assert!(choice.message.tool_calls.is_none()); assert!(choice.message.tool_calls.is_none());
@ -699,7 +705,7 @@ mod tests {
"location": "San Francisco" "location": "San Francisco"
}), }),
}, },
} },
], ],
}, },
}, },
@ -718,11 +724,21 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::ToolCalls)); assert_eq!(
assert_eq!(openai_response.choices[0].message.content, Some("I'll help you check the weather.".to_string())); openai_response.choices[0].finish_reason,
Some(FinishReason::ToolCalls)
);
assert_eq!(
openai_response.choices[0].message.content,
Some("I'll help you check the weather.".to_string())
);
// Check tool calls // Check tool calls
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap(); let tool_calls = openai_response.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0]; let tool_call = &tool_calls[0];
@ -750,11 +766,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "Test response".to_string(),
text: "Test response".to_string(), }],
}
],
}, },
}, },
stop_reason: bedrock_stop_reason, stop_reason: bedrock_stop_reason,
@ -771,7 +785,10 @@ mod tests {
}; };
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
assert_eq!(openai_response.choices[0].finish_reason, Some(expected_openai_finish_reason)); assert_eq!(
openai_response.choices[0].finish_reason,
Some(expected_openai_finish_reason)
);
} }
} }
@ -798,7 +815,7 @@ mod tests {
name: "lookup".to_string(), name: "lookup".to_string(),
input: json!({"id": "12345"}), input: json!({"id": "12345"}),
}, },
} },
], ],
}, },
}, },
@ -817,23 +834,35 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::ToolCalls)); assert_eq!(
assert_eq!(openai_response.choices[0].message.content, Some("I'll help with multiple tasks.".to_string())); openai_response.choices[0].finish_reason,
Some(FinishReason::ToolCalls)
);
assert_eq!(
openai_response.choices[0].message.content,
Some("I'll help with multiple tasks.".to_string())
);
// Check multiple tool calls // Check multiple tool calls
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap(); let tool_calls = openai_response.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 2); assert_eq!(tool_calls.len(), 2);
// First tool call // First tool call
assert_eq!(tool_calls[0].id, "tool_1"); assert_eq!(tool_calls[0].id, "tool_1");
assert_eq!(tool_calls[0].function.name, "search"); assert_eq!(tool_calls[0].function.name, "search");
let args1: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); let args1: serde_json::Value =
serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
assert_eq!(args1["query"], "weather"); assert_eq!(args1["query"], "weather");
// Second tool call // Second tool call
assert_eq!(tool_calls[1].id, "tool_2"); assert_eq!(tool_calls[1].id, "tool_2");
assert_eq!(tool_calls[1].function.name, "lookup"); assert_eq!(tool_calls[1].function.name, "lookup");
let args2: serde_json::Value = serde_json::from_str(&tool_calls[1].function.arguments).unwrap(); let args2: serde_json::Value =
serde_json::from_str(&tool_calls[1].function.arguments).unwrap();
assert_eq!(args2["id"], "12345"); assert_eq!(args2["id"], "12345");
} }
@ -856,7 +885,7 @@ mod tests {
}, },
ContentBlock::Text { ContentBlock::Text {
text: "Second part.".to_string(), text: "Second part.".to_string(),
} },
], ],
}, },
}, },
@ -876,10 +905,17 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
// Content should be combined text parts (no separator added) // Content should be combined text parts (no separator added)
assert_eq!(openai_response.choices[0].message.content, Some("First part. Second part.".to_string())); assert_eq!(
openai_response.choices[0].message.content,
Some("First part. Second part.".to_string())
);
// Should have one tool call // Should have one tool call
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap(); let tool_calls = openai_response.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "tool_mid"); assert_eq!(tool_calls[0].id, "tool_mid");
assert_eq!(tool_calls[0].function.name, "calculate"); assert_eq!(tool_calls[0].function.name, "calculate");
@ -891,15 +927,13 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
content: vec![ content: vec![ContentBlock::ToolUse {
ContentBlock::ToolUse { tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
tool_use: crate::apis::amazon_bedrock::ToolUseBlock { tool_use_id: "tool_only".to_string(),
tool_use_id: "tool_only".to_string(), name: "action".to_string(),
name: "action".to_string(), input: json!({}),
input: json!({}), },
}, }],
}
],
}, },
}, },
stop_reason: StopReason::ToolUse, stop_reason: StopReason::ToolUse,
@ -921,7 +955,11 @@ mod tests {
assert_eq!(openai_response.choices[0].message.content, None); assert_eq!(openai_response.choices[0].message.content, None);
// Should have tool call // Should have tool call
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap(); let tool_calls = openai_response.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "tool_only"); assert_eq!(tool_calls[0].id, "tool_only");
} }
@ -940,7 +978,7 @@ mod tests {
name: "test_function".to_string(), name: "test_function".to_string(),
input: json!({"param": "value"}), input: json!({"param": "value"}),
}, },
} },
], ],
}; };
@ -953,7 +991,8 @@ mod tests {
assert_eq!(tool_calls[0].id, "test_tool"); assert_eq!(tool_calls[0].id, "test_tool");
assert_eq!(tool_calls[0].function.name, "test_function"); assert_eq!(tool_calls[0].function.name, "test_function");
let args: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); let args: serde_json::Value =
serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
assert_eq!(args["param"], "value"); assert_eq!(args["param"], "value");
} }
@ -964,11 +1003,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::Assistant, role: ConversationRole::Assistant,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "I am an assistant".to_string(),
text: "I am an assistant".to_string(), }],
}
],
}, },
}, },
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
@ -992,11 +1029,9 @@ mod tests {
output: ConverseOutput::Message { output: ConverseOutput::Message {
message: BedrockMessage { message: BedrockMessage {
role: ConversationRole::User, role: ConversationRole::User,
content: vec![ content: vec![ContentBlock::Text {
ContentBlock::Text { text: "I am a user".to_string(),
text: "I am a user".to_string(), }],
}
],
}, },
}, },
stop_reason: StopReason::EndTurn, stop_reason: StopReason::EndTurn,
@ -1049,7 +1084,10 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
// Should extract model ID from trace // Should extract model ID from trace
assert_eq!(openai_response.model, "anthropic.claude-3-sonnet-20240229-v1:0"); assert_eq!(
openai_response.model,
"anthropic.claude-3-sonnet-20240229-v1:0"
);
// Test fallback when no trace information is available // Test fallback when no trace information is available
let bedrock_response_no_trace = ConverseResponse { let bedrock_response_no_trace = ConverseResponse {
@ -1074,7 +1112,8 @@ mod tests {
performance_config: None, performance_config: None,
}; };
let openai_response_fallback: ChatCompletionsResponse = bedrock_response_no_trace.try_into().unwrap(); let openai_response_fallback: ChatCompletionsResponse =
bedrock_response_no_trace.try_into().unwrap();
// Should use fallback model name // Should use fallback model name
assert_eq!(openai_response_fallback.model, "bedrock-model"); assert_eq!(openai_response_fallback.model, "bedrock-model");
@ -1082,7 +1121,7 @@ mod tests {
#[test] #[test]
fn test_bedrock_to_openai_with_multimedia_content() { fn test_bedrock_to_openai_with_multimedia_content() {
use crate::apis::amazon_bedrock::{ImageSource}; use crate::apis::amazon_bedrock::ImageSource;
let bedrock_response = ConverseResponse { let bedrock_response = ConverseResponse {
output: ConverseOutput::Message { output: ConverseOutput::Message {
@ -1118,7 +1157,10 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap(); let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::Stop)); assert_eq!(
openai_response.choices[0].finish_reason,
Some(FinishReason::Stop)
);
let content = openai_response.choices[0].message.content.as_ref().unwrap(); let content = openai_response.choices[0].message.content.as_ref().unwrap();