cargo fmt

This commit is contained in:
Adil Hafeez 2025-10-21 16:31:29 -07:00
parent aec052a843
commit d35d068d0d
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
25 changed files with 1978 additions and 1258 deletions

View file

@ -1,17 +1,17 @@
use std::sync::Arc;
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::{ModelAlias, ModelUsagePreference};
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_IS_STREAMING_HEADER};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::clients::SupportedAPIs;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::clients::SupportedAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
@ -31,14 +31,19 @@ pub async fn chat(
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
debug!(
"Received request body (raw utf8): {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
let mut client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err);
@ -79,18 +84,30 @@ pub async fn chat(
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
match ProviderRequestType::try_from((client_request, &SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
match ProviderRequestType::try_from((
client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions(
hermesllm::apis::OpenAIApi::ChatCompletions,
),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(ProviderRequestType::MessagesRequest(_) | ProviderRequestType::BedrockConverse(_) | ProviderRequestType::BedrockConverseStream(_)) => {
Ok(
ProviderRequestType::MessagesRequest(_)
| ProviderRequestType::BedrockConverse(_)
| ProviderRequestType::BedrockConverseStream(_),
) => {
// This should not happen after conversion to OpenAI format
warn!("Unexpected: got MessagesRequest after converting to OpenAI format");
let err_msg = "Request conversion failed".to_string();
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
},
}
Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
let err_msg = format!("Failed to convert request: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
@ -108,28 +125,29 @@ pub async fn chat(
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> =
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
});
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
let latest_message_for_log =
chat_completions_request_for_arch_router
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.to_string().replace('\n', "\\n")
});
let latest_message_for_log = chat_completions_request_for_arch_router
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.to_string().replace('\n', "\\n")
});
const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
let truncated: String = latest_message_for_log.chars().take(MAX_MESSAGE_LENGTH).collect();
let truncated: String = latest_message_for_log
.chars()
.take(MAX_MESSAGE_LENGTH)
.collect();
format!("{}...", truncated)
} else {
latest_message_for_log
@ -155,12 +173,11 @@ pub async fn chat(
Ok(route) => match route {
Some((_, model_name)) => model_name,
None => {
info!(
info!(
"No route determined, using default model from request: {}",
chat_completions_request_for_arch_router.model
);
chat_completions_request_for_arch_router.model.clone()
}
},
Err(err) => {

View file

@ -68,8 +68,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
&serde_json::to_string(arch_config.as_ref()).unwrap()
);
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:12001".to_string());
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
info!("llm provider url: {}", llm_provider_url);
info!("listening on http://{}", bind_address);
@ -96,7 +96,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::new(arch_config.model_aliases.clone());
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
@ -108,7 +107,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
@ -118,7 +116,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url, model_aliases)
.with_context(parent_cx)
.await

View file

@ -1,9 +1,7 @@
use std::collections::HashMap;
use common::{
configuration::{ModelUsagePreference, RoutingPreference},
};
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
use common::configuration::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};

View file

@ -1,20 +1,27 @@
use std::sync::OnceLock;
use std::fmt;
use std::sync::OnceLock;
use opentelemetry::global;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields};
use tracing::{Event, Subscriber};
use time::macros::format_description;
use tracing::{Event, Subscriber};
use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields};
use tracing_subscriber::EnvFilter;
struct BracketedTime;
impl FormatTime for BracketedTime {
fn format_time(&self, w: &mut format::Writer<'_>) -> fmt::Result {
let now = time::OffsetDateTime::now_utc();
write!(w, "[{}]", now.format(&format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]")).unwrap())
write!(
w,
"[{}]",
now.format(&format_description!(
"[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"
))
.unwrap()
)
}
}
@ -34,7 +41,11 @@ where
let timer = BracketedTime;
timer.format_time(&mut writer)?;
write!(writer, "[{}] ", event.metadata().level().to_string().to_lowercase())?;
write!(
writer,
"[{}] ",
event.metadata().level().to_string().to_lowercase()
)?;
ctx.field_format().format_fields(writer.by_ref(), event)?;

View file

@ -33,7 +33,6 @@ pub fn get_llm_provider(
return provider;
}
if llm_providers.default().is_some() {
return llm_providers.default().unwrap();
}

View file

@ -2,8 +2,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use thiserror::Error;
use std::collections::HashMap;
use thiserror::Error;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
@ -56,10 +56,7 @@ impl ApiDefinition for AmazonBedrockApi {
}
fn all_variants() -> Vec<Self> {
vec![
AmazonBedrockApi::Converse,
AmazonBedrockApi::ConverseStream,
]
vec![AmazonBedrockApi::Converse, AmazonBedrockApi::ConverseStream]
}
}
@ -173,7 +170,9 @@ impl ProviderRequest for ConverseRequest {
SystemContentBlock::Text { text } => {
text_parts.push(text.clone());
}
SystemContentBlock::GuardContent { text: Some(guard_text) } => {
SystemContentBlock::GuardContent {
text: Some(guard_text),
} => {
text_parts.push(guard_text.text.clone());
}
SystemContentBlock::GuardContent { text: None } => {
@ -194,11 +193,9 @@ impl ProviderRequest for ConverseRequest {
.find(|msg| msg.role == ConversationRole::User)
.and_then(|msg| {
// Extract the first text content block from the user message
msg.content.iter().find_map(|content| {
match content {
ContentBlock::Text { text } => Some(text.clone()),
_ => None,
}
msg.content.iter().find_map(|content| match content {
ContentBlock::Text { text } => Some(text.clone()),
_ => None,
})
})
}
@ -294,13 +291,13 @@ pub enum ConversationRole {
#[serde(untagged)]
pub enum ContentBlock {
Text {
text: String
text: String,
},
Image {
image: ImageBlock
image: ImageBlock,
},
Document {
document: DocumentBlock
document: DocumentBlock,
},
ToolUse {
#[serde(rename = "toolUse")]
@ -360,9 +357,7 @@ pub enum SystemContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "guardContent")]
GuardContent {
text: Option<GuardContentText>,
},
GuardContent { text: Option<GuardContentText> },
}
/// Image source for vision capabilities
@ -723,10 +718,12 @@ pub struct ContentBlockDeltaEvent {
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum ContentBlockDelta {
Text { text: String },
Text {
text: String,
},
ToolUse {
#[serde(rename = "toolUse")]
tool_use: ToolUseDelta
tool_use: ToolUseDelta,
},
}
@ -1008,7 +1005,9 @@ impl Into<String> for ConverseStreamEvent {
ConverseStreamEvent::Metadata { .. } => "metadata",
ConverseStreamEvent::InternalServerException { .. } => "internal_server_exception",
ConverseStreamEvent::ModelStreamErrorException { .. } => "model_stream_error_exception",
ConverseStreamEvent::ServiceUnavailableException { .. } => "service_unavailable_exception",
ConverseStreamEvent::ServiceUnavailableException { .. } => {
"service_unavailable_exception"
}
ConverseStreamEvent::ThrottlingException { .. } => "throttling_exception",
ConverseStreamEvent::ValidationException { .. } => "validation_exception",
};
@ -1019,17 +1018,14 @@ impl Into<String> for ConverseStreamEvent {
}
}
// Implement ProviderStreamResponse for ConverseStreamEvent
impl ProviderStreamResponse for ConverseStreamEvent {
fn content_delta(&self) -> Option<&str> {
match self {
ConverseStreamEvent::ContentBlockDelta(event) => {
match &event.delta {
ContentBlockDelta::Text { text } => Some(text),
ContentBlockDelta::ToolUse { .. } => None,
}
}
ConverseStreamEvent::ContentBlockDelta(event) => match &event.delta {
ContentBlockDelta::Text { text } => Some(text),
ContentBlockDelta::ToolUse { .. } => None,
},
_ => None,
}
}
@ -1099,7 +1095,10 @@ mod tests {
};
let serialized = serde_json::to_value(&tool).unwrap();
println!("Tool serialization: {}", serde_json::to_string_pretty(&serialized).unwrap());
println!(
"Tool serialization: {}",
serde_json::to_string_pretty(&serialized).unwrap()
);
// Verify the structure matches Bedrock API expectations
assert!(serialized.get("toolSpec").is_some());
@ -1107,16 +1106,24 @@ mod tests {
let tool_spec = serialized.get("toolSpec").unwrap();
assert_eq!(tool_spec.get("name").unwrap(), "get_weather");
assert_eq!(tool_spec.get("description").unwrap(), "Get the current weather for a specified city");
assert_eq!(
tool_spec.get("description").unwrap(),
"Get the current weather for a specified city"
);
assert!(tool_spec.get("inputSchema").is_some());
}
#[test]
fn test_tool_choice_serialization_format() {
// Test Auto choice
let auto_choice = ToolChoice::Auto { auto: AutoChoice {} };
let auto_choice = ToolChoice::Auto {
auto: AutoChoice {},
};
let serialized = serde_json::to_value(&auto_choice).unwrap();
println!("Auto ToolChoice serialization: {}", serde_json::to_string_pretty(&serialized).unwrap());
println!(
"Auto ToolChoice serialization: {}",
serde_json::to_string_pretty(&serialized).unwrap()
);
assert!(serialized.get("auto").is_some());
assert!(serialized.get("type").is_none()); // Should not have a type field
@ -1124,11 +1131,14 @@ mod tests {
// Test Tool choice
let tool_choice = ToolChoice::Tool {
tool: ToolChoiceSpec {
name: "get_weather".to_string()
}
name: "get_weather".to_string(),
},
};
let serialized = serde_json::to_value(&tool_choice).unwrap();
println!("Tool ToolChoice serialization: {}", serde_json::to_string_pretty(&serialized).unwrap());
println!(
"Tool ToolChoice serialization: {}",
serde_json::to_string_pretty(&serialized).unwrap()
);
assert!(serialized.get("tool").is_some());
assert!(serialized.get("type").is_none()); // Should not have a type field

View file

@ -1,7 +1,7 @@
use std::collections::HashSet;
use bytes::Buf;
use aws_smithy_eventstream::frame::DecodedFrame;
use aws_smithy_eventstream::frame::MessageFrameDecoder;
use bytes::Buf;
use std::collections::HashSet;
/// AWS Event Stream frame decoder wrapper
pub struct BedrockBinaryFrameDecoder<B>

View file

@ -8,7 +8,7 @@ use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
use crate::transforms::lib::ExtractText;
use crate::{MESSAGES_PATH};
use crate::MESSAGES_PATH;
// Enum for all supported Anthropic APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi {
}
fn all_variants() -> Vec<Self> {
vec![
AnthropicApi::Messages,
]
vec![AnthropicApi::Messages]
}
}
@ -100,7 +98,6 @@ pub struct McpServer {
pub tool_configuration: Option<McpToolConfiguration>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesRequest {
@ -121,10 +118,8 @@ pub struct MessagesRequest {
pub stop_sequences: Option<Vec<String>>,
pub tools: Option<Vec<MessagesTool>>,
pub tool_choice: Option<MessagesToolChoice>,
}
// Messages API specific types
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
@ -235,34 +230,21 @@ impl ExtractText for Vec<MessagesContentBlock> {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
Base64 { media_type: String, data: String },
Url { url: String },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesDocumentSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
File {
file_id: String,
},
Base64 { media_type: String, data: String },
Url { url: String },
File { file_id: String },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent {
fn extract_text(&self) -> String {
match self {
MessagesMessageContent::Single(text) => text.clone(),
MessagesMessageContent::Blocks(parts) => parts.extract_text()
MessagesMessageContent::Blocks(parts) => parts.extract_text(),
}
}
}
@ -320,7 +302,6 @@ pub struct MessagesToolChoice {
pub disable_parallel_tool_use: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MessagesStopReason {
@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse {
Some(self)
}
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
Some((
self.usage.input_tokens as usize,
self.usage.output_tokens as usize,
(self.usage.input_tokens + self.usage.output_tokens) as usize,
))
}
}
@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest {
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
@ -572,13 +557,11 @@ impl MessagesRole {
impl ProviderStreamResponse for MessagesStreamEvent {
fn content_delta(&self) -> Option<&str> {
match self {
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
match delta {
MessagesContentDelta::TextDelta { text } => Some(text),
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
_ => None,
}
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta {
MessagesContentDelta::TextDelta { text } => Some(text),
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
_ => None,
},
_ => None,
}
}
@ -627,7 +610,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -687,7 +671,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -730,7 +715,10 @@ mod tests {
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["system"], original_json["system"]);
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
assert_eq!(
serialized_json["service_tier"],
original_json["service_tier"]
);
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
@ -818,7 +806,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -833,7 +822,10 @@ mod tests {
// Validate text content block
if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] {
assert_eq!(text, "What can you see in this image and what's the weather like?");
assert_eq!(
text,
"What can you see in this image and what's the weather like?"
);
} else {
panic!("Expected text content block");
}
@ -861,20 +853,32 @@ mod tests {
// Validate thinking content block
if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] {
assert_eq!(thinking, "Let me analyze the image and then check the weather...");
assert_eq!(
thinking,
"Let me analyze the image and then check the weather..."
);
} else {
panic!("Expected thinking content block");
}
// Validate text content block
if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] {
assert_eq!(text, "I can see the image. Let me check the weather for you.");
assert_eq!(
text,
"I can see the image. Let me check the weather for you."
);
} else {
panic!("Expected text content block");
}
// Validate tool use content block
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] {
if let MessagesContentBlock::ToolUse {
ref id,
ref name,
ref input,
..
} = content_blocks[2]
{
assert_eq!(id, "toolu_weather123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
@ -892,7 +896,10 @@ mod tests {
let tool = &tools[0];
assert_eq!(tool.name, "get_weather");
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
assert_eq!(
tool.description,
Some("Get current weather information for a location".to_string())
);
assert_eq!(tool.input_schema["type"], "object");
assert!(tool.input_schema["properties"]["location"].is_object());
@ -938,10 +945,16 @@ mod tests {
assert_eq!(deserialized_mcp.name, "test-server");
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
assert_eq!(
deserialized_mcp.authorization_token,
Some("secret-token".to_string())
);
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
assert_eq!(
tool_config.allowed_tools,
Some(vec!["tool1".to_string(), "tool2".to_string()])
);
assert_eq!(tool_config.enabled, Some(true));
} else {
panic!("Expected tool configuration");
@ -957,7 +970,8 @@ mod tests {
"url": "https://minimal.com/mcp"
});
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
let deserialized_minimal: McpServer =
serde_json::from_value(minimal_mcp_json.clone()).unwrap();
assert_eq!(deserialized_minimal.name, "minimal-server");
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
@ -991,12 +1005,16 @@ mod tests {
}
});
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
let deserialized_response: MessagesResponse =
serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.id, "msg_01ABC123");
assert_eq!(deserialized_response.obj_type, "message");
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
assert_eq!(
deserialized_response.stop_reason,
MessagesStopReason::EndTurn
);
assert!(deserialized_response.stop_sequence.is_none());
assert!(deserialized_response.container.is_none());
@ -1011,7 +1029,10 @@ mod tests {
// Check usage
assert_eq!(deserialized_response.usage.input_tokens, 10);
assert_eq!(deserialized_response.usage.output_tokens, 25);
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
assert_eq!(
deserialized_response.usage.cache_creation_input_tokens,
Some(5)
);
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
@ -1027,7 +1048,8 @@ mod tests {
}
});
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
let deserialized_event: MessagesStreamEvent =
serde_json::from_value(stream_event_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0);
if let MessagesContentDelta::TextDelta { text } = delta {
@ -1055,8 +1077,15 @@ mod tests {
}
});
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use {
let deserialized_tool_use: MessagesContentBlock =
serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse {
ref id,
ref name,
ref input,
..
} = deserialized_tool_use
{
assert_eq!(id, "toolu_01ABC123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
@ -1079,8 +1108,15 @@ mod tests {
]
});
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result {
let deserialized_tool_result: MessagesContentBlock =
serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult {
ref tool_use_id,
ref is_error,
ref content,
..
} = deserialized_tool_result
{
assert_eq!(tool_use_id, "toolu_01ABC123");
assert!(is_error.is_none());
if let ToolResultContent::Blocks(blocks) = content {
@ -1229,7 +1265,8 @@ mod tests {
});
// Deserialize the complex MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(complex_request_json.clone()).unwrap();
// Verify basic fields
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
@ -1239,8 +1276,15 @@ mod tests {
// Verify system message with cache_control
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
assert_eq!(system_blocks.len(), 2);
if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] {
assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude.");
if let MessagesContentBlock::Text {
text,
cache_control,
} = &system_blocks[0]
{
assert_eq!(
text,
"You are Claude Code, Anthropic's official CLI for Claude."
);
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
} else {
panic!("Expected text system message with cache_control");
@ -1253,7 +1297,13 @@ mod tests {
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, MessagesRole::Assistant);
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] {
if let MessagesContentBlock::ToolUse {
id,
name,
input,
cache_control,
} = &content_blocks[0]
{
assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl");
assert_eq!(name, "TodoWrite");
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
@ -1272,7 +1322,12 @@ mod tests {
let user_message = &deserialized_request.messages[2];
assert_eq!(user_message.role, MessagesRole::User);
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] {
if let MessagesContentBlock::ToolResult {
tool_use_id,
content,
..
} = &content_blocks[0]
{
assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl");
if let ToolResultContent::Text(text) = content {
assert!(text.contains("Todos have been modified successfully"));
@ -1284,7 +1339,11 @@ mod tests {
}
// Verify text content with cache_control
if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] {
if let MessagesContentBlock::Text {
text,
cache_control,
} = &content_blocks[2]
{
assert_eq!(text, "try again");
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
} else {
@ -1296,11 +1355,15 @@ mod tests {
// Test serialization round-trip
let serialized_request = serde_json::to_value(&deserialized_request).unwrap();
let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap();
let re_deserialized_request: MessagesRequest =
serde_json::from_value(serialized_request).unwrap();
// Verify round-trip consistency
assert_eq!(deserialized_request.model, re_deserialized_request.model);
assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len());
assert_eq!(
deserialized_request.messages.len(),
re_deserialized_request.messages.len()
);
}
#[test]
@ -1339,7 +1402,8 @@ mod tests {
}
});
let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap();
let deserialized_event: MessagesStreamEvent =
serde_json::from_value(thinking_delta_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0);
if let MessagesContentDelta::ThinkingDelta { thinking } = delta {
@ -1352,7 +1416,10 @@ mod tests {
}
// Test that thinking delta is returned by content_delta()
assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current"));
assert_eq!(
deserialized_event.content_delta(),
Some(".\n\nI need to consider:\n1. Current")
);
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
assert_eq!(thinking_delta_json, serialized_event_json);
@ -1376,7 +1443,8 @@ mod tests {
}
});
let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(request_json.clone()).unwrap();
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
assert_eq!(deserialized_request.max_tokens, 2048);

View file

@ -1,15 +1,19 @@
pub mod anthropic;
pub mod openai;
pub mod amazon_bedrock;
pub mod amazon_bedrock_binary_frame;
pub mod anthropic;
pub mod openai;
pub mod sse;
// Explicit exports to avoid naming conflicts
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
pub use openai::{OpenAIApi, ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice};
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
pub use amazon_bedrock::{Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice};
pub use amazon_bedrock::{
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
};
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
pub use openai::{
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
};
pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice};
pub trait ApiDefinition {
/// Returns the endpoint path for this API
@ -56,11 +60,7 @@ mod tests {
#[test]
fn test_api_detection_from_endpoints() {
// Test that we can detect APIs from endpoints using the trait
let endpoints = vec![
CHAT_COMPLETIONS_PATH,
MESSAGES_PATH,
"/v1/unknown"
];
let endpoints = vec![CHAT_COMPLETIONS_PATH, MESSAGES_PATH, "/v1/unknown"];
let mut detected_apis = Vec::new();
@ -74,11 +74,14 @@ mod tests {
}
}
assert_eq!(detected_apis, vec![
"OpenAI: ChatCompletions",
"Anthropic: Messages",
"Unknown API"
]);
assert_eq!(
detected_apis,
vec![
"OpenAI: ChatCompletions",
"Anthropic: Messages",
"Unknown API"
]
);
}
#[test]

View file

@ -5,11 +5,11 @@ use std::collections::HashMap;
use std::fmt::Display;
use thiserror::Error;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
use super::ApiDefinition;
use crate::transforms::lib::ExtractText;
use crate::{CHAT_COMPLETIONS_PATH};
use crate::CHAT_COMPLETIONS_PATH;
// ============================================================================
// OPENAI API ENUMERATION
@ -46,7 +46,7 @@ impl ApiDefinition for OpenAIApi {
}
fn supports_tools(&self) -> bool {
match self {
match self {
OpenAIApi::ChatCompletions => true,
}
}
@ -58,9 +58,7 @@ impl ApiDefinition for OpenAIApi {
}
fn all_variants() -> Vec<Self> {
vec![
OpenAIApi::ChatCompletions,
]
vec![OpenAIApi::ChatCompletions]
}
}
@ -190,7 +188,9 @@ impl ResponseMessage {
pub fn to_message(&self) -> Message {
Message {
role: self.role.clone(),
content: self.content.as_ref()
content: self
.content
.as_ref()
.map(|s| MessageContent::Text(s.clone()))
.unwrap_or(MessageContent::Text(String::new())),
name: None, // Response messages don't have names in the same way request messages do
@ -215,7 +215,7 @@ impl ExtractText for MessageContent {
fn extract_text(&self) -> String {
match self {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts.extract_text()
MessageContent::Parts(parts) => parts.extract_text(),
}
}
}
@ -274,7 +274,6 @@ pub struct ImageUrl {
/// A single message in a chat conversation
/// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall {
@ -374,7 +373,6 @@ pub enum StaticContentType {
Parts(Vec<ContentPart>),
}
/// Chat completions API response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -496,7 +494,6 @@ pub struct ChatCompletionsStreamResponse {
pub service_tier: Option<String>,
}
/// A choice in a streaming response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -566,7 +563,6 @@ pub struct Models {
pub data: Vec<ModelDetail>,
}
// Error type for streaming operations
#[derive(Debug, thiserror::Error)]
pub enum OpenAIStreamError {
@ -597,13 +593,13 @@ pub enum OpenAIError {
/// Trait Implementations
/// ===========================================================================
/// Parameterized conversion for ChatCompletionsRequest
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
let mut req: ChatCompletionsRequest =
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
// Use the centralized suppression logic
req.suppress_max_tokens_if_o3();
req.fix_temperature_if_gpt5();
@ -651,13 +647,18 @@ impl ProviderRequest for ChatCompletionsRequest {
fn extract_messages_text(&self) -> String {
self.messages.iter().fold(String::new(), |acc, m| {
acc + " " + &match &m.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
ContentPart::Text { text } => text.clone(),
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
}).collect::<Vec<_>>().join(" ")
}
acc + " "
+ &match &m.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts
.iter()
.map(|part| match part {
ContentPart::Text { text } => text.clone(),
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
})
.collect::<Vec<_>>()
.join(" "),
}
})
}
@ -721,14 +722,14 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
}
fn role(&self) -> Option<&str> {
self.choices
.first()
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
self.choices.first().and_then(|choice| {
choice.delta.role.as_ref().map(|r| match r {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}))
})
})
}
fn event_type(&self) -> Option<&str> {
@ -736,7 +737,6 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -756,7 +756,8 @@ mod tests {
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "gpt-4");
@ -799,7 +800,8 @@ mod tests {
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "gpt-4");
@ -836,7 +838,10 @@ mod tests {
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["stream"], original_json["stream"]);
assert_eq!(serialized_json["stream_options"], original_json["stream_options"]);
assert_eq!(
serialized_json["stream_options"],
original_json["stream_options"]
);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
// Handle temperature with floating point tolerance
@ -917,7 +922,8 @@ mod tests {
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
@ -953,7 +959,10 @@ mod tests {
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, Role::Assistant);
if let MessageContent::Text(text) = &assistant_message.content {
assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you.");
assert_eq!(
text,
"I can see a beautiful cityscape. Let me check the weather for you."
);
} else {
panic!("Expected text content for assistant message");
}
@ -967,7 +976,10 @@ mod tests {
assert_eq!(tool_call.id, "call_weather123");
assert_eq!(tool_call.call_type, "function");
assert_eq!(tool_call.function.name, "get_weather");
assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}");
assert_eq!(
tool_call.function.arguments,
"{\"location\": \"New York, NY\"}"
);
// Validate third message (tool response)
let tool_message = &deserialized_request.messages[2];
@ -977,7 +989,10 @@ mod tests {
} else {
panic!("Expected text content for tool message");
}
assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string()));
assert_eq!(
tool_message.tool_call_id,
Some("call_weather123".to_string())
);
// Validate tools array
assert!(deserialized_request.tools.is_some());
@ -987,7 +1002,10 @@ mod tests {
let tool = &tools[0];
assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "get_weather");
assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string()));
assert_eq!(
tool.function.description,
Some("Get current weather information for a location".to_string())
);
assert_eq!(tool.function.strict, Some(true));
// Validate tool parameters schema
@ -1093,7 +1111,8 @@ mod tests {
]
});
let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap();
let deserialized_assistant: Message =
serde_json::from_value(assistant_json.clone()).unwrap();
assert_eq!(deserialized_assistant.role, Role::Assistant);
if let MessageContent::Text(content) = &deserialized_assistant.content {
assert_eq!(content, "I'll help with that.");
@ -1142,9 +1161,13 @@ mod tests {
]
});
let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap();
let deserialized_response: ResponseMessage =
serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.role, Role::Assistant);
assert_eq!(deserialized_response.content, Some("Response content".to_string()));
assert_eq!(
deserialized_response.content,
Some("Response content".to_string())
);
assert!(deserialized_response.annotations.is_some());
assert!(deserialized_response.refusal.is_none());
assert!(deserialized_response.function_call.is_none());
@ -1186,7 +1209,10 @@ mod tests {
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
assert_eq!(
required_deserialized,
ToolChoice::Type(ToolChoiceType::Required)
);
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
// Test that invalid string values fail deserialization (type safety!)
@ -1237,7 +1263,10 @@ mod tests {
assert_eq!(response.created, 1756574706);
assert_eq!(response.model, "gpt-4o-2024-08-06");
assert_eq!(response.service_tier, Some("default".to_string()));
assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string()));
assert_eq!(
response.system_fingerprint,
Some("fp_f33640a400".to_string())
);
assert_eq!(response.choices.len(), 1);
assert_eq!(response.usage.prompt_tokens, 65);
assert_eq!(response.usage.completion_tokens, 184);

View file

@ -1,9 +1,9 @@
use std::str::FromStr;
use std::fmt;
use std::error::Error;
use serde::{Serialize, Deserialize};
use crate::providers::response::ProviderStreamResponse;
use crate::providers::response::ProviderStreamResponseType;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use std::str::FromStr;
// ============================================================================
// SSE EVENT CONTAINER
@ -13,19 +13,19 @@ use crate::providers::response::ProviderStreamResponseType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SseEvent {
#[serde(rename = "data")]
pub data: Option<String>, // The JSON payload after "data: "
pub data: Option<String>, // The JSON payload after "data: "
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
#[serde(skip_serializing, skip_deserializing)]
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
}
impl SseEvent {
@ -48,13 +48,13 @@ impl SseEvent {
/// Get the parsed provider response if available
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
self.provider_stream_response.as_ref()
self.provider_stream_response
.as_ref()
.map(|resp| resp as &dyn ProviderStreamResponse)
.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
})
}
}
impl FromStr for SseEvent {
@ -75,7 +75,8 @@ impl FromStr for SseEvent {
sse_transform_buffer: line.to_string(),
provider_stream_response: None,
})
} else if line.starts_with("event: ") { //used by Anthropic
} else if line.starts_with("event: ") {
//used by Anthropic
let event_type = line[7..].to_string();
if event_type.is_empty() {
return Err(SseParseError {
@ -123,7 +124,6 @@ impl fmt::Display for SseParseError {
impl Error for SseParseError {}
/// Generic SSE (Server-Sent Events) streaming iterator container
/// Parses raw SSE lines into SseEvent objects
pub struct SseStreamIter<I>
@ -141,7 +141,10 @@ where
I::Item: AsRef<str>,
{
pub fn new(lines: I) -> Self {
Self { lines, done_seen: false }
Self {
lines,
done_seen: false,
}
}
}
@ -151,14 +154,13 @@ impl TryFrom<&[u8]> for SseStreamIter<std::vec::IntoIter<String>> {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
// Parse as text-based SSE format
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
Ok(SseStreamIter::new(lines.into_iter()))
// Parse as text-based SSE format
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
Ok(SseStreamIter::new(lines.into_iter()))
}
}
impl<I> Iterator for SseStreamIter<I>
where
I: Iterator,

View file

@ -1,5 +1,5 @@
use crate::{ProviderId};
use crate::apis::{OpenAIApi, AnthropicApi, AmazonBedrockApi, ApiDefinition};
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi};
use crate::ProviderId;
use std::fmt;
/// Unified enum representing all supported API endpoints across providers
@ -20,8 +20,12 @@ pub enum SupportedUpstreamAPIs {
impl fmt::Display for SupportedAPIs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SupportedAPIs::OpenAIChatCompletions(api) => write!(f, "OpenAI API ({})", api.endpoint()),
SupportedAPIs::AnthropicMessagesAPI(api) => write!(f, "Anthropic API ({})", api.endpoint()),
SupportedAPIs::OpenAIChatCompletions(api) => {
write!(f, "OpenAI API ({})", api.endpoint())
}
SupportedAPIs::AnthropicMessagesAPI(api) => {
write!(f, "Anthropic API ({})", api.endpoint())
}
}
}
}
@ -48,81 +52,81 @@ impl SupportedAPIs {
}
}
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str, is_streaming: bool) -> String {
pub fn target_endpoint_for_provider(
&self,
provider_id: &ProviderId,
request_path: &str,
model_id: &str,
is_streaming: bool,
) -> String {
let default_endpoint = "/v1/chat/completions".to_string();
match self {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
match provider_id {
ProviderId::Anthropic => "/v1/messages".to_string(),
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") && !is_streaming {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
ProviderId::Anthropic => "/v1/messages".to_string(),
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") && !is_streaming {
format!("/model/{}/converse", model_id)
} else if request_path.starts_with("/v1/") && is_streaming {
format!("/model/{}/converse-stream", model_id)
} else {
default_endpoint
}
}
_ => default_endpoint,
},
_ => match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
format!("/openai{}", request_path)
} else {
default_endpoint
}
}
ProviderId::Zhipu => {
if request_path.starts_with("/v1/") {
"/api/paas/v4/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::Qwen => {
if request_path.starts_with("/v1/") {
"/compatible-mode/v1/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
} else {
default_endpoint
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
"/v1beta/openai/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") {
if !is_streaming {
format!("/model/{}/converse", model_id)
} else if request_path.starts_with("/v1/") && is_streaming {
} else {
format!("/model/{}/converse-stream", model_id)
} else {
default_endpoint
}
} else {
default_endpoint
}
_ => default_endpoint,
}
}
_ => {
match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
format!("/openai{}", request_path)
} else {
default_endpoint
}
}
ProviderId::Zhipu => {
if request_path.starts_with("/v1/") {
"/api/paas/v4/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::Qwen => {
if request_path.starts_with("/v1/") {
"/compatible-mode/v1/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
} else {
default_endpoint
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
"/v1beta/openai/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") {
if !is_streaming {
format!("/model/{}/converse", model_id)
} else {
format!("/model/{}/converse-stream", model_id)
}
} else {
default_endpoint
}
}
_ => default_endpoint,
}
}
_ => default_endpoint,
},
}
}
}
/// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new();
@ -164,7 +168,6 @@ mod tests {
// Anthropic endpoints
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
// Unsupported endpoints
assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some());
@ -177,7 +180,6 @@ mod tests {
assert_eq!(endpoints.len(), 2); // We have 2 APIs defined
assert!(endpoints.contains(&"/v1/chat/completions"));
assert!(endpoints.contains(&"/v1/messages"));
}
#[test]
@ -203,14 +205,25 @@ mod tests {
// All OpenAI endpoints should be in the result
for endpoint in openai_endpoints {
assert!(endpoints.contains(&endpoint), "Missing OpenAI endpoint: {}", endpoint);
assert!(
endpoints.contains(&endpoint),
"Missing OpenAI endpoint: {}",
endpoint
);
}
// All Anthropic endpoints should be in the result
for endpoint in anthropic_endpoints {
assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint);
assert!(
endpoints.contains(&endpoint),
"Missing Anthropic endpoint: {}",
endpoint
);
}
// Total should match
assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len());
assert_eq!(
endpoints.len(),
OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len()
);
}
}

View file

@ -1,9 +1,9 @@
pub mod endpoints;
pub mod lib;
pub mod transformer;
pub mod endpoints;
// Re-export the main items for easier access
pub use endpoints::{identify_provider, SupportedAPIs};
pub use lib::*;
pub use endpoints::{SupportedAPIs, identify_provider};
// Note: transformer module contains TryFrom trait implementations that are automatically available

View file

@ -1,6 +1,5 @@
// Re-export new transformation modules for backward compatibility
//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING
// ============================================================================
@ -9,10 +8,10 @@
#[cfg(test)]
mod tests {
use serde_json::json;
use crate::apis::anthropic::*;
use crate::apis::openai::*;
use crate::transforms::*;
use serde_json::json;
type AnthropicMessagesRequest = MessagesRequest;
#[test]
@ -81,11 +80,20 @@ mod tests {
// Check key fields are preserved
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
assert_eq!(original_anthropic.max_tokens, roundtrip_anthropic.max_tokens);
assert_eq!(original_anthropic.temperature, roundtrip_anthropic.temperature);
assert_eq!(
original_anthropic.max_tokens,
roundtrip_anthropic.max_tokens
);
assert_eq!(
original_anthropic.temperature,
roundtrip_anthropic.temperature
);
assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p);
assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream);
assert_eq!(original_anthropic.messages.len(), roundtrip_anthropic.messages.len());
assert_eq!(
original_anthropic.messages.len(),
roundtrip_anthropic.messages.len()
);
}
#[test]
@ -229,7 +237,10 @@ mod tests {
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, Some("call_123".to_string()));
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("get_weather".to_string())
);
}
#[test]
@ -249,7 +260,10 @@ mod tests {
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.as_ref().unwrap().arguments, Some(r#"{"location": "San Francisco"#.to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().arguments,
Some(r#"{"location": "San Francisco"#.to_string())
);
}
#[test]
@ -412,7 +426,10 @@ mod tests {
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match anthropic_event {
MessagesStreamEvent::ContentBlockStart { index, content_block } => {
MessagesStreamEvent::ContentBlockStart {
index,
content_block,
} => {
assert_eq!(index, 0);
match content_block {
MessagesContentBlock::ToolUse { id, name, .. } => {
@ -555,16 +572,28 @@ mod tests {
// Verify tool start
let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls[0].id, Some("call_weather".to_string()));
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("get_weather".to_string())
);
// Verify argument deltas
let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0]
.function.as_ref().unwrap().arguments;
.function
.as_ref()
.unwrap()
.arguments;
assert_eq!(args1, &Some(r#"{"location": "#.to_string()));
let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0]
.function.as_ref().unwrap().arguments;
assert_eq!(args2, &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string()));
.function
.as_ref()
.unwrap()
.arguments;
assert_eq!(
args2,
&Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())
);
}
#[test]
@ -592,14 +621,23 @@ mod tests {
};
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.choices[0].finish_reason, Some(expected_openai_reason));
assert_eq!(
openai_resp.choices[0].finish_reason,
Some(expected_openai_reason)
);
// Test reverse conversion
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match roundtrip_event {
MessagesStreamEvent::MessageDelta { delta, .. } => {
// Note: Some precision may be lost in roundtrip due to mapping differences
assert!(matches!(delta.stop_reason, MessagesStopReason::EndTurn | MessagesStopReason::MaxTokens | MessagesStopReason::ToolUse | MessagesStopReason::StopSequence));
assert!(matches!(
delta.stop_reason,
MessagesStopReason::EndTurn
| MessagesStopReason::MaxTokens
| MessagesStopReason::ToolUse
| MessagesStopReason::StopSequence
));
}
_ => panic!("Expected MessageDelta after roundtrip"),
}
@ -632,7 +670,8 @@ mod tests {
};
// Should convert to Ping when no meaningful content
let anthropic_event: MessagesStreamEvent = openai_resp_with_missing_data.try_into().unwrap();
let anthropic_event: MessagesStreamEvent =
openai_resp_with_missing_data.try_into().unwrap();
assert!(matches!(anthropic_event, MessagesStreamEvent::Ping));
}

View file

@ -1,24 +1,25 @@
//! hermesllm: A library for translating LLM API requests and responses
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
pub mod providers;
pub mod apis;
pub mod clients;
pub mod providers;
pub mod transforms;
// Re-export important types and traits
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
pub use apis::sse::{SseEvent, SseStreamIter};
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage};
pub use providers::id::ProviderId;
pub use apis::sse::{SseEvent, SseStreamIter};
pub use aws_smithy_eventstream::frame::DecodedFrame;
pub use providers::id::ProviderId;
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use providers::response::{
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
ProviderStreamResponseType, TokenUsage,
};
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const MESSAGES_PATH: &str = "/v1/messages";
#[cfg(test)]
mod tests {
use crate::clients::endpoints::SupportedUpstreamAPIs;
@ -36,49 +37,51 @@ mod tests {
#[test]
fn test_provider_streaming_response() {
// Test streaming response parsing with sample SSE data
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
data: [DONE]
"#;
use crate::clients::endpoints::SupportedAPIs;
let client_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
use crate::clients::endpoints::SupportedAPIs;
let client_api =
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
let upstream_api =
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
// Test the new simplified architecture - create SseStreamIter directly
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
assert!(sse_iter.is_ok());
// Test the new simplified architecture - create SseStreamIter directly
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
assert!(sse_iter.is_ok());
let mut streaming_iter = sse_iter.unwrap();
let mut streaming_iter = sse_iter.unwrap();
// Test that we can iterate over SseEvents
let first_event = streaming_iter.next();
assert!(first_event.is_some());
// Test that we can iterate over SseEvents
let first_event = streaming_iter.next();
assert!(first_event.is_some());
let sse_event = first_event.unwrap();
let sse_event = first_event.unwrap();
// Test SseEvent properties
assert!(!sse_event.is_done());
assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
// Test SseEvent properties
assert!(!sse_event.is_done());
assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
// Test that we can parse the event into a provider stream response
let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api));
if let Err(e) = &transformed_event {
println!("Transform error: {:?}", e);
}
assert!(transformed_event.is_ok());
// Test that we can parse the event into a provider stream response
let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api));
if let Err(e) = &transformed_event {
println!("Transform error: {:?}", e);
}
assert!(transformed_event.is_ok());
let transformed_event = transformed_event.unwrap();
let provider_response = transformed_event.provider_response();
assert!(provider_response.is_ok());
let transformed_event = transformed_event.unwrap();
let provider_response = transformed_event.provider_response();
assert!(provider_response.is_ok());
let stream_response = provider_response.unwrap();
assert_eq!(stream_response.content_delta(), Some("Hello"));
assert!(!stream_response.is_final());
let stream_response = provider_response.unwrap();
assert_eq!(stream_response.content_delta(), Some("Hello"));
assert!(!stream_response.is_final());
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
let final_event = streaming_iter.next();
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
let final_event = streaming_iter.next();
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
}
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
@ -95,15 +98,15 @@ mod tests {
/// all complete frames in the buffer.
#[test]
fn test_amazon_bedrock_streaming_response() {
use aws_smithy_eventstream::frame::{MessageFrameDecoder, DecodedFrame};
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use bytes::{Buf, BytesMut};
use std::fs;
use std::path::PathBuf;
// Read the response.hex file from tests/e2e directory
// Use absolute path to avoid cargo test working directory issues
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../../tests/e2e/response.hex");
let test_file =
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
let response_data = fs::read(&test_file)
.unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e));
@ -134,8 +137,13 @@ mod tests {
simulated_network_buffer.extend_from_slice(chunk);
offset = end;
println!("📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)",
chunk_num, chunk.len(), simulated_network_buffer.len(), simulated_network_buffer.remaining());
println!(
"📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)",
chunk_num,
chunk.len(),
simulated_network_buffer.len(),
simulated_network_buffer.remaining()
);
// Try to decode all complete frames from buffer
// The Buf trait tracks position automatically!
@ -146,11 +154,16 @@ mod tests {
frame_count += 1;
let consumed = bytes_before - simulated_network_buffer.remaining();
println!(" ✅ Frame {}: decoded ({} bytes, {} bytes remaining)",
frame_count, consumed, simulated_network_buffer.remaining());
println!(
" ✅ Frame {}: decoded ({} bytes, {} bytes remaining)",
frame_count,
consumed,
simulated_network_buffer.remaining()
);
// Get event type from headers
let event_type = message.headers()
let event_type = message
.headers()
.iter()
.find(|h| h.name().as_str() == ":event-type")
.and_then(|h| {
@ -167,7 +180,9 @@ mod tests {
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(payload) {
if event_type.as_deref() == Some("contentBlockDelta") {
if let Some(delta) = json.get("delta") {
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
if let Some(text) =
delta.get("text").and_then(|t| t.as_str())
{
println!(" 📝 Content: \"{}\"", text);
content_chunks.push(text.to_string());
}
@ -178,7 +193,10 @@ mod tests {
}
Ok(DecodedFrame::Incomplete) => {
// Not enough data for a complete frame - need more chunks
println!(" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n", simulated_network_buffer.remaining());
println!(
" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n",
simulated_network_buffer.remaining()
);
break; // Wait for next chunk
}
Err(e) => {
@ -193,7 +211,10 @@ mod tests {
println!(" Total chunks received: {}", chunk_num);
println!(" Total frames decoded: {}", frame_count);
println!(" Total content chunks: {}", content_chunks.len());
println!(" Final buffer remaining: {} bytes", simulated_network_buffer.remaining());
println!(
" Final buffer remaining: {} bytes",
simulated_network_buffer.remaining()
);
if !content_chunks.is_empty() {
let full_text = content_chunks.join("");
@ -207,6 +228,11 @@ mod tests {
assert!(frame_count > 0, "Should decode at least one frame");
// Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame
assert_eq!(simulated_network_buffer.remaining(), 0, "All bytes should be consumed, {} bytes remain", simulated_network_buffer.remaining());
assert_eq!(
simulated_network_buffer.remaining(),
0,
"All bytes should be consumed, {} bytes remain",
simulated_network_buffer.remaining()
);
}
}

View file

@ -1,6 +1,6 @@
use std::fmt::Display;
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
use crate::apis::{OpenAIApi, AnthropicApi, AmazonBedrockApi};
use std::fmt::Display;
/// Provider identifier enum - simple enum for identifying providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -49,60 +49,77 @@ impl From<&str> for ProviderId {
impl ProviderId {
/// Given a client API, return the compatible upstream API for this provider
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs, is_streaming: bool) -> SupportedUpstreamAPIs {
pub fn compatible_api_for_client(
&self,
client_api: &SupportedAPIs,
is_streaming: bool,
) -> SupportedUpstreamAPIs {
match (self, client_api) {
// Claude/Anthropic providers natively support Anthropic APIs
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
}
(
ProviderId::Anthropic,
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// OpenAI-compatible providers only support OpenAI chat completions
(ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::Ollama
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(
ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::Ollama
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::AnthropicMessagesAPI(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::Ollama
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::OpenAIChatCompletions(_)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(
ProviderId::OpenAI
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
| ProviderId::XAI
| ProviderId::TogetherAI
| ProviderId::Ollama
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::OpenAIChatCompletions(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// Amazon Bedrock natively supports Bedrock APIs
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(AmazonBedrockApi::ConverseStream)
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
)
} else {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
}
},
}
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(AmazonBedrockApi::ConverseStream)
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
)
} else {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
}
},
}
}
}
}

View file

@ -8,5 +8,5 @@ pub mod request;
pub mod response;
pub use id::ProviderId;
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
pub use response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage };
pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage};

View file

@ -1,14 +1,14 @@
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::anthropic::MessagesRequest;
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
use crate::clients::endpoints::SupportedAPIs;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use serde_json::Value;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::collections::HashMap;
#[derive(Clone)]
pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest),
@ -124,15 +124,18 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
// Use SupportedApi to determine the appropriate request type
match client_api {
SupportedAPIs::OpenAIChatCompletions(_) => {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
SupportedAPIs::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
let chat_completion_request: ChatCompletionsRequest =
ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
}
Ok(ProviderRequestType::ChatCompletionsRequest(
chat_completion_request,
))
}
SupportedAPIs::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
}
}
}
}
@ -141,37 +144,57 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType {
type Error = ProviderRequestError;
fn try_from((client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> {
fn try_from(
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
match (client_request, upstream_api) {
// Same API - no conversion needed, just clone the reference
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => {
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
// Cross-API conversion - cloning is necessary for transformation
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
let messages_req = MessagesRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
source: Some(Box::new(e))
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => {
let messages_req =
MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to MessagesRequest: {}",
e
),
source: Some(Box::new(e)),
})?;
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => {
let chat_req = ChatCompletionsRequest::try_from(messages_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
source: Some(Box::new(e))
})?;
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => {
let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => {
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
let bedrock_req = ConverseRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
@ -180,7 +203,10 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => {
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
@ -188,20 +214,33 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => {
let bedrock_req = ConverseRequest::try_from(messages_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert MessagesRequest to Amazon Bedrock request: {}", e),
source: Some(Box::new(e))
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
let bedrock_req =
ConverseRequest::try_from(messages_req).map_err(|e| ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
e
),
source: Some(Box::new(e)),
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => {
let bedrock_req = ConverseStreamRequest::try_from(messages_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert MessagesRequest to Amazon Bedrock request: {}", e),
source: Some(Box::new(e))
})?;
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
@ -213,13 +252,10 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
(ProviderRequestType::BedrockConverseStream(_), _) => {
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
}
}
}
}
/// Error types for provider operations
#[derive(Debug)]
pub struct ProviderRequestError {
@ -235,19 +271,20 @@ impl fmt::Display for ProviderRequestError {
impl Error for ProviderRequestError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::anthropic::AnthropicApi::Messages;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest};
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::clients::endpoints::SupportedAPIs;
use crate::transforms::lib::ExtractText;
use serde_json::json;
@ -268,7 +305,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.messages.len(), 2);
},
}
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
@ -291,7 +328,7 @@ mod tests {
ProviderRequestType::MessagesRequest(r) => {
assert_eq!(r.model, "claude-3-sonnet");
assert_eq!(r.messages.len(), 1);
},
}
_ => panic!("Expected MessagesRequest variant"),
}
}
@ -313,7 +350,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.messages.len(), 2);
},
}
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
@ -337,7 +374,7 @@ mod tests {
ProviderRequestType::ChatCompletionsRequest(r) => {
assert_eq!(r.model, "claude-3-sonnet");
assert_eq!(r.messages.len(), 1);
},
}
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
@ -346,13 +383,15 @@ mod tests {
fn test_v1_messages_to_v1_chat_completions_roundtrip() {
let anthropic_req = AnthropicMessagesRequest {
model: "claude-3-sonnet".to_string(),
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())),
messages: vec![
crate::apis::anthropic::MessagesMessage {
role: crate::apis::anthropic::MessagesRole::User,
content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()),
}
],
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single(
"You are a helpful assistant".to_string(),
)),
messages: vec![crate::apis::anthropic::MessagesMessage {
role: crate::apis::anthropic::MessagesRole::User,
content: crate::apis::anthropic::MessagesMessageContent::Single(
"Hello!".to_string(),
),
}],
max_tokens: 128,
container: None,
mcp_servers: None,
@ -368,16 +407,27 @@ mod tests {
metadata: None,
};
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed");
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed");
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone())
.expect("Anthropic->OpenAI conversion failed");
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req)
.expect("OpenAI->Anthropic conversion failed");
assert_eq!(anthropic_req.model, anthropic_req2.model);
// Compare system prompt text if present
assert_eq!(
anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }),
anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None })
anthropic_req.system.as_ref().and_then(|s| match s {
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
_ => None,
}),
anthropic_req2.system.as_ref().and_then(|s| match s {
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
_ => None,
})
);
assert_eq!(
anthropic_req.messages[0].role,
anthropic_req2.messages[0].role
);
assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role);
// Compare message content text if present
assert_eq!(
anthropic_req.messages[0].content.extract_text(),
@ -386,49 +436,54 @@ mod tests {
assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens);
}
#[test]
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
#[test]
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
let openai_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("You are a helpful assistant".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
Message {
role: Role::User,
content: MessageContent::Text("Hello!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
],
temperature: Some(0.7),
top_p: Some(1.0),
max_tokens: Some(128),
stream: Some(false),
stop: Some(vec!["\n".to_string()]),
tools: None,
tool_choice: None,
parallel_tool_calls: None,
..Default::default()
};
let openai_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("You are a helpful assistant".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
Message {
role: Role::User,
content: MessageContent::Text("Hello!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
],
temperature: Some(0.7),
top_p: Some(1.0),
max_tokens: Some(128),
stream: Some(false),
stop: Some(vec!["\n".to_string()]),
tools: None,
tool_choice: None,
parallel_tool_calls: None,
..Default::default()
};
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed");
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed");
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone())
.expect("OpenAI->Anthropic conversion failed");
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req)
.expect("Anthropic->OpenAI conversion failed");
assert_eq!(openai_req.model, openai_req2.model);
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text());
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
assert_eq!(original_max_tokens, roundtrip_max_tokens);
}
assert_eq!(openai_req.model, openai_req2.model);
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
assert_eq!(
openai_req.messages[0].content.extract_text(),
openai_req2.messages[0].content.extract_text()
);
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
assert_eq!(original_max_tokens, roundtrip_max_tokens);
}
}

View file

@ -1,18 +1,18 @@
use serde::Serialize;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
use std::convert::TryFrom;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use crate::clients::endpoints::SupportedAPIs;
use crate::providers::id::ProviderId;
use crate::apis::sse::SseEvent;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::openai::ChatCompletionsStreamResponse;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::anthropic::MessagesResponse;
use crate::apis::amazon_bedrock::ConverseResponse;
use crate::apis::amazon_bedrock::ConverseStreamEvent;
use crate::apis::anthropic::MessagesResponse;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::openai::ChatCompletionsStreamResponse;
use crate::apis::sse::SseEvent;
use crate::clients::endpoints::SupportedAPIs;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use crate::providers::id::ProviderId;
/// Trait for token usage information
pub trait TokenUsage {
@ -34,7 +34,6 @@ pub enum ProviderStreamResponseType {
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
MessagesStreamEvent(MessagesStreamEvent),
ConverseStreamEvent(ConverseStreamEvent),
}
pub trait ProviderResponse: Send + Sync {
@ -43,7 +42,8 @@ pub trait ProviderResponse: Send + Sync {
/// Extract token counts for metrics
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
self.usage()
.map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
}
}
@ -74,7 +74,6 @@ pub trait ProviderStreamResponse: Send + Sync {
/// Get event type for SSE streaming (used by Anthropic)
fn event_type(&self) -> Option<&str>;
}
impl ProviderStreamResponse for ProviderStreamResponseType {
@ -135,59 +134,97 @@ impl Into<String> for ProviderStreamResponseType {
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
type Error = std::io::Error;
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
fn try_from(
(bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId),
) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api, false);
match (&upstream_api, client_api) {
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
}
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::MessagesResponse(resp))
}
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to OpenAI ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
let chat_resp: ChatCompletionsResponse =
anthropic_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
}
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to Anthropic Messages format using the transformer
let messages_resp: MessagesResponse = openai_resp.try_into()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
let messages_resp: MessagesResponse = openai_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::MessagesResponse(messages_resp))
}
// Amazon Bedrock transformations
(SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
(
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to OpenAI ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
}
(SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
(
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to Anthropic Messages format using the transformer
let messages_resp: MessagesResponse = bedrock_resp.try_into()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
let messages_resp: MessagesResponse = bedrock_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::MessagesResponse(messages_resp))
}
_ => {
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation"))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unsupported API combination for response transformation",
)),
}
}
}
@ -196,55 +233,86 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> {
fn try_from(
(bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
// Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion
if bytes == b"[DONE]" && matches!(client_api, SupportedAPIs::AnthropicMessagesAPI(_)) {
return Ok(ProviderStreamResponseType::MessagesStreamEvent(
crate::apis::anthropic::MessagesStreamEvent::MessageStop
crate::apis::anthropic::MessagesStreamEvent::MessageStop,
));
}
match (upstream_api, client_api) {
// OpenAI upstream
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let resp = serde_json::from_slice(bytes)?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp))
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
resp,
))
}
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse =
serde_json::from_slice(bytes)?;
let anthropic_resp = openai_resp.try_into()?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_resp))
Ok(ProviderStreamResponseType::MessagesStreamEvent(
anthropic_resp,
))
}
// Anthropic upstream
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let resp = serde_json::from_slice(bytes)?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(resp))
}
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?;
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent =
serde_json::from_slice(bytes)?;
let openai_resp = anthropic_resp.try_into()?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_resp))
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
openai_resp,
))
}
// Amazon Bedrock ConverseStream upstream
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = serde_json::from_slice(bytes)?;
(
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent =
serde_json::from_slice(bytes)?;
let anthropic_resp = bedrock_resp.try_into()?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_resp))
}
_ => {
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation").into())
Ok(ProviderStreamResponseType::MessagesStreamEvent(
anthropic_resp,
))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unsupported API combination for response transformation",
)
.into()),
}
}
}
// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response
impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> {
fn try_from(
(sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
// Create a new transformed event based on the original
let mut transformed_event = sse_event;
@ -252,7 +320,8 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
if transformed_event.data.is_some() {
let data_str = transformed_event.data.as_ref().unwrap();
let data_bytes = data_str.as_bytes();
let transformed_response: ProviderStreamResponseType = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
let transformed_response: ProviderStreamResponseType =
ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
// Convert to SSE string explicitly to avoid type ambiguity
let sse_string: String = transformed_response.clone().into();
@ -261,7 +330,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
}
match (client_api, upstream_api) {
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => {
(
SupportedAPIs::AnthropicMessagesAPI(_),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => {
if let Some(provider_response) = &transformed_event.provider_stream_response {
if let Some(event_type) = provider_response.event_type() {
// This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s)
@ -280,19 +352,18 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
// The sse_transform_buffer already contains the properly formatted MessageStart
transformed_event.sse_transform_buffer = format!(
"{}{}",
transformed_event.sse_transform_buffer,
content_block_start_sse,
transformed_event.sse_transform_buffer, content_block_start_sse,
);
} else if event_type == "message_delta" {
// Create ContentBlockStop event and format it using Into<String>
let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 };
let content_block_stop =
MessagesStreamEvent::ContentBlockStop { index: 0 };
let content_block_stop_sse: String = content_block_stop.into();
// Format as proper SSE: ContentBlockStop first, then MessageDelta
transformed_event.sse_transform_buffer = format!(
"{}{}",
content_block_stop_sse,
transformed_event.sse_transform_buffer
content_block_stop_sse, transformed_event.sse_transform_buffer
);
}
// For other event types, the sse_transform_buffer already has the correct format from Into<String>
@ -301,7 +372,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
// This handles cases where the transformation might not produce a valid event type
}
}
(SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
(
SupportedAPIs::OpenAIChatCompletions(_),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => {
if transformed_event.is_event_only() && transformed_event.event.is_some() {
transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI
}
@ -316,32 +390,56 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
}
// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType
impl TryFrom<(&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType {
impl
TryFrom<(
&aws_smithy_eventstream::frame::DecodedFrame,
&SupportedAPIs,
&SupportedUpstreamAPIs,
)> for ProviderStreamResponseType
{
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((frame, client_api, upstream_api): (&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> {
fn try_from(
(frame, client_api, upstream_api): (
&aws_smithy_eventstream::frame::DecodedFrame,
&SupportedAPIs,
&SupportedUpstreamAPIs,
),
) -> Result<Self, Self::Error> {
use aws_smithy_eventstream::frame::DecodedFrame;
match frame {
DecodedFrame::Complete(_) => {
// We have a complete frame - parse it based on upstream API
match (upstream_api, client_api) {
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
(
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
SupportedAPIs::AnthropicMessagesAPI(_),
) => {
// Parse the DecodedFrame into ConverseStreamEvent
let bedrock_event = crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = bedrock_event.try_into()?;
let bedrock_event =
crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
let anthropic_event: crate::apis::anthropic::MessagesStreamEvent =
bedrock_event.try_into()?;
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_event))
Ok(ProviderStreamResponseType::MessagesStreamEvent(
anthropic_event,
))
}
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
(
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
SupportedAPIs::OpenAIChatCompletions(_),
) => {
// Parse the DecodedFrame into ConverseStreamEvent
let bedrock_event = crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_event.try_into()?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event))
}
_ => {
Err("Unsupported API combination for event-stream decoding".into())
let bedrock_event =
crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
let openai_event: crate::apis::openai::ChatCompletionsStreamResponse =
bedrock_event.try_into()?;
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
openai_event,
))
}
_ => Err("Unsupported API combination for event-stream decoding".into()),
}
}
DecodedFrame::Incomplete => {
@ -351,15 +449,12 @@ impl TryFrom<(&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &Sup
}
}
#[derive(Debug)]
pub struct ProviderResponseError {
pub message: String,
pub source: Option<Box<dyn Error + Send + Sync>>,
}
impl fmt::Display for ProviderResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Provider response error: {}", self.message)
@ -368,22 +463,23 @@ impl fmt::Display for ProviderResponseError {
impl Error for ProviderResponseError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
self.source
.as_ref()
.map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::sse::SseStreamIter;
use crate::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::apis::sse::SseStreamIter;
use crate::clients::endpoints::SupportedAPIs;
use crate::providers::id::ProviderId;
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use serde_json::json;
#[test]
fn test_openai_response_from_bytes() {
let resp = json!({
@ -402,13 +498,17 @@ mod tests {
"system_fingerprint": null
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI));
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
&ProviderId::OpenAI,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.choices.len(), 1);
},
}
_ => panic!("Expected ChatCompletionsResponse variant"),
}
}
@ -427,13 +527,17 @@ mod tests {
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic));
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
&ProviderId::Anthropic,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.content.len(), 1);
},
}
_ => panic!("Expected MessagesResponse variant"),
}
}
@ -457,14 +561,18 @@ mod tests {
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
&ProviderId::OpenAI,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.usage.input_tokens, 10);
assert_eq!(r.usage.output_tokens, 25);
},
}
_ => panic!("Expected MessagesResponse variant"),
}
}
@ -495,14 +603,18 @@ mod tests {
}
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic));
let result = ProviderResponseType::try_from((
bytes.as_slice(),
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
&ProviderId::Anthropic,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.usage.prompt_tokens, 10);
assert_eq!(r.usage.completion_tokens, 25);
},
}
_ => panic!("Expected ChatCompletionsResponse variant"),
}
}
@ -514,11 +626,17 @@ mod tests {
let event: Result<SseEvent, _> = line.parse();
assert!(event.is_ok());
let event = event.unwrap();
assert_eq!(event.data, Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string()));
assert_eq!(
event.data,
Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string())
);
// Test conversion back to line using Display trait
let wire_format = event.to_string();
assert_eq!(wire_format, "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n");
assert_eq!(
wire_format,
"data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"
);
// Test [DONE] marker - should be valid SSE event
let done_line = "data: [DONE]";
@ -550,10 +668,12 @@ mod tests {
event: None,
raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"}
"#.to_string(),
"#
.to_string(),
sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"}
"#.to_string(),
"#
.to_string(),
provider_stream_response: None,
};
@ -590,7 +710,8 @@ mod tests {
data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()),
event: Some("content_block_delta".to_string()),
raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#
.to_string(),
provider_stream_response: None,
};
assert!(!normal_event.should_skip());
@ -616,7 +737,7 @@ mod tests {
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
"data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(),
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
"data: [DONE]".to_string(), // This should end the stream
"data: [DONE]".to_string(), // This should end the stream
];
let mut iter = SseStreamIter::new(test_lines.into_iter());
@ -684,13 +805,15 @@ mod tests {
#[test]
fn test_provider_stream_response_event_type() {
use crate::apis::anthropic::{MessagesStreamEvent, MessagesContentDelta};
use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent};
use crate::apis::openai::ChatCompletionsStreamResponse;
// Test Anthropic event type
let anthropic_event = MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::TextDelta { text: "Hello".to_string() },
delta: MessagesContentDelta::TextDelta {
text: "Hello".to_string(),
},
};
let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event);
assert_eq!(provider_type.event_type(), Some("content_block_delta"));
@ -717,15 +840,24 @@ mod tests {
// Test that [DONE] marker is properly converted to MessageStop in the transformation layer
let done_bytes = b"[DONE]";
let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(
crate::apis::openai::OpenAIApi::ChatCompletions,
);
let result = ProviderStreamResponseType::try_from((done_bytes.as_slice(), &client_api, &upstream_api));
let result = ProviderStreamResponseType::try_from((
done_bytes.as_slice(),
&client_api,
&upstream_api,
));
assert!(result.is_ok());
if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result {
// Verify it's a MessageStop event
assert_eq!(event.event_type(), Some("message_stop"));
assert!(matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStop));
assert!(matches!(
event,
crate::apis::anthropic::MessagesStreamEvent::MessageStop
));
} else {
panic!("Expected MessagesStreamEvent::MessageStop");
}
@ -747,7 +879,10 @@ mod tests {
// This signals the caller to wait for more data
let result = decoder.decode_frame();
assert!(result.is_some());
assert!(matches!(result.unwrap(), aws_smithy_eventstream::frame::DecodedFrame::Incomplete));
assert!(matches!(
result.unwrap(),
aws_smithy_eventstream::frame::DecodedFrame::Incomplete
));
// Verify we can still access the buffer
assert!(decoder.has_remaining());
@ -760,8 +895,8 @@ mod tests {
use std::path::PathBuf;
// Read the actual response.hex file from tests/e2e directory
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../../tests/e2e/response.hex");
let test_file =
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
// Only run this test if the file exists
if !test_file.exists() {
@ -782,7 +917,8 @@ mod tests {
frame_count += 1;
// Verify we can access headers
let event_type = message.headers()
let event_type = message
.headers()
.iter()
.find(|h| h.name().as_str() == ":event-type")
.and_then(|h| h.value().as_string().ok());
@ -811,8 +947,8 @@ mod tests {
use std::path::PathBuf;
// Read the actual response.hex file
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../../tests/e2e/response.hex");
let test_file =
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
if !test_file.exists() {
println!("Skipping test - response.hex not found");
@ -863,7 +999,10 @@ mod tests {
}
}
assert!(total_frames > 0, "Should have decoded frames from chunked data");
assert!(
total_frames > 0,
"Should have decoded frames from chunked data"
);
}
#[test]
@ -872,7 +1011,7 @@ mod tests {
}
#[test]
#[ignore] // Run with: cargo test -- --ignored --nocapture
#[ignore] // Run with: cargo test -- --ignored --nocapture
fn test_bedrock_decoded_frame_to_provider_response_verbose() {
test_bedrock_conversion(true);
}
@ -883,7 +1022,7 @@ mod tests {
}
#[test]
#[ignore] // Run with: cargo test -- --ignored --nocapture
#[ignore] // Run with: cargo test -- --ignored --nocapture
fn test_bedrock_decoded_frame_with_tool_use_verbose() {
test_bedrock_conversion_with_tools(true);
}
@ -894,8 +1033,8 @@ mod tests {
use std::path::PathBuf;
// Read the actual response.hex file from tests/e2e directory
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../../tests/e2e/response.hex");
let test_file =
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
// Only run this test if the file exists
if !test_file.exists() {
@ -908,8 +1047,11 @@ mod tests {
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream);
let client_api =
SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
);
let mut conversion_count = 0;
let mut message_start_seen = false;
@ -919,14 +1061,18 @@ mod tests {
match decoder.decode_frame() {
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
// Convert DecodedFrame to ProviderStreamResponseType
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
let result =
ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
match result {
Ok(provider_response) => {
conversion_count += 1;
// Verify we got a MessagesStreamEvent
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_)));
assert!(matches!(
provider_response,
ProviderStreamResponseType::MessagesStreamEvent(_)
));
if verbose {
// Print the SSE string output
@ -935,8 +1081,13 @@ mod tests {
}
// Check for MessageStart event
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response {
if matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. }) {
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) =
provider_response
{
if matches!(
event,
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. }
) {
message_start_seen = true;
}
}
@ -956,7 +1107,10 @@ mod tests {
}
}
assert!(conversion_count > 0, "Should have converted at least one frame");
assert!(
conversion_count > 0,
"Should have converted at least one frame"
);
assert!(message_start_seen, "Should have seen MessageStart event");
}
@ -980,8 +1134,11 @@ mod tests {
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream);
let client_api =
SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
);
let mut conversion_count = 0;
let mut message_start_seen = false;
@ -993,14 +1150,18 @@ mod tests {
match decoder.decode_frame() {
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
// Convert DecodedFrame to ProviderStreamResponseType
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
let result =
ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
match result {
Ok(provider_response) => {
conversion_count += 1;
// Verify we got a MessagesStreamEvent
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_)));
assert!(matches!(
provider_response,
ProviderStreamResponseType::MessagesStreamEvent(_)
));
if verbose {
// Print the SSE string output
@ -1009,7 +1170,9 @@ mod tests {
}
// Check for specific events related to tool use
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response {
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) =
provider_response
{
match event {
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => {
message_start_seen = true;
@ -1041,16 +1204,25 @@ mod tests {
}
}
assert!(conversion_count > 0, "Should have converted at least one frame");
assert!(
conversion_count > 0,
"Should have converted at least one frame"
);
assert!(message_start_seen, "Should have seen MessageStart event");
assert!(content_block_start_seen, "Should have seen ContentBlockStart event for tool use");
assert!(content_block_delta_tool_use_seen, "Should have seen ContentBlockDelta with ToolUseDelta");
assert!(
content_block_start_seen,
"Should have seen ContentBlockStart event for tool use"
);
assert!(
content_block_delta_tool_use_seen,
"Should have seen ContentBlockDelta with ToolUseDelta"
);
}
#[test]
fn test_sse_event_transformation_openai_to_anthropic_message_start() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an OpenAI stream response that represents a role start (which becomes message_start in Anthropic)
let openai_stream_chunk = json!({
@ -1085,8 +1257,14 @@ mod tests {
// Verify the transformation includes both message_start and content_block_start
let buffer = transformed.sse_transform_buffer;
assert!(buffer.contains("event: message_start"), "Should contain message_start event");
assert!(buffer.contains("event: content_block_start"), "Should contain content_block_start event");
assert!(
buffer.contains("event: message_start"),
"Should contain message_start event"
);
assert!(
buffer.contains("event: content_block_start"),
"Should contain content_block_start event"
);
// Verify proper SSE format with event lines before data lines
assert!(buffer.find("event: message_start").unwrap() < buffer.find("data:").unwrap());
@ -1095,8 +1273,8 @@ mod tests {
#[test]
fn test_sse_event_transformation_openai_to_anthropic_message_delta() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic)
let openai_stream_chunk = json!({
@ -1136,19 +1314,28 @@ mod tests {
// Verify the transformation includes both content_block_stop and message_delta
let buffer = transformed.sse_transform_buffer;
assert!(buffer.contains("event: content_block_stop"), "Should contain content_block_stop event");
assert!(buffer.contains("event: message_delta"), "Should contain message_delta event");
assert!(
buffer.contains("event: content_block_stop"),
"Should contain content_block_stop event"
);
assert!(
buffer.contains("event: message_delta"),
"Should contain message_delta event"
);
// Verify content_block_stop comes before message_delta
let stop_pos = buffer.find("content_block_stop").unwrap();
let delta_pos = buffer.find("message_delta").unwrap();
assert!(stop_pos < delta_pos, "content_block_stop should come before message_delta");
assert!(
stop_pos < delta_pos,
"content_block_stop should come before message_delta"
);
}
#[test]
fn test_sse_event_transformation_openai_to_anthropic_content_delta() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic)
let openai_stream_chunk = json!({
@ -1183,9 +1370,18 @@ mod tests {
// Verify the transformation is a content_block_delta (no extra events injected)
let buffer = transformed.sse_transform_buffer;
assert!(buffer.contains("event: content_block_delta"), "Should contain content_block_delta event");
assert!(!buffer.contains("content_block_start"), "Should not inject content_block_start for content delta");
assert!(!buffer.contains("content_block_stop"), "Should not inject content_block_stop for content delta");
assert!(
buffer.contains("event: content_block_delta"),
"Should contain content_block_delta event"
);
assert!(
!buffer.contains("content_block_start"),
"Should not inject content_block_start for content delta"
);
assert!(
!buffer.contains("content_block_stop"),
"Should not inject content_block_stop for content delta"
);
// Verify the content is preserved
assert!(buffer.contains("Hello"), "Should preserve the content text");
@ -1193,8 +1389,8 @@ mod tests {
#[test]
fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an Anthropic event-only SSE line (no data)
let sse_event = SseEvent {
@ -1215,14 +1411,20 @@ mod tests {
let transformed = result.unwrap();
// Verify the event line is suppressed (replaced with just newline)
assert_eq!(transformed.sse_transform_buffer, "\n", "Event-only lines should be suppressed to newline for OpenAI");
assert!(transformed.is_event_only(), "Should still be marked as event-only");
assert_eq!(
transformed.sse_transform_buffer, "\n",
"Event-only lines should be suppressed to newline for OpenAI"
);
assert!(
transformed.is_event_only(),
"Should still be marked as event-only"
);
}
#[test]
fn test_sse_event_transformation_anthropic_to_openai_preserves_data() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an Anthropic message_start event with data
let anthropic_event = json!({
@ -1258,7 +1460,10 @@ mod tests {
// Verify data is transformed to OpenAI format
let buffer = transformed.sse_transform_buffer;
assert!(buffer.starts_with("data: "), "Should have data: prefix");
assert!(!buffer.contains("event:"), "Should not have event: lines for OpenAI");
assert!(
!buffer.contains("event:"),
"Should not have event: lines for OpenAI"
);
// Verify provider response was parsed
assert!(transformed.provider_stream_response.is_some());
@ -1310,8 +1515,8 @@ mod tests {
#[test]
fn test_sse_event_transformation_preserves_provider_response() {
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
// Create an OpenAI stream response
let openai_stream_chunk = json!({
@ -1344,11 +1549,17 @@ mod tests {
let transformed = result.unwrap();
// Verify provider_stream_response is populated
assert!(transformed.provider_stream_response.is_some(), "Should parse and store provider response");
assert!(
transformed.provider_stream_response.is_some(),
"Should parse and store provider response"
);
// Verify we can access the provider response
let provider_response = transformed.provider_response();
assert!(provider_response.is_ok(), "Should be able to access provider response");
assert!(
provider_response.is_ok(),
"Should be able to access provider response"
);
// Verify the content delta is accessible
let content = provider_response.unwrap().content_delta();

View file

@ -1,7 +1,7 @@
use serde_json::Value;
use crate::apis::anthropic::{MessagesContentBlock,MessagesImageSource};
use crate::apis::anthropic::{MessagesContentBlock, MessagesImageSource};
use crate::apis::openai::{ContentPart, FunctionCall, ImageUrl, Message, MessageContent, ToolCall};
use crate::clients::TransformError;
use serde_json::Value;
use std::time::{SystemTime, UNIX_EPOCH};
pub trait ExtractText {
@ -11,12 +11,17 @@ pub trait ExtractText {
/// Trait for utility functions on content collections
pub trait ContentUtils<T> {
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
}
/// Helper to create a current unix timestamp
pub fn current_timestamp() -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs()
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
// Content Utilities
@ -26,24 +31,36 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
for block in self {
match block {
MessagesContentBlock::ToolUse { id, name, input, .. } |
MessagesContentBlock::ServerToolUse { id, name, input } |
MessagesContentBlock::McpToolUse { id, name, input } => {
MessagesContentBlock::ToolUse {
id, name, input, ..
}
| MessagesContentBlock::ServerToolUse { id, name, input }
| MessagesContentBlock::McpToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input)?;
tool_calls.push(ToolCall {
id: id.clone(),
call_type: "function".to_string(),
function: FunctionCall { name: name.clone(), arguments },
function: FunctionCall {
name: name.clone(),
arguments,
},
});
}
_ => continue,
}
}
Ok(if tool_calls.is_empty() { None } else { Some(tool_calls) })
Ok(if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
})
}
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError> {
fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>
{
let mut content_parts = Vec::new();
let mut tool_calls = Vec::new();
let mut tool_results = Vec::new();
@ -62,25 +79,55 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
},
});
}
MessagesContentBlock::ToolUse { id, name, input, .. } |
MessagesContentBlock::ServerToolUse { id, name, input } |
MessagesContentBlock::McpToolUse { id, name, input } => {
MessagesContentBlock::ToolUse {
id, name, input, ..
}
| MessagesContentBlock::ServerToolUse { id, name, input }
| MessagesContentBlock::McpToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input)?;
tool_calls.push(ToolCall {
id: id.clone(),
call_type: "function".to_string(),
function: FunctionCall { name: name.clone(), arguments },
function: FunctionCall {
name: name.clone(),
arguments,
},
});
}
MessagesContentBlock::ToolResult { tool_use_id, content, is_error, .. } => {
MessagesContentBlock::ToolResult {
tool_use_id,
content,
is_error,
..
} => {
let result_text = content.extract_text();
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false)));
tool_results.push((
tool_use_id.clone(),
result_text,
is_error.unwrap_or(false),
));
}
MessagesContentBlock::WebSearchToolResult { tool_use_id, content, is_error } |
MessagesContentBlock::CodeExecutionToolResult { tool_use_id, content, is_error } |
MessagesContentBlock::McpToolResult { tool_use_id, content, is_error } => {
MessagesContentBlock::WebSearchToolResult {
tool_use_id,
content,
is_error,
}
| MessagesContentBlock::CodeExecutionToolResult {
tool_use_id,
content,
is_error,
}
| MessagesContentBlock::McpToolResult {
tool_use_id,
content,
is_error,
} => {
let result_text = content.extract_text();
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false)));
tool_results.push((
tool_use_id.clone(),
result_text,
is_error.unwrap_or(false),
));
}
_ => {
// Skip unsupported content types
@ -122,29 +169,41 @@ fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource {
data: data.to_string(),
}
} else {
MessagesImageSource::Url { url: image_url.url.clone() }
MessagesImageSource::Url {
url: image_url.url.clone(),
}
}
} else {
MessagesImageSource::Url { url: image_url.url.clone() }
MessagesImageSource::Url {
url: image_url.url.clone(),
}
}
}
/// Convert OpenAI message to Anthropic content blocks
pub fn convert_openai_message_to_anthropic_content(message: &Message) -> Result<Vec<MessagesContentBlock>, TransformError> {
pub fn convert_openai_message_to_anthropic_content(
message: &Message,
) -> Result<Vec<MessagesContentBlock>, TransformError> {
let mut blocks = Vec::new();
// Handle regular content
match &message.content {
MessageContent::Text(text) => {
if !text.is_empty() {
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None });
blocks.push(MessagesContentBlock::Text {
text: text.clone(),
cache_control: None,
});
}
}
MessageContent::Parts(parts) => {
for part in parts {
match part {
ContentPart::Text { text } => {
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None });
blocks.push(MessagesContentBlock::Text {
text: text.clone(),
cache_control: None,
});
}
ContentPart::ImageUrl { image_url } => {
let source = convert_image_url_to_source(image_url);

View file

@ -8,14 +8,14 @@
//! by the gateway, but the external API surface remains these two standard formats.
//! The transformations are split into logical modules for maintainability.
pub mod lib;
pub mod request;
pub mod response;
pub mod lib;
// Re-export commonly used items for convenience
pub use lib::*;
pub use request::*;
pub use response::*;
pub use lib::*;
// ============================================================================
// CONSTANTS

View file

@ -1,14 +1,21 @@
use crate::transforms::lib::*;
use crate::clients::TransformError;
use crate::apis::anthropic::{MessagesMessage, MessagesRequest, MessagesMessageContent, MessagesRole, MessagesStopReason, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage, MessagesSystemPrompt, ToolResultContent};
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Function, FunctionChoice,FinishReason, Usage, ContentPart};
use crate::apis::amazon_bedrock::{
ConverseRequest, SystemContentBlock, InferenceConfiguration, ToolConfiguration,
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolInputSchema, ToolSpecDefinition,
AutoChoice, AnyChoice, ToolChoiceSpec,
Message as BedrockMessage, ConversationRole, ContentBlock,
ToolUseBlock, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ImageBlock, ImageSource
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, ImageBlock,
ImageSource, InferenceConfiguration, Message as BedrockMessage, SystemContentBlock,
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration,
ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ToolSpecDefinition,
ToolUseBlock,
};
use crate::apis::anthropic::{
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, MessagesStopReason,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage,
ToolResultContent,
};
use crate::apis::openai::{
ChatCompletionsRequest, ContentPart, FinishReason, Function, FunctionChoice, Message,
MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Usage,
};
use crate::clients::TransformError;
use crate::transforms::lib::*;
type AnthropicMessagesRequest = MessagesRequest;
@ -32,7 +39,8 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
// Convert tools and tool choice
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice);
let (openai_tool_choice, parallel_tool_calls) =
convert_anthropic_tool_choice(req.tool_choice);
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
model: req.model,
@ -91,7 +99,8 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
// Convert tools and tool choice to ToolConfiguration
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
let tools = req.tools.map(|anthropic_tools| {
anthropic_tools.into_iter()
anthropic_tools
.into_iter()
.map(|tool| BedrockTool::ToolSpec {
tool_spec: ToolSpecDefinition {
name: tool.name,
@ -106,25 +115,28 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
let tool_choice = req.tool_choice.map(|choice| {
match choice.kind {
MessagesToolChoiceType::Auto => BedrockToolChoice::Auto { auto: AutoChoice {} },
MessagesToolChoiceType::Auto => BedrockToolChoice::Auto {
auto: AutoChoice {},
},
MessagesToolChoiceType::Any => BedrockToolChoice::Any { any: AnyChoice {} },
MessagesToolChoiceType::None => BedrockToolChoice::Auto { auto: AutoChoice {} }, // Bedrock doesn't have explicit "none"
MessagesToolChoiceType::None => BedrockToolChoice::Auto {
auto: AutoChoice {},
}, // Bedrock doesn't have explicit "none"
MessagesToolChoiceType::Tool => {
if let Some(name) = choice.name {
BedrockToolChoice::Tool {
tool: ToolChoiceSpec { name }
tool: ToolChoiceSpec { name },
}
} else {
BedrockToolChoice::Auto { auto: AutoChoice {} }
BedrockToolChoice::Auto {
auto: AutoChoice {},
}
}
}
}
});
Some(ToolConfiguration {
tools,
tool_choice,
})
Some(ToolConfiguration { tools, tool_choice })
} else {
None
};
@ -147,7 +159,6 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
}
}
// Message Conversions
impl TryFrom<MessagesMessage> for Vec<Message> {
type Error = TransformError;
@ -186,7 +197,11 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
role: message.role.into(),
content,
name: None,
tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) },
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
};
result.push(main_message);
@ -198,7 +213,6 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
}
}
// Role Conversions
impl Into<Role> for MessagesRole {
fn into(self) -> Role {
@ -237,9 +251,7 @@ impl Into<Message> for MessagesSystemPrompt {
fn into(self) -> Message {
let system_content = match self {
MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
MessagesSystemPrompt::Blocks(blocks) => {
MessageContent::Text(blocks.extract_text())
}
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
};
Message {
@ -255,7 +267,8 @@ impl Into<Message> for MessagesSystemPrompt {
//Utility Functions
/// Convert Anthropic tools to OpenAI format
fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
tools.into_iter()
tools
.into_iter()
.map(|tool| Tool {
tool_type: "function".to_string(),
function: Function {
@ -269,7 +282,9 @@ fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
}
/// Convert Anthropic tool choice to OpenAI format
fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Option<ToolChoice>, Option<bool>) {
fn convert_anthropic_tool_choice(
tool_choice: Option<MessagesToolChoice>,
) -> (Option<ToolChoice>, Option<bool>) {
match tool_choice {
Some(choice) => {
let openai_choice = match choice.kind {
@ -290,12 +305,15 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
let parallel = choice.disable_parallel_tool_use.map(|disable| !disable);
(Some(openai_choice), parallel)
}
None => (None, None)
None => (None, None),
}
}
/// Build OpenAI message content from parts and tool calls
fn build_openai_content(content_parts: Vec<ContentPart>, tool_calls: &[ToolCall]) -> MessageContent {
fn build_openai_content(
content_parts: Vec<ContentPart>,
tool_calls: &[ToolCall],
) -> MessageContent {
if content_parts.len() == 1 && tool_calls.is_empty() {
match &content_parts[0] {
ContentPart::Text { text } => MessageContent::Text(text.clone()),
@ -334,7 +352,12 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
content_blocks.push(ContentBlock::Text { text });
}
}
crate::apis::anthropic::MessagesContentBlock::ToolUse { id, name, input, .. } => {
crate::apis::anthropic::MessagesContentBlock::ToolUse {
id,
name,
input,
..
} => {
content_blocks.push(ContentBlock::ToolUse {
tool_use: ToolUseBlock {
tool_use_id: id,
@ -343,7 +366,12 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
},
});
}
crate::apis::anthropic::MessagesContentBlock::ToolResult { tool_use_id, is_error, content, .. } => {
crate::apis::anthropic::MessagesContentBlock::ToolResult {
tool_use_id,
is_error,
content,
..
} => {
// Convert Anthropic ToolResultContent to Bedrock ToolResultContentBlock
let tool_result_content = match content {
ToolResultContent::Text(text) => {
@ -366,7 +394,9 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
// Ensure we have at least one content block
let final_content = if tool_result_content.is_empty() {
vec![ToolResultContentBlock::Text { text: " ".to_string() }]
vec![ToolResultContentBlock::Text {
text: " ".to_string(),
}]
} else {
tool_result_content
};
@ -388,13 +418,13 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
crate::apis::anthropic::MessagesContentBlock::Image { source } => {
// Convert Anthropic image to Bedrock image format
match source {
crate::apis::anthropic::MessagesImageSource::Base64 { media_type, data } => {
crate::apis::anthropic::MessagesImageSource::Base64 {
media_type,
data,
} => {
content_blocks.push(ContentBlock::Image {
image: ImageBlock {
source: ImageSource::Base64 {
media_type,
data,
},
source: ImageSource::Base64 { media_type, data },
},
});
}
@ -413,7 +443,9 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
// Ensure we have at least one content block
if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text { text: " ".to_string() });
content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
}
Ok(BedrockMessage {
@ -423,29 +455,33 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::anthropic::{MessagesRequest, MessagesMessage, MessagesMessageContent, MessagesRole, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesSystemPrompt};
use crate::apis::amazon_bedrock::{ConverseRequest, SystemContentBlock, ToolChoice as BedrockToolChoice, ConversationRole, ContentBlock};
use crate::apis::amazon_bedrock::{
ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
ToolChoice as BedrockToolChoice,
};
use crate::apis::anthropic::{
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
};
use serde_json::json;
#[test]
fn test_anthropic_to_bedrock_basic_request() {
let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![
MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Hello, how are you?".to_string()),
}
],
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Hello, how are you?".to_string()),
}],
max_tokens: 1000,
container: None,
mcp_servers: None,
system: Some(MessagesSystemPrompt::Single("You are a helpful assistant.".to_string())),
system: Some(MessagesSystemPrompt::Single(
"You are a helpful assistant.".to_string(),
)),
metadata: None,
service_tier: None,
thinking: None,
@ -478,19 +514,20 @@ mod tests {
assert_eq!(inference_config.temperature, Some(0.7));
assert_eq!(inference_config.top_p, Some(0.9));
assert_eq!(inference_config.max_tokens, Some(1000));
assert_eq!(inference_config.stop_sequences, Some(vec!["STOP".to_string()]));
assert_eq!(
inference_config.stop_sequences,
Some(vec!["STOP".to_string()])
);
}
#[test]
fn test_anthropic_to_bedrock_with_tools() {
let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![
MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("What's the weather like?".to_string()),
}
],
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("What's the weather like?".to_string()),
}],
max_tokens: 1000,
container: None,
mcp_servers: None,
@ -503,22 +540,20 @@ mod tests {
top_k: None,
stream: None,
stop_sequences: None,
tools: Some(vec![
MessagesTool {
name: "get_weather".to_string(),
description: Some("Get current weather information".to_string()),
input_schema: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
}
},
"required": ["location"]
}),
}
]),
tools: Some(vec![MessagesTool {
name: "get_weather".to_string(),
description: Some("Get current weather information".to_string()),
input_schema: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
}
},
"required": ["location"]
}),
}]),
tool_choice: Some(MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some("get_weather".to_string()),
@ -537,7 +572,10 @@ mod tests {
assert_eq!(tools.len(), 1);
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
assert_eq!(tool_spec.name, "get_weather");
assert_eq!(tool_spec.description, Some("Get current weather information".to_string()));
assert_eq!(
tool_spec.description,
Some("Get current weather information".to_string())
);
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
assert_eq!(tool.name, "get_weather");
@ -550,12 +588,10 @@ mod tests {
fn test_anthropic_to_bedrock_auto_tool_choice() {
let anthropic_request = MessagesRequest {
model: "claude-3-5-sonnet-20241022".to_string(),
messages: vec![
MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Help me with something".to_string()),
}
],
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Help me with something".to_string()),
}],
max_tokens: 500,
container: None,
mcp_servers: None,
@ -568,16 +604,14 @@ mod tests {
top_k: None,
stream: None,
stop_sequences: None,
tools: Some(vec![
MessagesTool {
name: "help_tool".to_string(),
description: Some("A helpful tool".to_string()),
input_schema: json!({
"type": "object",
"properties": {}
}),
}
]),
tools: Some(vec![MessagesTool {
name: "help_tool".to_string(),
description: Some("A helpful tool".to_string()),
input_schema: json!({
"type": "object",
"properties": {}
}),
}]),
tool_choice: Some(MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
@ -589,7 +623,10 @@ mod tests {
assert!(bedrock_request.tool_config.is_some());
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
assert!(matches!(tool_config.tool_choice, Some(BedrockToolChoice::Auto { .. })));
assert!(matches!(
tool_config.tool_choice,
Some(BedrockToolChoice::Auto { .. })
));
}
#[test]
@ -603,12 +640,14 @@ mod tests {
},
MessagesMessage {
role: MessagesRole::Assistant,
content: MessagesMessageContent::Single("Hi there! How can I help you?".to_string()),
content: MessagesMessageContent::Single(
"Hi there! How can I help you?".to_string(),
),
},
MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("What's 2+2?".to_string()),
}
},
],
max_tokens: 100,
container: None,

View file

@ -1,15 +1,21 @@
use crate::apis::amazon_bedrock::{
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, InferenceConfiguration,
Message as BedrockMessage, SystemContentBlock, Tool as BedrockTool,
ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration, ToolInputSchema,
ToolSpecDefinition,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
ToolResultContent,
};
use crate::apis::openai::{
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
};
use crate::clients::TransformError;
use crate::transforms::lib::ExtractText;
use crate::transforms::lib::*;
use crate::clients::TransformError;
use crate::transforms::*;
use crate::apis::anthropic::{MessagesSystemPrompt, MessagesMessage,MessagesRequest, MessagesMessageContent, MessagesContentBlock, MessagesRole, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, ToolResultContent};
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, Tool, ToolChoice, ToolChoiceType, MessageContent};
use crate::apis::amazon_bedrock::{
ConverseRequest, SystemContentBlock, InferenceConfiguration, ToolConfiguration,
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolInputSchema, ToolSpecDefinition,
AutoChoice, AnyChoice, ToolChoiceSpec,
Message as BedrockMessage, ConversationRole, ContentBlock
};
type AnthropicMessagesRequest = MessagesRequest;
@ -21,7 +27,7 @@ impl Into<MessagesSystemPrompt> for Message {
fn into(self) -> MessagesSystemPrompt {
let system_text = match self.content {
MessageContent::Text(text) => text,
MessageContent::Parts(parts) => parts.extract_text()
MessageContent::Parts(parts) => parts.extract_text(),
};
MessagesSystemPrompt::Single(system_text)
}
@ -36,8 +42,11 @@ impl TryFrom<Message> for MessagesMessage {
Role::Assistant => MessagesRole::Assistant,
Role::Tool => {
// Tool messages become user messages with tool results
let tool_call_id = message.tool_call_id
.ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?;
let tool_call_id = message.tool_call_id.ok_or_else(|| {
TransformError::MissingField(
"tool_call_id required for Tool messages".to_string(),
)
})?;
return Ok(MessagesMessage {
role: MessagesRole::User,
@ -55,7 +64,9 @@ impl TryFrom<Message> for MessagesMessage {
});
}
Role::System => {
return Err(TransformError::UnsupportedConversion("System messages should be handled separately".to_string()));
return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately".to_string(),
));
}
};
@ -75,7 +86,9 @@ impl TryFrom<Message> for BedrockMessage {
Role::Assistant => ConversationRole::Assistant,
Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock
Role::System => {
return Err(TransformError::UnsupportedConversion("System messages should be handled separately in Bedrock".to_string()));
return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately in Bedrock".to_string(),
));
}
};
@ -103,7 +116,9 @@ impl TryFrom<Message> for BedrockMessage {
crate::apis::openai::ContentPart::ImageUrl { image_url } => {
// Convert image URL to Bedrock image format
if image_url.url.starts_with("data:") {
if let Some((media_type, data)) = parse_data_url(&image_url.url) {
if let Some((media_type, data)) =
parse_data_url(&image_url.url)
{
content_blocks.push(ContentBlock::Image {
image: crate::apis::amazon_bedrock::ImageBlock {
source: crate::apis::amazon_bedrock::ImageSource::Base64 {
@ -114,7 +129,10 @@ impl TryFrom<Message> for BedrockMessage {
});
} else {
return Err(TransformError::UnsupportedConversion(
format!("Invalid data URL format: {}", image_url.url)
format!(
"Invalid data URL format: {}",
image_url.url
),
));
}
} else {
@ -130,13 +148,18 @@ impl TryFrom<Message> for BedrockMessage {
// Ensure we have at least one content block
if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text { text: " ".to_string() });
content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
}
}
Role::Assistant => {
// Handle text content - but only add if non-empty OR if we don't have tool calls
let text_content = message.content.extract_text();
let has_tool_calls = message.tool_calls.as_ref().map_or(false, |calls| !calls.is_empty());
let has_tool_calls = message
.tool_calls
.as_ref()
.map_or(false, |calls| !calls.is_empty());
// Add text content if it's non-empty, or if we have no tool calls (to avoid empty content)
if !text_content.is_empty() {
@ -144,17 +167,22 @@ impl TryFrom<Message> for BedrockMessage {
} else if !has_tool_calls {
// If we have empty content and no tool calls, add a minimal placeholder
// This prevents the "blank text field" error
content_blocks.push(ContentBlock::Text { text: " ".to_string() });
content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
}
// Convert tool calls to ToolUse content blocks
if let Some(tool_calls) = message.tool_calls {
for tool_call in tool_calls {
// Parse the arguments string as JSON
let input: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
.map_err(|e| TransformError::UnsupportedConversion(
format!("Failed to parse tool arguments as JSON: {}. Arguments: {}", e, tool_call.function.arguments)
))?;
let input: serde_json::Value =
serde_json::from_str(&tool_call.function.arguments).map_err(|e| {
TransformError::UnsupportedConversion(format!(
"Failed to parse tool arguments as JSON: {}. Arguments: {}",
e, tool_call.function.arguments
))
})?;
content_blocks.push(ContentBlock::ToolUse {
tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
@ -168,13 +196,18 @@ impl TryFrom<Message> for BedrockMessage {
// Bedrock requires at least one content block
if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text { text: " ".to_string() });
content_blocks.push(ContentBlock::Text {
text: " ".to_string(),
});
}
}
Role::Tool => {
// Tool messages become user messages with ToolResult content blocks
let tool_call_id = message.tool_call_id
.ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?;
let tool_call_id = message.tool_call_id.ok_or_else(|| {
TransformError::MissingField(
"tool_call_id required for Tool messages".to_string(),
)
})?;
let tool_content = message.content.extract_text();
@ -182,11 +215,11 @@ impl TryFrom<Message> for BedrockMessage {
let tool_result_content = if tool_content.is_empty() {
// Even for tool results, we need non-empty content
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
text: " ".to_string()
text: " ".to_string(),
}]
} else {
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
text: tool_content
text: tool_content,
}]
};
@ -232,13 +265,15 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
// Convert tools and tool choice
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
let anthropic_tool_choice =
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
Ok(AnthropicMessagesRequest {
model: req.model,
system: system_prompt,
messages,
max_tokens: req.max_completion_tokens
max_tokens: req
.max_completion_tokens
.or(req.max_tokens)
.unwrap_or(DEFAULT_MAX_TOKENS),
container: None,
@ -297,8 +332,11 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
// Build inference configuration
let max_tokens = req.max_completion_tokens.or(req.max_tokens);
let inference_config = if max_tokens.is_some() || req.temperature.is_some() ||
req.top_p.is_some() || req.stop.is_some() {
let inference_config = if max_tokens.is_some()
|| req.temperature.is_some()
|| req.top_p.is_some()
|| req.stop.is_some()
{
Some(InferenceConfiguration {
max_tokens,
temperature: req.temperature,
@ -312,7 +350,8 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
// Convert tools and tool choice to ToolConfiguration
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
let tools = req.tools.map(|openai_tools| {
openai_tools.into_iter()
openai_tools
.into_iter()
.map(|tool| BedrockTool::ToolSpec {
tool_spec: ToolSpecDefinition {
name: tool.function.name,
@ -325,34 +364,40 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
.collect()
});
let tool_choice = req.tool_choice.map(|choice| {
match choice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => BedrockToolChoice::Auto { auto: AutoChoice {} },
ToolChoiceType::Required => BedrockToolChoice::Any { any: AnyChoice {} },
ToolChoiceType::None => BedrockToolChoice::Auto { auto: AutoChoice {} }, // Bedrock doesn't have explicit "none"
},
ToolChoice::Function { function, .. } => {
BedrockToolChoice::Tool {
tool: ToolChoiceSpec {
name: function.name
let tool_choice = req
.tool_choice
.map(|choice| {
match choice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => BedrockToolChoice::Auto {
auto: AutoChoice {},
},
ToolChoiceType::Required => {
BedrockToolChoice::Any { any: AnyChoice {} }
}
}
ToolChoiceType::None => BedrockToolChoice::Auto {
auto: AutoChoice {},
}, // Bedrock doesn't have explicit "none"
},
ToolChoice::Function { function, .. } => BedrockToolChoice::Tool {
tool: ToolChoiceSpec {
name: function.name,
},
},
}
}
}).or_else(|| {
// If tools are present but no tool_choice specified, default to "auto"
if tools.is_some() {
Some(BedrockToolChoice::Auto { auto: AutoChoice {} })
} else {
None
}
});
})
.or_else(|| {
// If tools are present but no tool_choice specified, default to "auto"
if tools.is_some() {
Some(BedrockToolChoice::Auto {
auto: AutoChoice {},
})
} else {
None
}
});
Some(ToolConfiguration {
tools,
tool_choice,
})
Some(ToolConfiguration { tools, tool_choice })
} else {
None
};
@ -377,7 +422,8 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
/// Convert OpenAI tools to Anthropic format
fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
tools.into_iter()
tools
.into_iter()
.map(|tool| MessagesTool {
name: tool.function.name,
description: tool.function.description,
@ -386,37 +432,34 @@ fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
.collect()
}
/// Convert OpenAI tool choice to Anthropic format
fn convert_openai_tool_choice(
tool_choice: Option<ToolChoice>,
parallel_tool_calls: Option<bool>
parallel_tool_calls: Option<bool>,
) -> Option<MessagesToolChoice> {
tool_choice.map(|choice| {
match choice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::Required => MessagesToolChoice {
kind: MessagesToolChoiceType::Any,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::None => MessagesToolChoice {
kind: MessagesToolChoiceType::None,
name: None,
disable_parallel_tool_use: None,
},
},
ToolChoice::Function { function, .. } => MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some(function.name),
tool_choice.map(|choice| match choice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
}
ToolChoiceType::Required => MessagesToolChoice {
kind: MessagesToolChoiceType::Any,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::None => MessagesToolChoice {
kind: MessagesToolChoiceType::None,
name: None,
disable_parallel_tool_use: None,
},
},
ToolChoice::Function { function, .. } => MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some(function.name),
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
})
}
@ -434,8 +477,6 @@ fn build_anthropic_content(content_blocks: Vec<MessagesContentBlock>) -> Message
}
}
/// Parse a data URL into media type and base64 data
/// Supports format: data:image/jpeg;base64,<data>
fn parse_data_url(url: &str) -> Option<(String, String)> {
@ -473,8 +514,14 @@ fn parse_data_url(url: &str) -> Option<(String, String)> {
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType, Function, FunctionChoice};
use crate::apis::amazon_bedrock::{ConverseRequest, SystemContentBlock, ConversationRole, ContentBlock, ToolChoice as BedrockToolChoice};
use crate::apis::amazon_bedrock::{
ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
ToolChoice as BedrockToolChoice,
};
use crate::apis::openai::{
ChatCompletionsRequest, Function, FunctionChoice, Message, MessageContent, Role, Tool,
ToolChoice, ToolChoiceType,
};
use serde_json::json;
#[test]
@ -495,7 +542,7 @@ mod tests {
name: None,
tool_call_id: None,
tool_calls: None,
}
},
],
temperature: Some(0.7),
top_p: Some(0.9),
@ -534,50 +581,51 @@ mod tests {
assert_eq!(inference_config.temperature, Some(0.7));
assert_eq!(inference_config.top_p, Some(0.9));
assert_eq!(inference_config.max_tokens, Some(1000));
assert_eq!(inference_config.stop_sequences, Some(vec!["STOP".to_string()]));
assert_eq!(
inference_config.stop_sequences,
Some(vec!["STOP".to_string()])
);
}
#[test]
fn test_openai_to_bedrock_with_tools() {
let openai_request = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::User,
content: MessageContent::Text("What's the weather like?".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
}
],
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("What's the weather like?".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
}],
temperature: None,
top_p: None,
max_completion_tokens: Some(1000),
stop: None,
stream: None,
tools: Some(vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get current weather information".to_string()),
parameters: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
}
},
"required": ["location"]
}),
strict: None,
},
}
]),
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get current weather information".to_string()),
parameters: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name"
}
},
"required": ["location"]
}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::Function {
choice_type: "function".to_string(),
function: FunctionChoice { name: "get_weather".to_string() },
function: FunctionChoice {
name: "get_weather".to_string(),
},
}),
..Default::default()
};
@ -594,7 +642,10 @@ mod tests {
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
assert_eq!(tool_spec.name, "get_weather");
assert_eq!(tool_spec.description, Some("Get current weather information".to_string()));
assert_eq!(
tool_spec.description,
Some("Get current weather information".to_string())
);
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
assert_eq!(tool.name, "get_weather");
@ -607,34 +658,30 @@ mod tests {
fn test_openai_to_bedrock_auto_tool_choice() {
let openai_request = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::User,
content: MessageContent::Text("Help me with something".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
}
],
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("Help me with something".to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
}],
temperature: None,
top_p: None,
max_completion_tokens: Some(500),
stop: None,
stream: None,
tools: Some(vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "help_tool".to_string(),
description: Some("A helpful tool".to_string()),
parameters: json!({
"type": "object",
"properties": {}
}),
strict: None,
},
}
]),
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "help_tool".to_string(),
description: Some("A helpful tool".to_string()),
parameters: json!({
"type": "object",
"properties": {}
}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::Type(ToolChoiceType::Auto)),
..Default::default()
};
@ -643,7 +690,10 @@ mod tests {
assert!(bedrock_request.tool_config.is_some());
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
assert!(matches!(tool_config.tool_choice, Some(BedrockToolChoice::Auto { .. })));
assert!(matches!(
tool_config.tool_choice,
Some(BedrockToolChoice::Auto { .. })
));
}
#[test]
@ -678,7 +728,7 @@ mod tests {
name: None,
tool_call_id: None,
tool_calls: None,
}
},
],
temperature: Some(0.5),
top_p: None,

View file

@ -1,16 +1,16 @@
use serde_json::Value;
use crate::transforms::lib::*;
use crate::clients::TransformError;
use crate::apis::openai::{
ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta
use crate::apis::amazon_bedrock::{
ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
};
use crate::apis::anthropic::{
MessagesStreamEvent, MessagesStopReason, MessagesMessageDelta, MessagesResponse,
MessagesStreamMessage, MessagesUsage, MessagesContentDelta, MessagesRole, MessagesContentBlock
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse,
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
};
use crate::apis::amazon_bedrock::{
ConverseResponse, ConverseOutput, StopReason, ConverseStreamEvent, ContentBlockDelta
use crate::apis::openai::{
ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta,
};
use crate::clients::TransformError;
use crate::transforms::lib::*;
use serde_json::Value;
// ============================================================================
// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience
@ -20,11 +20,15 @@ impl TryFrom<ChatCompletionsResponse> for MessagesResponse {
type Error = TransformError;
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
let choice = resp.choices.into_iter().next()
let choice = resp
.choices
.into_iter()
.next()
.ok_or_else(|| TransformError::MissingField("choices".to_string()))?;
let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?;
let stop_reason = choice.finish_reason
let stop_reason = choice
.finish_reason
.map(|fr| fr.into())
.unwrap_or(MessagesStopReason::EndTurn);
@ -86,13 +90,17 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
};
// Generate a response ID (Bedrock doesn't provide one)
let id = format!("bedrock-{}", std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos());
let id = format!(
"bedrock-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
);
// Extract model ID from trace information if available, otherwise use fallback
let model = resp.trace
let model = resp
.trace
.as_ref()
.and_then(|trace| trace.prompt_router.as_ref())
.map(|router| router.invoked_model_id.clone())
@ -231,15 +239,20 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
ConverseStreamEvent::MessageStart(start_event) => {
let role = match start_event.role {
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
crate::apis::amazon_bedrock::ConversationRole::Assistant => MessagesRole::Assistant,
crate::apis::amazon_bedrock::ConversationRole::Assistant => {
MessagesRole::Assistant
}
};
Ok(MessagesStreamEvent::MessageStart {
message: MessagesStreamMessage {
id: format!("bedrock-stream-{}", std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()),
id: format!(
"bedrock-stream-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
),
obj_type: "message".to_string(),
role,
content: vec![],
@ -278,11 +291,11 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
// ContentBlockDelta - convert to Anthropic ContentBlockDelta
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
let delta = match delta_event.delta {
ContentBlockDelta::Text { text } => {
MessagesContentDelta::TextDelta { text }
}
ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
ContentBlockDelta::ToolUse { tool_use } => {
MessagesContentDelta::InputJsonDelta { partial_json: tool_use.input }
MessagesContentDelta::InputJsonDelta {
partial_json: tool_use.input,
}
}
};
@ -342,11 +355,11 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
}
// Exception events - convert to Ping (could be enhanced to return error events)
ConverseStreamEvent::InternalServerException(_) |
ConverseStreamEvent::ModelStreamErrorException(_) |
ConverseStreamEvent::ServiceUnavailableException(_) |
ConverseStreamEvent::ThrottlingException(_) |
ConverseStreamEvent::ValidationException(_) => {
ConverseStreamEvent::InternalServerException(_)
| ConverseStreamEvent::ModelStreamErrorException(_)
| ConverseStreamEvent::ServiceUnavailableException(_)
| ConverseStreamEvent::ThrottlingException(_)
| ConverseStreamEvent::ValidationException(_) => {
// TODO: Consider adding proper error handling/events
Ok(MessagesStreamEvent::Ping)
}
@ -355,7 +368,9 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
}
/// Convert tool call deltas to Anthropic stream events
fn convert_tool_call_deltas(tool_calls: Vec<ToolCallDelta>) -> Result<MessagesStreamEvent, TransformError> {
fn convert_tool_call_deltas(
tool_calls: Vec<ToolCallDelta>,
) -> Result<MessagesStreamEvent, TransformError> {
for tool_call in tool_calls {
if let Some(id) = &tool_call.id {
// Tool call start
@ -403,7 +418,9 @@ fn convert_tool_call_deltas(tool_calls: Vec<ToolCallDelta>) -> Result<MessagesSt
///
/// Note on S3/URL handling: Converting S3 locations or URLs would require async operations
/// to download and convert to base64, which is not implemented in this synchronous function.
fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_bedrock::Message) -> Result<Vec<MessagesContentBlock>, TransformError> {
fn convert_bedrock_message_to_anthropic_content(
message: &crate::apis::amazon_bedrock::Message,
) -> Result<Vec<MessagesContentBlock>, TransformError> {
use crate::apis::amazon_bedrock::ContentBlock;
let mut content_blocks = Vec::new();
@ -438,16 +455,19 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
crate::apis::amazon_bedrock::ToolResultContentBlock::Image { source } => {
// Convert Bedrock ImageSource to Anthropic format
match source {
crate::apis::amazon_bedrock::ImageSource::Base64 { media_type, data } => {
crate::apis::amazon_bedrock::ImageSource::Base64 {
media_type,
data,
} => {
tool_result_blocks.push(MessagesContentBlock::Image {
source: crate::apis::anthropic::MessagesImageSource::Base64 {
media_type: media_type.clone(),
data: data.clone(),
},
source:
crate::apis::anthropic::MessagesImageSource::Base64 {
media_type: media_type.clone(),
data: data.clone(),
},
});
}
// Note: S3Location is not yet implemented in the current Bedrock API definition
// but would need async handling when added
} // Note: S3Location is not yet implemented in the current Bedrock API definition
// but would need async handling when added
}
}
crate::apis::amazon_bedrock::ToolResultContentBlock::Json { json } => {
@ -463,7 +483,10 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
use crate::apis::anthropic::ToolResultContent;
content_blocks.push(MessagesContentBlock::ToolResult {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.status.as_ref().map(|s| matches!(s, crate::apis::amazon_bedrock::ToolResultStatus::Error)),
is_error: tool_result
.status
.as_ref()
.map(|s| matches!(s, crate::apis::amazon_bedrock::ToolResultStatus::Error)),
content: ToolResultContent::Blocks(tool_result_blocks),
cache_control: None,
});
@ -478,8 +501,7 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
data: data.clone(),
},
});
}
// Note: S3Location would require async handling if implemented
} // Note: S3Location would require async handling if implemented
}
}
ContentBlock::Document { document } => {
@ -493,8 +515,7 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
data: data.clone(),
},
});
}
// Note: S3Location would require async handling if implemented
} // Note: S3Location would require async handling if implemented
}
}
ContentBlock::GuardContent { guard_content } => {
@ -516,11 +537,13 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
mod tests {
use super::*;
use crate::apis::amazon_bedrock::{
ConverseResponse, ConverseOutput, Message as BedrockMessage, ConversationRole,
ContentBlock, StopReason, BedrockTokenUsage, ToolResultContentBlock, ToolResultStatus,
ConverseTrace, PromptRouterTrace
BedrockTokenUsage, ContentBlock, ConversationRole, ConverseOutput, ConverseResponse,
ConverseTrace, Message as BedrockMessage, PromptRouterTrace, StopReason,
ToolResultContentBlock, ToolResultStatus,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesResponse, MessagesRole, MessagesStopReason, ToolResultContent,
};
use crate::apis::anthropic::{MessagesResponse, MessagesContentBlock, MessagesStopReason, MessagesRole, ToolResultContent};
use serde_json::json;
#[test]
@ -529,11 +552,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::Assistant,
content: vec![
ContentBlock::Text {
text: "Hello! How can I help you today?".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "Hello! How can I help you today?".to_string(),
}],
},
},
stop_reason: StopReason::EndTurn,
@ -592,7 +613,7 @@ mod tests {
"location": "San Francisco"
}),
},
}
},
],
},
},
@ -624,7 +645,10 @@ mod tests {
}
// Check tool use content
if let MessagesContentBlock::ToolUse { id, name, input, .. } = &anthropic_response.content[1] {
if let MessagesContentBlock::ToolUse {
id, name, input, ..
} = &anthropic_response.content[1]
{
assert_eq!(id, "tool_12345");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco");
@ -649,11 +673,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::Assistant,
content: vec![
ContentBlock::Text {
text: "Test response".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "Test response".to_string(),
}],
},
},
stop_reason: bedrock_stop_reason,
@ -670,7 +692,10 @@ mod tests {
};
let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap();
assert_eq!(anthropic_response.stop_reason, expected_anthropic_stop_reason);
assert_eq!(
anthropic_response.stop_reason,
expected_anthropic_stop_reason
);
}
}
@ -680,11 +705,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::Assistant,
content: vec![
ContentBlock::Text {
text: "Cached response".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "Cached response".to_string(),
}],
},
},
stop_reason: StopReason::EndTurn,
@ -706,7 +729,10 @@ mod tests {
assert_eq!(anthropic_response.usage.input_tokens, 100);
assert_eq!(anthropic_response.usage.output_tokens, 50);
assert_eq!(anthropic_response.usage.cache_creation_input_tokens, Some(20));
assert_eq!(
anthropic_response.usage.cache_creation_input_tokens,
Some(20)
);
assert_eq!(anthropic_response.usage.cache_read_input_tokens, Some(10));
}
@ -723,14 +749,12 @@ mod tests {
ContentBlock::ToolResult {
tool_result: crate::apis::amazon_bedrock::ToolResultBlock {
tool_use_id: "tool_67890".to_string(),
content: vec![
ToolResultContentBlock::Text {
text: "Temperature: 72°F, Sunny".to_string(),
}
],
content: vec![ToolResultContentBlock::Text {
text: "Temperature: 72°F, Sunny".to_string(),
}],
status: Some(ToolResultStatus::Success),
},
}
},
],
},
},
@ -761,7 +785,12 @@ mod tests {
}
// Check tool result content
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &anthropic_response.content[1] {
if let MessagesContentBlock::ToolResult {
tool_use_id,
content,
..
} = &anthropic_response.content[1]
{
assert_eq!(tool_use_id, "tool_67890");
if let ToolResultContent::Blocks(blocks) = content {
assert_eq!(blocks.len(), 1);
@ -804,7 +833,7 @@ mod tests {
name: "lookup".to_string(),
input: json!({"id": "12345"}),
},
}
},
],
},
},
@ -870,11 +899,12 @@ mod tests {
name: "test_function".to_string(),
input: json!({"param": "value"}),
},
}
},
],
};
let content_blocks = convert_bedrock_message_to_anthropic_content(&bedrock_message).unwrap();
let content_blocks =
convert_bedrock_message_to_anthropic_content(&bedrock_message).unwrap();
assert_eq!(content_blocks.len(), 2);
@ -884,7 +914,10 @@ mod tests {
panic!("Expected text content block");
}
if let MessagesContentBlock::ToolUse { id, name, input, .. } = &content_blocks[1] {
if let MessagesContentBlock::ToolUse {
id, name, input, ..
} = &content_blocks[1]
{
assert_eq!(id, "test_tool");
assert_eq!(name, "test_function");
assert_eq!(input["param"], "value");
@ -900,11 +933,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::Assistant,
content: vec![
ContentBlock::Text {
text: "I am an assistant".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "I am an assistant".to_string(),
}],
},
},
stop_reason: StopReason::EndTurn,
@ -928,11 +959,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::User,
content: vec![
ContentBlock::Text {
text: "I am a user".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "I am a user".to_string(),
}],
},
},
stop_reason: StopReason::EndTurn,
@ -985,7 +1014,10 @@ mod tests {
let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap();
// Should extract model ID from trace
assert_eq!(anthropic_response.model, "anthropic.claude-3-sonnet-20240229-v1:0");
assert_eq!(
anthropic_response.model,
"anthropic.claude-3-sonnet-20240229-v1:0"
);
// Test fallback when no trace information is available
let bedrock_response_no_trace = ConverseResponse {
@ -1010,7 +1042,8 @@ mod tests {
performance_config: None,
};
let anthropic_response_fallback: MessagesResponse = bedrock_response_no_trace.try_into().unwrap();
let anthropic_response_fallback: MessagesResponse =
bedrock_response_no_trace.try_into().unwrap();
// Should use fallback model name
assert_eq!(anthropic_response_fallback.model, "bedrock-model");

View file

@ -1,10 +1,18 @@
use crate::apis::openai::{ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, ResponseMessage, Role, ToolCallDelta, FunctionCallDelta, Usage, StreamChoice, MessageDelta, MessageContent};
use crate::apis::anthropic::{MessagesResponse, MessagesStreamEvent, MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesUsage};
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason};
use crate::apis::amazon_bedrock::{
ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason,
MessagesStreamEvent, MessagesUsage,
};
use crate::apis::openai::{
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice,
ToolCallDelta, Usage,
};
use crate::clients::TransformError;
use crate::transforms::lib::*;
// ============================================================================
// MAIN RESPONSE TRANSFORMATIONS
// ============================================================================
@ -35,7 +43,11 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
MessageContent::Text(text) => Some(text),
MessageContent::Parts(parts) => {
let text = parts.extract_text();
if text.is_empty() { None } else { Some(text) }
if text.is_empty() {
None
} else {
Some(text)
}
}
};
@ -105,7 +117,6 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
StopReason::ContentFiltered => FinishReason::ContentFilter,
};
// Create response message
let response_message = ResponseMessage {
role,
@ -135,13 +146,17 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
};
// Generate a response ID (using timestamp since Bedrock doesn't provide one)
let id = format!("bedrock-{}", std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos());
let id = format!(
"bedrock-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
);
// Extract model ID from trace information if available, otherwise use fallback
let model = resp.trace
let model = resp
.trace
.as_ref()
.and_then(|trace| trace.prompt_router.as_ref())
.map(|router| router.invoked_model_id.clone())
@ -160,7 +175,6 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
}
}
// ============================================================================
// STREAMING TRANSFORMATIONS
// ============================================================================
@ -170,33 +184,27 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
match event {
MessagesStreamEvent::MessageStart { message } => {
Ok(create_openai_chunk(
&message.id,
&message.model,
MessageDelta {
role: Some(Role::Assistant),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
&message.id,
&message.model,
MessageDelta {
role: Some(Role::Assistant),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
convert_content_block_start(content_block)
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
convert_content_delta(delta)
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
MessagesStreamEvent::ContentBlockStop { .. } => {
Ok(create_empty_openai_chunk())
}
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
MessagesStreamEvent::MessageDelta { delta, usage } => {
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
@ -217,39 +225,34 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
))
}
MessagesStreamEvent::MessageStop => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
Some(FinishReason::Stop),
None,
))
}
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
Some(FinishReason::Stop),
None,
)),
MessagesStreamEvent::Ping => {
Ok(ChatCompletionsStreamResponse {
id: "stream".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: current_timestamp(),
model: "unknown".to_string(),
choices: vec![],
usage: None,
system_fingerprint: None,
service_tier: None,
})
}
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
id: "stream".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: current_timestamp(),
model: "unknown".to_string(),
choices: vec![],
usage: None,
system_fingerprint: None,
service_tier: None,
}),
}
}
}
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
type Error = TransformError;
@ -280,29 +283,27 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
use crate::apis::amazon_bedrock::ContentBlockStart;
match start_event.start {
ContentBlockStart::ToolUse { tool_use } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: start_event.content_block_index as u32,
id: Some(tool_use.tool_use_id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(tool_use.name),
arguments: Some("".to_string()),
}),
}]),
},
None,
None,
))
}
ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: start_event.content_block_index as u32,
id: Some(tool_use.tool_use_id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(tool_use.name),
arguments: Some("".to_string()),
}),
}]),
},
None,
None,
)),
}
}
@ -310,50 +311,44 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
use crate::apis::amazon_bedrock::ContentBlockDelta;
match delta_event.delta {
ContentBlockDelta::Text { text } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
ContentBlockDelta::ToolUse { tool_use } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: delta_event.content_block_index as u32,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(tool_use.input),
}),
}]),
},
None,
None,
))
}
ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: delta_event.content_block_index as u32,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(tool_use.input),
}),
}]),
},
None,
None,
)),
}
}
ConverseStreamEvent::ContentBlockStop(_) => {
Ok(create_empty_openai_chunk())
}
ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
ConverseStreamEvent::MessageStop(stop_event) => {
let finish_reason = match stop_event.stop_reason {
@ -405,27 +400,27 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
}
// Error events - convert to empty chunks (errors should be handled elsewhere)
ConverseStreamEvent::InternalServerException(_) |
ConverseStreamEvent::ModelStreamErrorException(_) |
ConverseStreamEvent::ServiceUnavailableException(_) |
ConverseStreamEvent::ThrottlingException(_) |
ConverseStreamEvent::ValidationException(_) => {
Ok(create_empty_openai_chunk())
}
ConverseStreamEvent::InternalServerException(_)
| ConverseStreamEvent::ModelStreamErrorException(_)
| ConverseStreamEvent::ServiceUnavailableException(_)
| ConverseStreamEvent::ThrottlingException(_)
| ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
}
}
}
/// Convert content block start to OpenAI chunk
fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<ChatCompletionsStreamResponse, TransformError> {
fn convert_content_block_start(
content_block: MessagesContentBlock,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match content_block {
MessagesContentBlock::Text { .. } => {
// No immediate output for text block start
Ok(create_empty_openai_chunk())
}
MessagesContentBlock::ToolUse { id, name, .. } |
MessagesContentBlock::ServerToolUse { id, name, .. } |
MessagesContentBlock::McpToolUse { id, name, .. } => {
MessagesContentBlock::ToolUse { id, name, .. }
| MessagesContentBlock::ServerToolUse { id, name, .. }
| MessagesContentBlock::McpToolUse { id, name, .. } => {
// Tool use start → OpenAI chunk with tool_calls
Ok(create_openai_chunk(
"stream",
@ -449,66 +444,64 @@ fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<Ch
None,
))
}
_ => Err(TransformError::UnsupportedContent("Unsupported content block type in stream start".to_string())),
_ => Err(TransformError::UnsupportedContent(
"Unsupported content block type in stream start".to_string(),
)),
}
}
/// Convert content delta to OpenAI chunk
fn convert_content_delta(delta: MessagesContentDelta) -> Result<ChatCompletionsStreamResponse, TransformError> {
fn convert_content_delta(
delta: MessagesContentDelta,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match delta {
MessagesContentDelta::TextDelta { text } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
MessagesContentDelta::ThinkingDelta { thinking } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(format!("thinking: {}", thinking)),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
MessagesContentDelta::InputJsonDelta { partial_json } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(partial_json),
}),
}]),
},
None,
None,
))
}
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(format!("thinking: {}", thinking)),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(partial_json),
}),
}]),
},
None,
None,
)),
}
}
@ -518,7 +511,7 @@ fn create_openai_chunk(
model: &str,
delta: MessageDelta,
finish_reason: Option<FinishReason>,
usage: Option<Usage>
usage: Option<Usage>,
) -> ChatCompletionsStreamResponse {
ChatCompletionsStreamResponse {
id: id.to_string(),
@ -555,7 +548,9 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
}
/// Convert Anthropic content blocks to OpenAI message content
fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Result<MessageContent, TransformError> {
fn convert_anthropic_content_to_openai(
content: &[MessagesContentBlock],
) -> Result<MessageContent, TransformError> {
let mut text_parts = Vec::new();
for block in content {
@ -592,9 +587,11 @@ impl Into<FinishReason> for MessagesStopReason {
/// Convert Bedrock Message to OpenAI content and tool calls
/// This function extracts text content and tool calls from a Bedrock message
fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Message) -> Result<(Option<String>, Option<Vec<crate::apis::openai::ToolCall>>), TransformError> {
fn convert_bedrock_message_to_openai(
message: &crate::apis::amazon_bedrock::Message,
) -> Result<(Option<String>, Option<Vec<crate::apis::openai::ToolCall>>), TransformError> {
use crate::apis::amazon_bedrock::ContentBlock;
use crate::apis::openai::{ToolCall, FunctionCall};
use crate::apis::openai::{FunctionCall, ToolCall};
let mut text_content = String::new();
let mut tool_calls = Vec::new();
@ -614,12 +611,20 @@ fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Mess
},
});
}
_ => continue,
_ => continue,
}
}
let content = if text_content.is_empty() { None } else { Some(text_content) };
let tool_calls = if tool_calls.is_empty() { None } else { Some(tool_calls) };
let content = if text_content.is_empty() {
None
} else {
Some(text_content)
};
let tool_calls = if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
};
Ok((content, tool_calls))
}
@ -628,8 +633,8 @@ fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Mess
mod tests {
use super::*;
use crate::apis::amazon_bedrock::{
ConverseResponse, ConverseOutput, Message as BedrockMessage, ConversationRole,
ContentBlock, StopReason, BedrockTokenUsage, ConverseTrace, PromptRouterTrace
BedrockTokenUsage, ContentBlock, ConversationRole, ConverseOutput, ConverseResponse,
ConverseTrace, Message as BedrockMessage, PromptRouterTrace, StopReason,
};
use crate::apis::openai::{ChatCompletionsResponse, FinishReason, Role};
use serde_json::json;
@ -640,11 +645,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::Assistant,
content: vec![
ContentBlock::Text {
text: "Hello! How can I help you today?".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "Hello! How can I help you today?".to_string(),
}],
},
},
stop_reason: StopReason::EndTurn,
@ -671,7 +674,10 @@ mod tests {
let choice = &openai_response.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.message.role, Role::Assistant);
assert_eq!(choice.message.content, Some("Hello! How can I help you today?".to_string()));
assert_eq!(
choice.message.content,
Some("Hello! How can I help you today?".to_string())
);
assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
assert!(choice.message.tool_calls.is_none());
@ -699,7 +705,7 @@ mod tests {
"location": "San Francisco"
}),
},
}
},
],
},
},
@ -718,11 +724,21 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::ToolCalls));
assert_eq!(openai_response.choices[0].message.content, Some("I'll help you check the weather.".to_string()));
assert_eq!(
openai_response.choices[0].finish_reason,
Some(FinishReason::ToolCalls)
);
assert_eq!(
openai_response.choices[0].message.content,
Some("I'll help you check the weather.".to_string())
);
// Check tool calls
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap();
let tool_calls = openai_response.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
@ -750,11 +766,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::Assistant,
content: vec![
ContentBlock::Text {
text: "Test response".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "Test response".to_string(),
}],
},
},
stop_reason: bedrock_stop_reason,
@ -771,7 +785,10 @@ mod tests {
};
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
assert_eq!(openai_response.choices[0].finish_reason, Some(expected_openai_finish_reason));
assert_eq!(
openai_response.choices[0].finish_reason,
Some(expected_openai_finish_reason)
);
}
}
@ -798,7 +815,7 @@ mod tests {
name: "lookup".to_string(),
input: json!({"id": "12345"}),
},
}
},
],
},
},
@ -817,23 +834,35 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::ToolCalls));
assert_eq!(openai_response.choices[0].message.content, Some("I'll help with multiple tasks.".to_string()));
assert_eq!(
openai_response.choices[0].finish_reason,
Some(FinishReason::ToolCalls)
);
assert_eq!(
openai_response.choices[0].message.content,
Some("I'll help with multiple tasks.".to_string())
);
// Check multiple tool calls
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap();
let tool_calls = openai_response.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 2);
// First tool call
assert_eq!(tool_calls[0].id, "tool_1");
assert_eq!(tool_calls[0].function.name, "search");
let args1: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
let args1: serde_json::Value =
serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
assert_eq!(args1["query"], "weather");
// Second tool call
assert_eq!(tool_calls[1].id, "tool_2");
assert_eq!(tool_calls[1].function.name, "lookup");
let args2: serde_json::Value = serde_json::from_str(&tool_calls[1].function.arguments).unwrap();
let args2: serde_json::Value =
serde_json::from_str(&tool_calls[1].function.arguments).unwrap();
assert_eq!(args2["id"], "12345");
}
@ -856,7 +885,7 @@ mod tests {
},
ContentBlock::Text {
text: "Second part.".to_string(),
}
},
],
},
},
@ -876,10 +905,17 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
// Content should be combined text parts (no separator added)
assert_eq!(openai_response.choices[0].message.content, Some("First part. Second part.".to_string()));
assert_eq!(
openai_response.choices[0].message.content,
Some("First part. Second part.".to_string())
);
// Should have one tool call
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap();
let tool_calls = openai_response.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "tool_mid");
assert_eq!(tool_calls[0].function.name, "calculate");
@ -891,15 +927,13 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::Assistant,
content: vec![
ContentBlock::ToolUse {
tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
tool_use_id: "tool_only".to_string(),
name: "action".to_string(),
input: json!({}),
},
}
],
content: vec![ContentBlock::ToolUse {
tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
tool_use_id: "tool_only".to_string(),
name: "action".to_string(),
input: json!({}),
},
}],
},
},
stop_reason: StopReason::ToolUse,
@ -921,7 +955,11 @@ mod tests {
assert_eq!(openai_response.choices[0].message.content, None);
// Should have tool call
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap();
let tool_calls = openai_response.choices[0]
.message
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "tool_only");
}
@ -940,7 +978,7 @@ mod tests {
name: "test_function".to_string(),
input: json!({"param": "value"}),
},
}
},
],
};
@ -953,7 +991,8 @@ mod tests {
assert_eq!(tool_calls[0].id, "test_tool");
assert_eq!(tool_calls[0].function.name, "test_function");
let args: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
let args: serde_json::Value =
serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
assert_eq!(args["param"], "value");
}
@ -964,11 +1003,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::Assistant,
content: vec![
ContentBlock::Text {
text: "I am an assistant".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "I am an assistant".to_string(),
}],
},
},
stop_reason: StopReason::EndTurn,
@ -992,11 +1029,9 @@ mod tests {
output: ConverseOutput::Message {
message: BedrockMessage {
role: ConversationRole::User,
content: vec![
ContentBlock::Text {
text: "I am a user".to_string(),
}
],
content: vec![ContentBlock::Text {
text: "I am a user".to_string(),
}],
},
},
stop_reason: StopReason::EndTurn,
@ -1049,7 +1084,10 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
// Should extract model ID from trace
assert_eq!(openai_response.model, "anthropic.claude-3-sonnet-20240229-v1:0");
assert_eq!(
openai_response.model,
"anthropic.claude-3-sonnet-20240229-v1:0"
);
// Test fallback when no trace information is available
let bedrock_response_no_trace = ConverseResponse {
@ -1074,7 +1112,8 @@ mod tests {
performance_config: None,
};
let openai_response_fallback: ChatCompletionsResponse = bedrock_response_no_trace.try_into().unwrap();
let openai_response_fallback: ChatCompletionsResponse =
bedrock_response_no_trace.try_into().unwrap();
// Should use fallback model name
assert_eq!(openai_response_fallback.model, "bedrock-model");
@ -1082,7 +1121,7 @@ mod tests {
#[test]
fn test_bedrock_to_openai_with_multimedia_content() {
use crate::apis::amazon_bedrock::{ImageSource};
use crate::apis::amazon_bedrock::ImageSource;
let bedrock_response = ConverseResponse {
output: ConverseOutput::Message {
@ -1118,7 +1157,10 @@ mod tests {
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::Stop));
assert_eq!(
openai_response.choices[0].finish_reason,
Some(FinishReason::Stop)
);
let content = openai_response.choices[0].message.content.as_ref().unwrap();