mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
cargo fmt
This commit is contained in:
parent
aec052a843
commit
d35d068d0d
25 changed files with 1978 additions and 1258 deletions
|
|
@ -1,17 +1,17 @@
|
|||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{ModelAlias, ModelUsagePreference};
|
||||
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_IS_STREAMING_HEADER};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
|
|
@ -31,14 +31,19 @@ pub async fn chat(
|
|||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
|
||||
debug!(
|
||||
"Received request body (raw utf8): {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
|
|
@ -79,18 +84,30 @@ pub async fn chat(
|
|||
|
||||
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
|
||||
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
|
||||
match ProviderRequestType::try_from((client_request, &SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
|
||||
match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
hermesllm::apis::OpenAIApi::ChatCompletions,
|
||||
),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(ProviderRequestType::MessagesRequest(_) | ProviderRequestType::BedrockConverse(_) | ProviderRequestType::BedrockConverseStream(_)) => {
|
||||
Ok(
|
||||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
| ProviderRequestType::BedrockConverseStream(_),
|
||||
) => {
|
||||
// This should not happen after conversion to OpenAI format
|
||||
warn!("Unexpected: got MessagesRequest after converting to OpenAI format");
|
||||
let err_msg = "Request conversion failed".to_string();
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
},
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
|
||||
warn!(
|
||||
"Failed to convert request to ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
let err_msg = format!("Failed to convert request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
|
|
@ -108,28 +125,29 @@ pub async fn chat(
|
|||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let usage_preferences_str: Option<String> =
|
||||
routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
let latest_message_for_log =
|
||||
chat_completions_request_for_arch_router
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
let latest_message_for_log = chat_completions_request_for_arch_router
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
|
||||
let truncated: String = latest_message_for_log.chars().take(MAX_MESSAGE_LENGTH).collect();
|
||||
let truncated: String = latest_message_for_log
|
||||
.chars()
|
||||
.take(MAX_MESSAGE_LENGTH)
|
||||
.collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
latest_message_for_log
|
||||
|
|
@ -155,12 +173,11 @@ pub async fn chat(
|
|||
Ok(route) => match route {
|
||||
Some((_, model_name)) => model_name,
|
||||
None => {
|
||||
info!(
|
||||
info!(
|
||||
"No route determined, using default model from request: {}",
|
||||
chat_completions_request_for_arch_router.model
|
||||
);
|
||||
chat_completions_request_for_arch_router.model.clone()
|
||||
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
|
|
|
|||
|
|
@ -68,8 +68,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
&serde_json::to_string(arch_config.as_ref()).unwrap()
|
||||
);
|
||||
|
||||
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
info!("llm provider url: {}", llm_provider_url);
|
||||
info!("listening on http://{}", bind_address);
|
||||
|
|
@ -96,7 +96,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let model_aliases = Arc::new(arch_config.model_aliases.clone());
|
||||
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
|
|
@ -108,7 +107,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let service = service_fn(move |req| {
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
|
@ -118,7 +116,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
chat(req, router_service, fully_qualified_url, model_aliases)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use common::{
|
||||
configuration::{ModelUsagePreference, RoutingPreference},
|
||||
};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
|
||||
use common::configuration::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,27 @@
|
|||
use std::sync::OnceLock;
|
||||
use std::fmt;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use opentelemetry::global;
|
||||
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
|
||||
use opentelemetry_stdout::SpanExporter;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields};
|
||||
use tracing::{Event, Subscriber};
|
||||
use time::macros::format_description;
|
||||
use tracing::{Event, Subscriber};
|
||||
use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
struct BracketedTime;
|
||||
|
||||
impl FormatTime for BracketedTime {
|
||||
fn format_time(&self, w: &mut format::Writer<'_>) -> fmt::Result {
|
||||
let now = time::OffsetDateTime::now_utc();
|
||||
write!(w, "[{}]", now.format(&format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]")).unwrap())
|
||||
write!(
|
||||
w,
|
||||
"[{}]",
|
||||
now.format(&format_description!(
|
||||
"[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -34,7 +41,11 @@ where
|
|||
let timer = BracketedTime;
|
||||
timer.format_time(&mut writer)?;
|
||||
|
||||
write!(writer, "[{}] ", event.metadata().level().to_string().to_lowercase())?;
|
||||
write!(
|
||||
writer,
|
||||
"[{}] ",
|
||||
event.metadata().level().to_string().to_lowercase()
|
||||
)?;
|
||||
|
||||
ctx.field_format().format_fields(writer.by_ref(), event)?;
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ pub fn get_llm_provider(
|
|||
return provider;
|
||||
}
|
||||
|
||||
|
||||
if llm_providers.default().is_some() {
|
||||
return llm_providers.default().unwrap();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
|
||||
use thiserror::Error;
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
|
|
@ -56,10 +56,7 @@ impl ApiDefinition for AmazonBedrockApi {
|
|||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
AmazonBedrockApi::Converse,
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
]
|
||||
vec![AmazonBedrockApi::Converse, AmazonBedrockApi::ConverseStream]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -173,7 +170,9 @@ impl ProviderRequest for ConverseRequest {
|
|||
SystemContentBlock::Text { text } => {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
SystemContentBlock::GuardContent { text: Some(guard_text) } => {
|
||||
SystemContentBlock::GuardContent {
|
||||
text: Some(guard_text),
|
||||
} => {
|
||||
text_parts.push(guard_text.text.clone());
|
||||
}
|
||||
SystemContentBlock::GuardContent { text: None } => {
|
||||
|
|
@ -194,11 +193,9 @@ impl ProviderRequest for ConverseRequest {
|
|||
.find(|msg| msg.role == ConversationRole::User)
|
||||
.and_then(|msg| {
|
||||
// Extract the first text content block from the user message
|
||||
msg.content.iter().find_map(|content| {
|
||||
match content {
|
||||
ContentBlock::Text { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
}
|
||||
msg.content.iter().find_map(|content| match content {
|
||||
ContentBlock::Text { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
@ -294,13 +291,13 @@ pub enum ConversationRole {
|
|||
#[serde(untagged)]
|
||||
pub enum ContentBlock {
|
||||
Text {
|
||||
text: String
|
||||
text: String,
|
||||
},
|
||||
Image {
|
||||
image: ImageBlock
|
||||
image: ImageBlock,
|
||||
},
|
||||
Document {
|
||||
document: DocumentBlock
|
||||
document: DocumentBlock,
|
||||
},
|
||||
ToolUse {
|
||||
#[serde(rename = "toolUse")]
|
||||
|
|
@ -360,9 +357,7 @@ pub enum SystemContentBlock {
|
|||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "guardContent")]
|
||||
GuardContent {
|
||||
text: Option<GuardContentText>,
|
||||
},
|
||||
GuardContent { text: Option<GuardContentText> },
|
||||
}
|
||||
|
||||
/// Image source for vision capabilities
|
||||
|
|
@ -723,10 +718,12 @@ pub struct ContentBlockDeltaEvent {
|
|||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentBlockDelta {
|
||||
Text { text: String },
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
ToolUse {
|
||||
#[serde(rename = "toolUse")]
|
||||
tool_use: ToolUseDelta
|
||||
tool_use: ToolUseDelta,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -1008,7 +1005,9 @@ impl Into<String> for ConverseStreamEvent {
|
|||
ConverseStreamEvent::Metadata { .. } => "metadata",
|
||||
ConverseStreamEvent::InternalServerException { .. } => "internal_server_exception",
|
||||
ConverseStreamEvent::ModelStreamErrorException { .. } => "model_stream_error_exception",
|
||||
ConverseStreamEvent::ServiceUnavailableException { .. } => "service_unavailable_exception",
|
||||
ConverseStreamEvent::ServiceUnavailableException { .. } => {
|
||||
"service_unavailable_exception"
|
||||
}
|
||||
ConverseStreamEvent::ThrottlingException { .. } => "throttling_exception",
|
||||
ConverseStreamEvent::ValidationException { .. } => "validation_exception",
|
||||
};
|
||||
|
|
@ -1019,17 +1018,14 @@ impl Into<String> for ConverseStreamEvent {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Implement ProviderStreamResponse for ConverseStreamEvent
|
||||
impl ProviderStreamResponse for ConverseStreamEvent {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
ConverseStreamEvent::ContentBlockDelta(event) => {
|
||||
match &event.delta {
|
||||
ContentBlockDelta::Text { text } => Some(text),
|
||||
ContentBlockDelta::ToolUse { .. } => None,
|
||||
}
|
||||
}
|
||||
ConverseStreamEvent::ContentBlockDelta(event) => match &event.delta {
|
||||
ContentBlockDelta::Text { text } => Some(text),
|
||||
ContentBlockDelta::ToolUse { .. } => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -1099,7 +1095,10 @@ mod tests {
|
|||
};
|
||||
|
||||
let serialized = serde_json::to_value(&tool).unwrap();
|
||||
println!("Tool serialization: {}", serde_json::to_string_pretty(&serialized).unwrap());
|
||||
println!(
|
||||
"Tool serialization: {}",
|
||||
serde_json::to_string_pretty(&serialized).unwrap()
|
||||
);
|
||||
|
||||
// Verify the structure matches Bedrock API expectations
|
||||
assert!(serialized.get("toolSpec").is_some());
|
||||
|
|
@ -1107,16 +1106,24 @@ mod tests {
|
|||
|
||||
let tool_spec = serialized.get("toolSpec").unwrap();
|
||||
assert_eq!(tool_spec.get("name").unwrap(), "get_weather");
|
||||
assert_eq!(tool_spec.get("description").unwrap(), "Get the current weather for a specified city");
|
||||
assert_eq!(
|
||||
tool_spec.get("description").unwrap(),
|
||||
"Get the current weather for a specified city"
|
||||
);
|
||||
assert!(tool_spec.get("inputSchema").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_serialization_format() {
|
||||
// Test Auto choice
|
||||
let auto_choice = ToolChoice::Auto { auto: AutoChoice {} };
|
||||
let auto_choice = ToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
};
|
||||
let serialized = serde_json::to_value(&auto_choice).unwrap();
|
||||
println!("Auto ToolChoice serialization: {}", serde_json::to_string_pretty(&serialized).unwrap());
|
||||
println!(
|
||||
"Auto ToolChoice serialization: {}",
|
||||
serde_json::to_string_pretty(&serialized).unwrap()
|
||||
);
|
||||
|
||||
assert!(serialized.get("auto").is_some());
|
||||
assert!(serialized.get("type").is_none()); // Should not have a type field
|
||||
|
|
@ -1124,11 +1131,14 @@ mod tests {
|
|||
// Test Tool choice
|
||||
let tool_choice = ToolChoice::Tool {
|
||||
tool: ToolChoiceSpec {
|
||||
name: "get_weather".to_string()
|
||||
}
|
||||
name: "get_weather".to_string(),
|
||||
},
|
||||
};
|
||||
let serialized = serde_json::to_value(&tool_choice).unwrap();
|
||||
println!("Tool ToolChoice serialization: {}", serde_json::to_string_pretty(&serialized).unwrap());
|
||||
println!(
|
||||
"Tool ToolChoice serialization: {}",
|
||||
serde_json::to_string_pretty(&serialized).unwrap()
|
||||
);
|
||||
|
||||
assert!(serialized.get("tool").is_some());
|
||||
assert!(serialized.get("type").is_none()); // Should not have a type field
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::collections::HashSet;
|
||||
use bytes::Buf;
|
||||
use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
use aws_smithy_eventstream::frame::MessageFrameDecoder;
|
||||
use bytes::Buf;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// AWS Event Stream frame decoder wrapper
|
||||
pub struct BedrockBinaryFrameDecoder<B>
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use super::ApiDefinition;
|
|||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::{MESSAGES_PATH};
|
||||
use crate::MESSAGES_PATH;
|
||||
|
||||
// Enum for all supported Anthropic APIs
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
|
@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi {
|
|||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
AnthropicApi::Messages,
|
||||
]
|
||||
vec![AnthropicApi::Messages]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,7 +98,6 @@ pub struct McpServer {
|
|||
pub tool_configuration: Option<McpToolConfiguration>,
|
||||
}
|
||||
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesRequest {
|
||||
|
|
@ -121,10 +118,8 @@ pub struct MessagesRequest {
|
|||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub tools: Option<Vec<MessagesTool>>,
|
||||
pub tool_choice: Option<MessagesToolChoice>,
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Messages API specific types
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
|
|
@ -235,34 +230,21 @@ impl ExtractText for Vec<MessagesContentBlock> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesImageSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
Base64 { media_type: String, data: String },
|
||||
Url { url: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesDocumentSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
File {
|
||||
file_id: String,
|
||||
},
|
||||
Base64 { media_type: String, data: String },
|
||||
Url { url: String },
|
||||
File { file_id: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent {
|
|||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessagesMessageContent::Single(text) => text.clone(),
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text()
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -320,7 +302,6 @@ pub struct MessagesToolChoice {
|
|||
pub disable_parallel_tool_use: Option<bool>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesStopReason {
|
||||
|
|
@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse {
|
|||
Some(self)
|
||||
}
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
|
||||
Some((
|
||||
self.usage.input_tokens as usize,
|
||||
self.usage.output_tokens as usize,
|
||||
(self.usage.input_tokens + self.usage.output_tokens) as usize,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest {
|
|||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
return &self.metadata;
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
|
|
@ -572,13 +557,11 @@ impl MessagesRole {
|
|||
impl ProviderStreamResponse for MessagesStreamEvent {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Some(text),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Some(text),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -627,7 +610,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields are properly set
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -687,7 +671,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -730,7 +715,10 @@ mod tests {
|
|||
assert_eq!(serialized_json["messages"], original_json["messages"]);
|
||||
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
|
||||
assert_eq!(serialized_json["system"], original_json["system"]);
|
||||
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
|
||||
assert_eq!(
|
||||
serialized_json["service_tier"],
|
||||
original_json["service_tier"]
|
||||
);
|
||||
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
|
||||
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
|
||||
|
||||
|
|
@ -818,7 +806,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate top-level fields
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -833,7 +822,10 @@ mod tests {
|
|||
|
||||
// Validate text content block
|
||||
if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] {
|
||||
assert_eq!(text, "What can you see in this image and what's the weather like?");
|
||||
assert_eq!(
|
||||
text,
|
||||
"What can you see in this image and what's the weather like?"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
|
@ -861,20 +853,32 @@ mod tests {
|
|||
|
||||
// Validate thinking content block
|
||||
if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] {
|
||||
assert_eq!(thinking, "Let me analyze the image and then check the weather...");
|
||||
assert_eq!(
|
||||
thinking,
|
||||
"Let me analyze the image and then check the weather..."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected thinking content block");
|
||||
}
|
||||
|
||||
// Validate text content block
|
||||
if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] {
|
||||
assert_eq!(text, "I can see the image. Let me check the weather for you.");
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see the image. Let me check the weather for you."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
||||
// Validate tool use content block
|
||||
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
ref id,
|
||||
ref name,
|
||||
ref input,
|
||||
..
|
||||
} = content_blocks[2]
|
||||
{
|
||||
assert_eq!(id, "toolu_weather123");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco, CA");
|
||||
|
|
@ -892,7 +896,10 @@ mod tests {
|
|||
|
||||
let tool = &tools[0];
|
||||
assert_eq!(tool.name, "get_weather");
|
||||
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
|
||||
assert_eq!(
|
||||
tool.description,
|
||||
Some("Get current weather information for a location".to_string())
|
||||
);
|
||||
assert_eq!(tool.input_schema["type"], "object");
|
||||
assert!(tool.input_schema["properties"]["location"].is_object());
|
||||
|
||||
|
|
@ -938,10 +945,16 @@ mod tests {
|
|||
assert_eq!(deserialized_mcp.name, "test-server");
|
||||
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
|
||||
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
|
||||
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
|
||||
assert_eq!(
|
||||
deserialized_mcp.authorization_token,
|
||||
Some("secret-token".to_string())
|
||||
);
|
||||
|
||||
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
|
||||
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
|
||||
assert_eq!(
|
||||
tool_config.allowed_tools,
|
||||
Some(vec!["tool1".to_string(), "tool2".to_string()])
|
||||
);
|
||||
assert_eq!(tool_config.enabled, Some(true));
|
||||
} else {
|
||||
panic!("Expected tool configuration");
|
||||
|
|
@ -957,7 +970,8 @@ mod tests {
|
|||
"url": "https://minimal.com/mcp"
|
||||
});
|
||||
|
||||
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
|
||||
let deserialized_minimal: McpServer =
|
||||
serde_json::from_value(minimal_mcp_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_minimal.name, "minimal-server");
|
||||
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
|
||||
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
|
||||
|
|
@ -991,12 +1005,16 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
|
||||
let deserialized_response: MessagesResponse =
|
||||
serde_json::from_value(response_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_response.id, "msg_01ABC123");
|
||||
assert_eq!(deserialized_response.obj_type, "message");
|
||||
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
|
||||
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
|
||||
assert_eq!(
|
||||
deserialized_response.stop_reason,
|
||||
MessagesStopReason::EndTurn
|
||||
);
|
||||
assert!(deserialized_response.stop_sequence.is_none());
|
||||
assert!(deserialized_response.container.is_none());
|
||||
|
||||
|
|
@ -1011,7 +1029,10 @@ mod tests {
|
|||
// Check usage
|
||||
assert_eq!(deserialized_response.usage.input_tokens, 10);
|
||||
assert_eq!(deserialized_response.usage.output_tokens, 25);
|
||||
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
|
||||
assert_eq!(
|
||||
deserialized_response.usage.cache_creation_input_tokens,
|
||||
Some(5)
|
||||
);
|
||||
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
|
||||
|
||||
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
|
||||
|
|
@ -1027,7 +1048,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
|
||||
let deserialized_event: MessagesStreamEvent =
|
||||
serde_json::from_value(stream_event_json.clone()).unwrap();
|
||||
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
|
||||
assert_eq!(index, 0);
|
||||
if let MessagesContentDelta::TextDelta { text } = delta {
|
||||
|
|
@ -1055,8 +1077,15 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use {
|
||||
let deserialized_tool_use: MessagesContentBlock =
|
||||
serde_json::from_value(tool_use_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
ref id,
|
||||
ref name,
|
||||
ref input,
|
||||
..
|
||||
} = deserialized_tool_use
|
||||
{
|
||||
assert_eq!(id, "toolu_01ABC123");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco, CA");
|
||||
|
|
@ -1079,8 +1108,15 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result {
|
||||
let deserialized_tool_result: MessagesContentBlock =
|
||||
serde_json::from_value(tool_result_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
ref tool_use_id,
|
||||
ref is_error,
|
||||
ref content,
|
||||
..
|
||||
} = deserialized_tool_result
|
||||
{
|
||||
assert_eq!(tool_use_id, "toolu_01ABC123");
|
||||
assert!(is_error.is_none());
|
||||
if let ToolResultContent::Blocks(blocks) = content {
|
||||
|
|
@ -1229,7 +1265,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize the complex MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(complex_request_json.clone()).unwrap();
|
||||
|
||||
// Verify basic fields
|
||||
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
|
||||
|
|
@ -1239,8 +1276,15 @@ mod tests {
|
|||
// Verify system message with cache_control
|
||||
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
|
||||
assert_eq!(system_blocks.len(), 2);
|
||||
if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] {
|
||||
assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude.");
|
||||
if let MessagesContentBlock::Text {
|
||||
text,
|
||||
cache_control,
|
||||
} = &system_blocks[0]
|
||||
{
|
||||
assert_eq!(
|
||||
text,
|
||||
"You are Claude Code, Anthropic's official CLI for Claude."
|
||||
);
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
} else {
|
||||
panic!("Expected text system message with cache_control");
|
||||
|
|
@ -1253,7 +1297,13 @@ mod tests {
|
|||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, MessagesRole::Assistant);
|
||||
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
|
||||
if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
cache_control,
|
||||
} = &content_blocks[0]
|
||||
{
|
||||
assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl");
|
||||
assert_eq!(name, "TodoWrite");
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
|
|
@ -1272,7 +1322,12 @@ mod tests {
|
|||
let user_message = &deserialized_request.messages[2];
|
||||
assert_eq!(user_message.role, MessagesRole::User);
|
||||
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
|
||||
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] {
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
..
|
||||
} = &content_blocks[0]
|
||||
{
|
||||
assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl");
|
||||
if let ToolResultContent::Text(text) = content {
|
||||
assert!(text.contains("Todos have been modified successfully"));
|
||||
|
|
@ -1284,7 +1339,11 @@ mod tests {
|
|||
}
|
||||
|
||||
// Verify text content with cache_control
|
||||
if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] {
|
||||
if let MessagesContentBlock::Text {
|
||||
text,
|
||||
cache_control,
|
||||
} = &content_blocks[2]
|
||||
{
|
||||
assert_eq!(text, "try again");
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
} else {
|
||||
|
|
@ -1296,11 +1355,15 @@ mod tests {
|
|||
|
||||
// Test serialization round-trip
|
||||
let serialized_request = serde_json::to_value(&deserialized_request).unwrap();
|
||||
let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap();
|
||||
let re_deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(serialized_request).unwrap();
|
||||
|
||||
// Verify round-trip consistency
|
||||
assert_eq!(deserialized_request.model, re_deserialized_request.model);
|
||||
assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len());
|
||||
assert_eq!(
|
||||
deserialized_request.messages.len(),
|
||||
re_deserialized_request.messages.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1339,7 +1402,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap();
|
||||
let deserialized_event: MessagesStreamEvent =
|
||||
serde_json::from_value(thinking_delta_json.clone()).unwrap();
|
||||
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
|
||||
assert_eq!(index, 0);
|
||||
if let MessagesContentDelta::ThinkingDelta { thinking } = delta {
|
||||
|
|
@ -1352,7 +1416,10 @@ mod tests {
|
|||
}
|
||||
|
||||
// Test that thinking delta is returned by content_delta()
|
||||
assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current"));
|
||||
assert_eq!(
|
||||
deserialized_event.content_delta(),
|
||||
Some(".\n\nI need to consider:\n1. Current")
|
||||
);
|
||||
|
||||
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
|
||||
assert_eq!(thinking_delta_json, serialized_event_json);
|
||||
|
|
@ -1376,7 +1443,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(request_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
|
||||
assert_eq!(deserialized_request.max_tokens, 2048);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub mod amazon_bedrock;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub mod sse;
|
||||
|
||||
// Explicit exports to avoid naming conflicts
|
||||
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
|
||||
pub use openai::{OpenAIApi, ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
|
||||
pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice};
|
||||
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
|
||||
pub use amazon_bedrock::{Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice};
|
||||
pub use amazon_bedrock::{
|
||||
Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
|
||||
};
|
||||
pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent};
|
||||
pub use openai::{
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi,
|
||||
};
|
||||
pub use openai::{Message as OpenAIMessage, Tool as OpenAITool, ToolChoice as OpenAIToolChoice};
|
||||
|
||||
pub trait ApiDefinition {
|
||||
/// Returns the endpoint path for this API
|
||||
|
|
@ -56,11 +60,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_api_detection_from_endpoints() {
|
||||
// Test that we can detect APIs from endpoints using the trait
|
||||
let endpoints = vec![
|
||||
CHAT_COMPLETIONS_PATH,
|
||||
MESSAGES_PATH,
|
||||
"/v1/unknown"
|
||||
];
|
||||
let endpoints = vec![CHAT_COMPLETIONS_PATH, MESSAGES_PATH, "/v1/unknown"];
|
||||
|
||||
let mut detected_apis = Vec::new();
|
||||
|
||||
|
|
@ -74,11 +74,14 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
assert_eq!(detected_apis, vec![
|
||||
"OpenAI: ChatCompletions",
|
||||
"Anthropic: Messages",
|
||||
"Unknown API"
|
||||
]);
|
||||
assert_eq!(
|
||||
detected_apis,
|
||||
vec![
|
||||
"OpenAI: ChatCompletions",
|
||||
"Anthropic: Messages",
|
||||
"Unknown API"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ use std::collections::HashMap;
|
|||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
|
||||
use super::ApiDefinition;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::{CHAT_COMPLETIONS_PATH};
|
||||
use crate::CHAT_COMPLETIONS_PATH;
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI API ENUMERATION
|
||||
|
|
@ -46,7 +46,7 @@ impl ApiDefinition for OpenAIApi {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
}
|
||||
}
|
||||
|
|
@ -58,9 +58,7 @@ impl ApiDefinition for OpenAIApi {
|
|||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
OpenAIApi::ChatCompletions,
|
||||
]
|
||||
vec![OpenAIApi::ChatCompletions]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -190,7 +188,9 @@ impl ResponseMessage {
|
|||
pub fn to_message(&self) -> Message {
|
||||
Message {
|
||||
role: self.role.clone(),
|
||||
content: self.content.as_ref()
|
||||
content: self
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|s| MessageContent::Text(s.clone()))
|
||||
.unwrap_or(MessageContent::Text(String::new())),
|
||||
name: None, // Response messages don't have names in the same way request messages do
|
||||
|
|
@ -215,7 +215,7 @@ impl ExtractText for MessageContent {
|
|||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -274,7 +274,6 @@ pub struct ImageUrl {
|
|||
|
||||
/// A single message in a chat conversation
|
||||
|
||||
|
||||
/// A tool call made by the assistant
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
|
|
@ -374,7 +373,6 @@ pub enum StaticContentType {
|
|||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
|
||||
/// Chat completions API response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -496,7 +494,6 @@ pub struct ChatCompletionsStreamResponse {
|
|||
pub service_tier: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
/// A choice in a streaming response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -566,7 +563,6 @@ pub struct Models {
|
|||
pub data: Vec<ModelDetail>,
|
||||
}
|
||||
|
||||
|
||||
// Error type for streaming operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OpenAIStreamError {
|
||||
|
|
@ -597,13 +593,13 @@ pub enum OpenAIError {
|
|||
/// Trait Implementations
|
||||
/// ===========================================================================
|
||||
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsRequest
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
|
||||
let mut req: ChatCompletionsRequest =
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
|
||||
// Use the centralized suppression logic
|
||||
req.suppress_max_tokens_if_o3();
|
||||
req.fix_temperature_if_gpt5();
|
||||
|
|
@ -651,13 +647,18 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
}).collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
acc + " "
|
||||
+ &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -721,14 +722,14 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
||||
self.choices.first().and_then(|choice| {
|
||||
choice.delta.role.as_ref().map(|r| match r {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
|
|
@ -736,7 +737,6 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -756,7 +756,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields are properly set
|
||||
assert_eq!(deserialized_request.model, "gpt-4");
|
||||
|
|
@ -799,7 +800,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields
|
||||
assert_eq!(deserialized_request.model, "gpt-4");
|
||||
|
|
@ -836,7 +838,10 @@ mod tests {
|
|||
assert_eq!(serialized_json["messages"], original_json["messages"]);
|
||||
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
|
||||
assert_eq!(serialized_json["stream"], original_json["stream"]);
|
||||
assert_eq!(serialized_json["stream_options"], original_json["stream_options"]);
|
||||
assert_eq!(
|
||||
serialized_json["stream_options"],
|
||||
original_json["stream_options"]
|
||||
);
|
||||
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
|
||||
|
||||
// Handle temperature with floating point tolerance
|
||||
|
|
@ -917,7 +922,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate top-level fields
|
||||
assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
|
||||
|
|
@ -953,7 +959,10 @@ mod tests {
|
|||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, Role::Assistant);
|
||||
if let MessageContent::Text(text) = &assistant_message.content {
|
||||
assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you.");
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see a beautiful cityscape. Let me check the weather for you."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content for assistant message");
|
||||
}
|
||||
|
|
@ -967,7 +976,10 @@ mod tests {
|
|||
assert_eq!(tool_call.id, "call_weather123");
|
||||
assert_eq!(tool_call.call_type, "function");
|
||||
assert_eq!(tool_call.function.name, "get_weather");
|
||||
assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}");
|
||||
assert_eq!(
|
||||
tool_call.function.arguments,
|
||||
"{\"location\": \"New York, NY\"}"
|
||||
);
|
||||
|
||||
// Validate third message (tool response)
|
||||
let tool_message = &deserialized_request.messages[2];
|
||||
|
|
@ -977,7 +989,10 @@ mod tests {
|
|||
} else {
|
||||
panic!("Expected text content for tool message");
|
||||
}
|
||||
assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string()));
|
||||
assert_eq!(
|
||||
tool_message.tool_call_id,
|
||||
Some("call_weather123".to_string())
|
||||
);
|
||||
|
||||
// Validate tools array
|
||||
assert!(deserialized_request.tools.is_some());
|
||||
|
|
@ -987,7 +1002,10 @@ mod tests {
|
|||
let tool = &tools[0];
|
||||
assert_eq!(tool.tool_type, "function");
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string()));
|
||||
assert_eq!(
|
||||
tool.function.description,
|
||||
Some("Get current weather information for a location".to_string())
|
||||
);
|
||||
assert_eq!(tool.function.strict, Some(true));
|
||||
|
||||
// Validate tool parameters schema
|
||||
|
|
@ -1093,7 +1111,8 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap();
|
||||
let deserialized_assistant: Message =
|
||||
serde_json::from_value(assistant_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_assistant.role, Role::Assistant);
|
||||
if let MessageContent::Text(content) = &deserialized_assistant.content {
|
||||
assert_eq!(content, "I'll help with that.");
|
||||
|
|
@ -1142,9 +1161,13 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap();
|
||||
let deserialized_response: ResponseMessage =
|
||||
serde_json::from_value(response_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_response.role, Role::Assistant);
|
||||
assert_eq!(deserialized_response.content, Some("Response content".to_string()));
|
||||
assert_eq!(
|
||||
deserialized_response.content,
|
||||
Some("Response content".to_string())
|
||||
);
|
||||
assert!(deserialized_response.annotations.is_some());
|
||||
assert!(deserialized_response.refusal.is_none());
|
||||
assert!(deserialized_response.function_call.is_none());
|
||||
|
|
@ -1186,7 +1209,10 @@ mod tests {
|
|||
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
|
||||
|
||||
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
|
||||
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
|
||||
assert_eq!(
|
||||
required_deserialized,
|
||||
ToolChoice::Type(ToolChoiceType::Required)
|
||||
);
|
||||
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
|
||||
|
||||
// Test that invalid string values fail deserialization (type safety!)
|
||||
|
|
@ -1237,7 +1263,10 @@ mod tests {
|
|||
assert_eq!(response.created, 1756574706);
|
||||
assert_eq!(response.model, "gpt-4o-2024-08-06");
|
||||
assert_eq!(response.service_tier, Some("default".to_string()));
|
||||
assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string()));
|
||||
assert_eq!(
|
||||
response.system_fingerprint,
|
||||
Some("fp_f33640a400".to_string())
|
||||
);
|
||||
assert_eq!(response.choices.len(), 1);
|
||||
assert_eq!(response.usage.prompt_tokens, 65);
|
||||
assert_eq!(response.usage.completion_tokens, 184);
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
use std::str::FromStr;
|
||||
use std::fmt;
|
||||
use std::error::Error;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::providers::response::ProviderStreamResponse;
|
||||
use crate::providers::response::ProviderStreamResponseType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
// ============================================================================
|
||||
// SSE EVENT CONTAINER
|
||||
|
|
@ -13,19 +13,19 @@ use crate::providers::response::ProviderStreamResponseType;
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SseEvent {
|
||||
#[serde(rename = "data")]
|
||||
pub data: Option<String>, // The JSON payload after "data: "
|
||||
pub data: Option<String>, // The JSON payload after "data: "
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
|
||||
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
}
|
||||
|
||||
impl SseEvent {
|
||||
|
|
@ -48,13 +48,13 @@ impl SseEvent {
|
|||
|
||||
/// Get the parsed provider response if available
|
||||
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
|
||||
self.provider_stream_response.as_ref()
|
||||
self.provider_stream_response
|
||||
.as_ref()
|
||||
.map(|resp| resp as &dyn ProviderStreamResponse)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl FromStr for SseEvent {
|
||||
|
|
@ -75,7 +75,8 @@ impl FromStr for SseEvent {
|
|||
sse_transform_buffer: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if line.starts_with("event: ") { //used by Anthropic
|
||||
} else if line.starts_with("event: ") {
|
||||
//used by Anthropic
|
||||
let event_type = line[7..].to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
|
|
@ -123,7 +124,6 @@ impl fmt::Display for SseParseError {
|
|||
|
||||
impl Error for SseParseError {}
|
||||
|
||||
|
||||
/// Generic SSE (Server-Sent Events) streaming iterator container
|
||||
/// Parses raw SSE lines into SseEvent objects
|
||||
pub struct SseStreamIter<I>
|
||||
|
|
@ -141,7 +141,10 @@ where
|
|||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self { lines, done_seen: false }
|
||||
Self {
|
||||
lines,
|
||||
done_seen: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -151,14 +154,13 @@ impl TryFrom<&[u8]> for SseStreamIter<std::vec::IntoIter<String>> {
|
|||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
// Parse as text-based SSE format
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
Ok(SseStreamIter::new(lines.into_iter()))
|
||||
// Parse as text-based SSE format
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
Ok(SseStreamIter::new(lines.into_iter()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<I> Iterator for SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{ProviderId};
|
||||
use crate::apis::{OpenAIApi, AnthropicApi, AmazonBedrockApi, ApiDefinition};
|
||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi};
|
||||
use crate::ProviderId;
|
||||
use std::fmt;
|
||||
|
||||
/// Unified enum representing all supported API endpoints across providers
|
||||
|
|
@ -20,8 +20,12 @@ pub enum SupportedUpstreamAPIs {
|
|||
impl fmt::Display for SupportedAPIs {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => write!(f, "OpenAI API ({})", api.endpoint()),
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => write!(f, "Anthropic API ({})", api.endpoint()),
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => {
|
||||
write!(f, "OpenAI API ({})", api.endpoint())
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => {
|
||||
write!(f, "Anthropic API ({})", api.endpoint())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -48,81 +52,81 @@ impl SupportedAPIs {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str, is_streaming: bool) -> String {
|
||||
pub fn target_endpoint_for_provider(
|
||||
&self,
|
||||
provider_id: &ProviderId,
|
||||
request_path: &str,
|
||||
model_id: &str,
|
||||
is_streaming: bool,
|
||||
) -> String {
|
||||
let default_endpoint = "/v1/chat/completions".to_string();
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
|
||||
match provider_id {
|
||||
ProviderId::Anthropic => "/v1/messages".to_string(),
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") && !is_streaming {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
|
||||
ProviderId::Anthropic => "/v1/messages".to_string(),
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") && !is_streaming {
|
||||
format!("/model/{}/converse", model_id)
|
||||
} else if request_path.starts_with("/v1/") && is_streaming {
|
||||
format!("/model/{}/converse-stream", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
},
|
||||
_ => match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai{}", request_path)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Zhipu => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/api/paas/v4/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Qwen => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/compatible-mode/v1/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
if !is_streaming {
|
||||
format!("/model/{}/converse", model_id)
|
||||
} else if request_path.starts_with("/v1/") && is_streaming {
|
||||
} else {
|
||||
format!("/model/{}/converse-stream", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
_ => default_endpoint,
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai{}", request_path)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Zhipu => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/api/paas/v4/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Qwen => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/compatible-mode/v1/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
if !is_streaming {
|
||||
format!("/model/{}/converse", model_id)
|
||||
} else {
|
||||
format!("/model/{}/converse-stream", model_id)
|
||||
}
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Get all supported endpoint paths
|
||||
pub fn supported_endpoints() -> Vec<&'static str> {
|
||||
let mut endpoints = Vec::new();
|
||||
|
|
@ -164,7 +168,6 @@ mod tests {
|
|||
// Anthropic endpoints
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
|
||||
|
||||
|
||||
// Unsupported endpoints
|
||||
assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
|
||||
assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some());
|
||||
|
|
@ -177,7 +180,6 @@ mod tests {
|
|||
assert_eq!(endpoints.len(), 2); // We have 2 APIs defined
|
||||
assert!(endpoints.contains(&"/v1/chat/completions"));
|
||||
assert!(endpoints.contains(&"/v1/messages"));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -203,14 +205,25 @@ mod tests {
|
|||
|
||||
// All OpenAI endpoints should be in the result
|
||||
for endpoint in openai_endpoints {
|
||||
assert!(endpoints.contains(&endpoint), "Missing OpenAI endpoint: {}", endpoint);
|
||||
assert!(
|
||||
endpoints.contains(&endpoint),
|
||||
"Missing OpenAI endpoint: {}",
|
||||
endpoint
|
||||
);
|
||||
}
|
||||
|
||||
// All Anthropic endpoints should be in the result
|
||||
for endpoint in anthropic_endpoints {
|
||||
assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint);
|
||||
assert!(
|
||||
endpoints.contains(&endpoint),
|
||||
"Missing Anthropic endpoint: {}",
|
||||
endpoint
|
||||
);
|
||||
}
|
||||
// Total should match
|
||||
assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len());
|
||||
assert_eq!(
|
||||
endpoints.len(),
|
||||
OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
pub mod endpoints;
|
||||
pub mod lib;
|
||||
pub mod transformer;
|
||||
pub mod endpoints;
|
||||
|
||||
// Re-export the main items for easier access
|
||||
pub use endpoints::{identify_provider, SupportedAPIs};
|
||||
pub use lib::*;
|
||||
pub use endpoints::{SupportedAPIs, identify_provider};
|
||||
|
||||
// Note: transformer module contains TryFrom trait implementations that are automatically available
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
// Re-export new transformation modules for backward compatibility
|
||||
|
||||
|
||||
//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -9,10 +8,10 @@
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
use crate::apis::anthropic::*;
|
||||
use crate::apis::openai::*;
|
||||
use crate::transforms::*;
|
||||
use serde_json::json;
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
#[test]
|
||||
|
|
@ -81,11 +80,20 @@ mod tests {
|
|||
|
||||
// Check key fields are preserved
|
||||
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
|
||||
assert_eq!(original_anthropic.max_tokens, roundtrip_anthropic.max_tokens);
|
||||
assert_eq!(original_anthropic.temperature, roundtrip_anthropic.temperature);
|
||||
assert_eq!(
|
||||
original_anthropic.max_tokens,
|
||||
roundtrip_anthropic.max_tokens
|
||||
);
|
||||
assert_eq!(
|
||||
original_anthropic.temperature,
|
||||
roundtrip_anthropic.temperature
|
||||
);
|
||||
assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p);
|
||||
assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream);
|
||||
assert_eq!(original_anthropic.messages.len(), roundtrip_anthropic.messages.len());
|
||||
assert_eq!(
|
||||
original_anthropic.messages.len(),
|
||||
roundtrip_anthropic.messages.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -229,7 +237,10 @@ mod tests {
|
|||
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, Some("call_123".to_string()));
|
||||
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().name,
|
||||
Some("get_weather".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -249,7 +260,10 @@ mod tests {
|
|||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.as_ref().unwrap().arguments, Some(r#"{"location": "San Francisco"#.to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().arguments,
|
||||
Some(r#"{"location": "San Francisco"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -412,7 +426,10 @@ mod tests {
|
|||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::ContentBlockStart { index, content_block } => {
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => {
|
||||
assert_eq!(index, 0);
|
||||
match content_block {
|
||||
MessagesContentBlock::ToolUse { id, name, .. } => {
|
||||
|
|
@ -555,16 +572,28 @@ mod tests {
|
|||
// Verify tool start
|
||||
let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls[0].id, Some("call_weather".to_string()));
|
||||
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().name,
|
||||
Some("get_weather".to_string())
|
||||
);
|
||||
|
||||
// Verify argument deltas
|
||||
let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0]
|
||||
.function.as_ref().unwrap().arguments;
|
||||
.function
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.arguments;
|
||||
assert_eq!(args1, &Some(r#"{"location": "#.to_string()));
|
||||
|
||||
let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0]
|
||||
.function.as_ref().unwrap().arguments;
|
||||
assert_eq!(args2, &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string()));
|
||||
.function
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.arguments;
|
||||
assert_eq!(
|
||||
args2,
|
||||
&Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -592,14 +621,23 @@ mod tests {
|
|||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
assert_eq!(openai_resp.choices[0].finish_reason, Some(expected_openai_reason));
|
||||
assert_eq!(
|
||||
openai_resp.choices[0].finish_reason,
|
||||
Some(expected_openai_reason)
|
||||
);
|
||||
|
||||
// Test reverse conversion
|
||||
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
match roundtrip_event {
|
||||
MessagesStreamEvent::MessageDelta { delta, .. } => {
|
||||
// Note: Some precision may be lost in roundtrip due to mapping differences
|
||||
assert!(matches!(delta.stop_reason, MessagesStopReason::EndTurn | MessagesStopReason::MaxTokens | MessagesStopReason::ToolUse | MessagesStopReason::StopSequence));
|
||||
assert!(matches!(
|
||||
delta.stop_reason,
|
||||
MessagesStopReason::EndTurn
|
||||
| MessagesStopReason::MaxTokens
|
||||
| MessagesStopReason::ToolUse
|
||||
| MessagesStopReason::StopSequence
|
||||
));
|
||||
}
|
||||
_ => panic!("Expected MessageDelta after roundtrip"),
|
||||
}
|
||||
|
|
@ -632,7 +670,8 @@ mod tests {
|
|||
};
|
||||
|
||||
// Should convert to Ping when no meaningful content
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp_with_missing_data.try_into().unwrap();
|
||||
let anthropic_event: MessagesStreamEvent =
|
||||
openai_resp_with_missing_data.try_into().unwrap();
|
||||
assert!(matches!(anthropic_event, MessagesStreamEvent::Ping));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,24 +1,25 @@
|
|||
//! hermesllm: A library for translating LLM API requests and responses
|
||||
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
|
||||
|
||||
pub mod providers;
|
||||
pub mod apis;
|
||||
pub mod clients;
|
||||
pub mod providers;
|
||||
pub mod transforms;
|
||||
// Re-export important types and traits
|
||||
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
|
||||
pub use apis::sse::{SseEvent, SseStreamIter};
|
||||
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage};
|
||||
pub use providers::id::ProviderId;
|
||||
pub use apis::sse::{SseEvent, SseStreamIter};
|
||||
pub use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
|
||||
pub use providers::id::ProviderId;
|
||||
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use providers::response::{
|
||||
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
|
||||
ProviderStreamResponseType, TokenUsage,
|
||||
};
|
||||
|
||||
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
|
@ -36,49 +37,51 @@ mod tests {
|
|||
#[test]
|
||||
fn test_provider_streaming_response() {
|
||||
// Test streaming response parsing with sample SSE data
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
let client_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
let client_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api =
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test the new simplified architecture - create SseStreamIter directly
|
||||
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
|
||||
assert!(sse_iter.is_ok());
|
||||
// Test the new simplified architecture - create SseStreamIter directly
|
||||
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
|
||||
assert!(sse_iter.is_ok());
|
||||
|
||||
let mut streaming_iter = sse_iter.unwrap();
|
||||
let mut streaming_iter = sse_iter.unwrap();
|
||||
|
||||
// Test that we can iterate over SseEvents
|
||||
let first_event = streaming_iter.next();
|
||||
assert!(first_event.is_some());
|
||||
// Test that we can iterate over SseEvents
|
||||
let first_event = streaming_iter.next();
|
||||
assert!(first_event.is_some());
|
||||
|
||||
let sse_event = first_event.unwrap();
|
||||
let sse_event = first_event.unwrap();
|
||||
|
||||
// Test SseEvent properties
|
||||
assert!(!sse_event.is_done());
|
||||
assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
|
||||
// Test SseEvent properties
|
||||
assert!(!sse_event.is_done());
|
||||
assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
|
||||
|
||||
// Test that we can parse the event into a provider stream response
|
||||
let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api));
|
||||
if let Err(e) = &transformed_event {
|
||||
println!("Transform error: {:?}", e);
|
||||
}
|
||||
assert!(transformed_event.is_ok());
|
||||
// Test that we can parse the event into a provider stream response
|
||||
let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api));
|
||||
if let Err(e) = &transformed_event {
|
||||
println!("Transform error: {:?}", e);
|
||||
}
|
||||
assert!(transformed_event.is_ok());
|
||||
|
||||
let transformed_event = transformed_event.unwrap();
|
||||
let provider_response = transformed_event.provider_response();
|
||||
assert!(provider_response.is_ok());
|
||||
let transformed_event = transformed_event.unwrap();
|
||||
let provider_response = transformed_event.provider_response();
|
||||
assert!(provider_response.is_ok());
|
||||
|
||||
let stream_response = provider_response.unwrap();
|
||||
assert_eq!(stream_response.content_delta(), Some("Hello"));
|
||||
assert!(!stream_response.is_final());
|
||||
let stream_response = provider_response.unwrap();
|
||||
assert_eq!(stream_response.content_delta(), Some("Hello"));
|
||||
assert!(!stream_response.is_final());
|
||||
|
||||
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
|
||||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
|
||||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
}
|
||||
|
||||
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
|
||||
|
|
@ -95,15 +98,15 @@ mod tests {
|
|||
/// all complete frames in the buffer.
|
||||
#[test]
|
||||
fn test_amazon_bedrock_streaming_response() {
|
||||
use aws_smithy_eventstream::frame::{MessageFrameDecoder, DecodedFrame};
|
||||
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
|
||||
use bytes::{Buf, BytesMut};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Read the response.hex file from tests/e2e directory
|
||||
// Use absolute path to avoid cargo test working directory issues
|
||||
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../tests/e2e/response.hex");
|
||||
let test_file =
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
|
||||
let response_data = fs::read(&test_file)
|
||||
.unwrap_or_else(|e| panic!("Failed to read {:?}: {}", test_file, e));
|
||||
|
||||
|
|
@ -134,8 +137,13 @@ mod tests {
|
|||
simulated_network_buffer.extend_from_slice(chunk);
|
||||
offset = end;
|
||||
|
||||
println!("📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)",
|
||||
chunk_num, chunk.len(), simulated_network_buffer.len(), simulated_network_buffer.remaining());
|
||||
println!(
|
||||
"📦 Chunk {}: Received {} bytes (buffer: {} bytes total, {} bytes remaining)",
|
||||
chunk_num,
|
||||
chunk.len(),
|
||||
simulated_network_buffer.len(),
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
// Try to decode all complete frames from buffer
|
||||
// The Buf trait tracks position automatically!
|
||||
|
|
@ -146,11 +154,16 @@ mod tests {
|
|||
frame_count += 1;
|
||||
let consumed = bytes_before - simulated_network_buffer.remaining();
|
||||
|
||||
println!(" ✅ Frame {}: decoded ({} bytes, {} bytes remaining)",
|
||||
frame_count, consumed, simulated_network_buffer.remaining());
|
||||
println!(
|
||||
" ✅ Frame {}: decoded ({} bytes, {} bytes remaining)",
|
||||
frame_count,
|
||||
consumed,
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
// Get event type from headers
|
||||
let event_type = message.headers()
|
||||
let event_type = message
|
||||
.headers()
|
||||
.iter()
|
||||
.find(|h| h.name().as_str() == ":event-type")
|
||||
.and_then(|h| {
|
||||
|
|
@ -167,7 +180,9 @@ mod tests {
|
|||
if let Ok(json) = serde_json::from_slice::<serde_json::Value>(payload) {
|
||||
if event_type.as_deref() == Some("contentBlockDelta") {
|
||||
if let Some(delta) = json.get("delta") {
|
||||
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
|
||||
if let Some(text) =
|
||||
delta.get("text").and_then(|t| t.as_str())
|
||||
{
|
||||
println!(" 📝 Content: \"{}\"", text);
|
||||
content_chunks.push(text.to_string());
|
||||
}
|
||||
|
|
@ -178,7 +193,10 @@ mod tests {
|
|||
}
|
||||
Ok(DecodedFrame::Incomplete) => {
|
||||
// Not enough data for a complete frame - need more chunks
|
||||
println!(" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n", simulated_network_buffer.remaining());
|
||||
println!(
|
||||
" ⏳ Incomplete frame ({} bytes remaining) - waiting for more data\n",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
break; // Wait for next chunk
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
@ -193,7 +211,10 @@ mod tests {
|
|||
println!(" Total chunks received: {}", chunk_num);
|
||||
println!(" Total frames decoded: {}", frame_count);
|
||||
println!(" Total content chunks: {}", content_chunks.len());
|
||||
println!(" Final buffer remaining: {} bytes", simulated_network_buffer.remaining());
|
||||
println!(
|
||||
" Final buffer remaining: {} bytes",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
|
||||
if !content_chunks.is_empty() {
|
||||
let full_text = content_chunks.join("");
|
||||
|
|
@ -207,6 +228,11 @@ mod tests {
|
|||
assert!(frame_count > 0, "Should decode at least one frame");
|
||||
|
||||
// Ensure all data was consumed - if buffer has remaining bytes, it's a partial frame
|
||||
assert_eq!(simulated_network_buffer.remaining(), 0, "All bytes should be consumed, {} bytes remain", simulated_network_buffer.remaining());
|
||||
assert_eq!(
|
||||
simulated_network_buffer.remaining(),
|
||||
0,
|
||||
"All bytes should be consumed, {} bytes remain",
|
||||
simulated_network_buffer.remaining()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use std::fmt::Display;
|
||||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
|
||||
use crate::apis::{OpenAIApi, AnthropicApi, AmazonBedrockApi};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
|
|
@ -49,60 +49,77 @@ impl From<&str> for ProviderId {
|
|||
|
||||
impl ProviderId {
|
||||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs, is_streaming: bool) -> SupportedUpstreamAPIs {
|
||||
pub fn compatible_api_for_client(
|
||||
&self,
|
||||
client_api: &SupportedAPIs,
|
||||
is_streaming: bool,
|
||||
) -> SupportedUpstreamAPIs {
|
||||
match (self, client_api) {
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
|
||||
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
(
|
||||
ProviderId::Anthropic,
|
||||
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::Ollama
|
||||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(
|
||||
ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::Ollama
|
||||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
(ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::Ollama
|
||||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::OpenAIChatCompletions(_)) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(
|
||||
ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::Ollama
|
||||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// Amazon Bedrock natively supports Bedrock APIs
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(AmazonBedrockApi::ConverseStream)
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
},
|
||||
}
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(AmazonBedrockApi::ConverseStream)
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ pub mod request;
|
|||
pub mod response;
|
||||
|
||||
pub use id::ProviderId;
|
||||
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
|
||||
pub use response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage };
|
||||
pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage};
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
|
||||
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::collections::HashMap;
|
||||
#[derive(Clone)]
|
||||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
|
|
@ -124,15 +124,18 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|||
// Use SupportedApi to determine the appropriate request type
|
||||
match client_api {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
let chat_completion_request: ChatCompletionsRequest =
|
||||
ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(
|
||||
chat_completion_request,
|
||||
))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -141,37 +144,57 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|||
impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestType {
|
||||
type Error = ProviderRequestError;
|
||||
|
||||
fn try_from((client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
match (client_request, upstream_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let messages_req = MessagesRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let messages_req =
|
||||
MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to MessagesRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
|
||||
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => {
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
|
|
@ -180,7 +203,10 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => {
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
|
|
@ -188,20 +214,33 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverse(_)) => {
|
||||
let bedrock_req = ConverseRequest::try_from(messages_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert MessagesRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req =
|
||||
ConverseRequest::try_from(messages_req).map_err(|e| ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(messages_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert MessagesRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
|
|
@ -213,13 +252,10 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
(ProviderRequestType::BedrockConverseStream(_), _) => {
|
||||
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Error types for provider operations
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderRequestError {
|
||||
|
|
@ -235,19 +271,20 @@ impl fmt::Display for ProviderRequestError {
|
|||
|
||||
impl Error for ProviderRequestError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
self.source
|
||||
.as_ref()
|
||||
.map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::anthropic::AnthropicApi::Messages;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest};
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use serde_json::json;
|
||||
|
||||
|
|
@ -268,7 +305,7 @@ mod tests {
|
|||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.messages.len(), 2);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -291,7 +328,7 @@ mod tests {
|
|||
ProviderRequestType::MessagesRequest(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet");
|
||||
assert_eq!(r.messages.len(), 1);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected MessagesRequest variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -313,7 +350,7 @@ mod tests {
|
|||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.messages.len(), 2);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -337,7 +374,7 @@ mod tests {
|
|||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet");
|
||||
assert_eq!(r.messages.len(), 1);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -346,13 +383,15 @@ mod tests {
|
|||
fn test_v1_messages_to_v1_chat_completions_roundtrip() {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())),
|
||||
messages: vec![
|
||||
crate::apis::anthropic::MessagesMessage {
|
||||
role: crate::apis::anthropic::MessagesRole::User,
|
||||
content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()),
|
||||
}
|
||||
],
|
||||
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single(
|
||||
"You are a helpful assistant".to_string(),
|
||||
)),
|
||||
messages: vec![crate::apis::anthropic::MessagesMessage {
|
||||
role: crate::apis::anthropic::MessagesRole::User,
|
||||
content: crate::apis::anthropic::MessagesMessageContent::Single(
|
||||
"Hello!".to_string(),
|
||||
),
|
||||
}],
|
||||
max_tokens: 128,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
|
|
@ -368,16 +407,27 @@ mod tests {
|
|||
metadata: None,
|
||||
};
|
||||
|
||||
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed");
|
||||
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed");
|
||||
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone())
|
||||
.expect("Anthropic->OpenAI conversion failed");
|
||||
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req)
|
||||
.expect("OpenAI->Anthropic conversion failed");
|
||||
|
||||
assert_eq!(anthropic_req.model, anthropic_req2.model);
|
||||
// Compare system prompt text if present
|
||||
assert_eq!(
|
||||
anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }),
|
||||
anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None })
|
||||
anthropic_req.system.as_ref().and_then(|s| match s {
|
||||
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
|
||||
_ => None,
|
||||
}),
|
||||
anthropic_req2.system.as_ref().and_then(|s| match s {
|
||||
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
|
||||
_ => None,
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
anthropic_req.messages[0].role,
|
||||
anthropic_req2.messages[0].role
|
||||
);
|
||||
assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role);
|
||||
// Compare message content text if present
|
||||
assert_eq!(
|
||||
anthropic_req.messages[0].content.extract_text(),
|
||||
|
|
@ -386,49 +436,54 @@ mod tests {
|
|||
assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
|
||||
#[test]
|
||||
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
max_tokens: Some(128),
|
||||
stream: Some(false),
|
||||
stop: Some(vec!["\n".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
..Default::default()
|
||||
};
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
max_tokens: Some(128),
|
||||
stream: Some(false),
|
||||
stop: Some(vec!["\n".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed");
|
||||
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed");
|
||||
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone())
|
||||
.expect("OpenAI->Anthropic conversion failed");
|
||||
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req)
|
||||
.expect("Anthropic->OpenAI conversion failed");
|
||||
|
||||
assert_eq!(openai_req.model, openai_req2.model);
|
||||
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
|
||||
assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text());
|
||||
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
|
||||
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
|
||||
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
|
||||
assert_eq!(original_max_tokens, roundtrip_max_tokens);
|
||||
}
|
||||
assert_eq!(openai_req.model, openai_req2.model);
|
||||
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
|
||||
assert_eq!(
|
||||
openai_req.messages[0].content.extract_text(),
|
||||
openai_req2.messages[0].content.extract_text()
|
||||
);
|
||||
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
|
||||
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
|
||||
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
|
||||
assert_eq!(original_max_tokens, roundtrip_max_tokens);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
use serde::Serialize;
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::apis::sse::SseEvent;
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
use crate::apis::amazon_bedrock::ConverseResponse;
|
||||
use crate::apis::amazon_bedrock::ConverseStreamEvent;
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
use crate::apis::sse::SseEvent;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use crate::providers::id::ProviderId;
|
||||
|
||||
/// Trait for token usage information
|
||||
pub trait TokenUsage {
|
||||
|
|
@ -34,7 +34,6 @@ pub enum ProviderStreamResponseType {
|
|||
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
|
||||
MessagesStreamEvent(MessagesStreamEvent),
|
||||
ConverseStreamEvent(ConverseStreamEvent),
|
||||
|
||||
}
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
|
|
@ -43,7 +42,8 @@ pub trait ProviderResponse: Send + Sync {
|
|||
|
||||
/// Extract token counts for metrics
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
self.usage()
|
||||
.map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -74,7 +74,6 @@ pub trait ProviderStreamResponse: Send + Sync {
|
|||
|
||||
/// Get event type for SSE streaming (used by Anthropic)
|
||||
fn event_type(&self) -> Option<&str>;
|
||||
|
||||
}
|
||||
|
||||
impl ProviderStreamResponse for ProviderStreamResponseType {
|
||||
|
|
@ -135,59 +134,97 @@ impl Into<String> for ProviderStreamResponseType {
|
|||
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId),
|
||||
) -> Result<Self, Self::Error> {
|
||||
let upstream_api = provider_id.compatible_api_for_client(client_api, false);
|
||||
match (&upstream_api, client_api) {
|
||||
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
|
||||
}
|
||||
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::MessagesResponse(resp))
|
||||
}
|
||||
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to OpenAI ChatCompletions format using the transformer
|
||||
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
let chat_resp: ChatCompletionsResponse =
|
||||
anthropic_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
|
||||
}
|
||||
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to Anthropic Messages format using the transformer
|
||||
let messages_resp: MessagesResponse = openai_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
let messages_resp: MessagesResponse = openai_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::MessagesResponse(messages_resp))
|
||||
}
|
||||
// Amazon Bedrock transformations
|
||||
(SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to OpenAI ChatCompletions format using the transformer
|
||||
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
|
||||
}
|
||||
(SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to Anthropic Messages format using the transformer
|
||||
let messages_resp: MessagesResponse = bedrock_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
let messages_resp: MessagesResponse = bedrock_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::MessagesResponse(messages_resp))
|
||||
}
|
||||
_ => {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation"))
|
||||
}
|
||||
_ => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Unsupported API combination for response transformation",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -196,55 +233,86 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
|||
impl TryFrom<(&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedUpstreamAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
// Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion
|
||||
if bytes == b"[DONE]" && matches!(client_api, SupportedAPIs::AnthropicMessagesAPI(_)) {
|
||||
return Ok(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStop
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStop,
|
||||
));
|
||||
}
|
||||
match (upstream_api, client_api) {
|
||||
// OpenAI upstream
|
||||
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let resp = serde_json::from_slice(bytes)?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp))
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
|
||||
resp,
|
||||
))
|
||||
}
|
||||
(SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
|
||||
(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
serde_json::from_slice(bytes)?;
|
||||
let anthropic_resp = openai_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_resp))
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
anthropic_resp,
|
||||
))
|
||||
}
|
||||
|
||||
// Anthropic upstream
|
||||
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let resp = serde_json::from_slice(bytes)?;
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(resp))
|
||||
}
|
||||
(SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?;
|
||||
(
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent =
|
||||
serde_json::from_slice(bytes)?;
|
||||
let openai_resp = anthropic_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_resp))
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
|
||||
openai_resp,
|
||||
))
|
||||
}
|
||||
|
||||
// Amazon Bedrock ConverseStream upstream
|
||||
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = serde_json::from_slice(bytes)?;
|
||||
(
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent =
|
||||
serde_json::from_slice(bytes)?;
|
||||
let anthropic_resp = bedrock_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_resp))
|
||||
}
|
||||
_ => {
|
||||
Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation").into())
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
anthropic_resp,
|
||||
))
|
||||
}
|
||||
_ => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"Unsupported API combination for response transformation",
|
||||
)
|
||||
.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response
|
||||
impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
// Create a new transformed event based on the original
|
||||
let mut transformed_event = sse_event;
|
||||
|
||||
|
|
@ -252,7 +320,8 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
|
|||
if transformed_event.data.is_some() {
|
||||
let data_str = transformed_event.data.as_ref().unwrap();
|
||||
let data_bytes = data_str.as_bytes();
|
||||
let transformed_response: ProviderStreamResponseType = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
|
||||
let transformed_response: ProviderStreamResponseType =
|
||||
ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
|
||||
|
||||
// Convert to SSE string explicitly to avoid type ambiguity
|
||||
let sse_string: String = transformed_response.clone().into();
|
||||
|
|
@ -261,7 +330,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
|
|||
}
|
||||
|
||||
match (client_api, upstream_api) {
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => {
|
||||
(
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
if let Some(provider_response) = &transformed_event.provider_stream_response {
|
||||
if let Some(event_type) = provider_response.event_type() {
|
||||
// This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s)
|
||||
|
|
@ -280,19 +352,18 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
|
|||
// The sse_transform_buffer already contains the properly formatted MessageStart
|
||||
transformed_event.sse_transform_buffer = format!(
|
||||
"{}{}",
|
||||
transformed_event.sse_transform_buffer,
|
||||
content_block_start_sse,
|
||||
transformed_event.sse_transform_buffer, content_block_start_sse,
|
||||
);
|
||||
} else if event_type == "message_delta" {
|
||||
// Create ContentBlockStop event and format it using Into<String>
|
||||
let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 };
|
||||
let content_block_stop =
|
||||
MessagesStreamEvent::ContentBlockStop { index: 0 };
|
||||
let content_block_stop_sse: String = content_block_stop.into();
|
||||
|
||||
// Format as proper SSE: ContentBlockStop first, then MessageDelta
|
||||
transformed_event.sse_transform_buffer = format!(
|
||||
"{}{}",
|
||||
content_block_stop_sse,
|
||||
transformed_event.sse_transform_buffer
|
||||
content_block_stop_sse, transformed_event.sse_transform_buffer
|
||||
);
|
||||
}
|
||||
// For other event types, the sse_transform_buffer already has the correct format from Into<String>
|
||||
|
|
@ -301,7 +372,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
|
|||
// This handles cases where the transformation might not produce a valid event type
|
||||
}
|
||||
}
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
if transformed_event.is_event_only() && transformed_event.event.is_some() {
|
||||
transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI
|
||||
}
|
||||
|
|
@ -316,32 +390,56 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
|
|||
}
|
||||
|
||||
// TryFrom implementation to convert AWS Event Stream DecodedFrame to ProviderStreamResponseType
|
||||
impl TryFrom<(&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &SupportedUpstreamAPIs)> for ProviderStreamResponseType {
|
||||
impl
|
||||
TryFrom<(
|
||||
&aws_smithy_eventstream::frame::DecodedFrame,
|
||||
&SupportedAPIs,
|
||||
&SupportedUpstreamAPIs,
|
||||
)> for ProviderStreamResponseType
|
||||
{
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((frame, client_api, upstream_api): (&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &SupportedUpstreamAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(frame, client_api, upstream_api): (
|
||||
&aws_smithy_eventstream::frame::DecodedFrame,
|
||||
&SupportedAPIs,
|
||||
&SupportedUpstreamAPIs,
|
||||
),
|
||||
) -> Result<Self, Self::Error> {
|
||||
use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
|
||||
match frame {
|
||||
DecodedFrame::Complete(_) => {
|
||||
// We have a complete frame - parse it based on upstream API
|
||||
match (upstream_api, client_api) {
|
||||
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
// Parse the DecodedFrame into ConverseStreamEvent
|
||||
let bedrock_event = crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
|
||||
let anthropic_event: crate::apis::anthropic::MessagesStreamEvent = bedrock_event.try_into()?;
|
||||
let bedrock_event =
|
||||
crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
|
||||
let anthropic_event: crate::apis::anthropic::MessagesStreamEvent =
|
||||
bedrock_event.try_into()?;
|
||||
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(anthropic_event))
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
anthropic_event,
|
||||
))
|
||||
}
|
||||
(SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
(
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
// Parse the DecodedFrame into ConverseStreamEvent
|
||||
let bedrock_event = crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
|
||||
let openai_event: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_event.try_into()?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event))
|
||||
}
|
||||
_ => {
|
||||
Err("Unsupported API combination for event-stream decoding".into())
|
||||
let bedrock_event =
|
||||
crate::apis::amazon_bedrock::ConverseStreamEvent::try_from(frame)?;
|
||||
let openai_event: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
bedrock_event.try_into()?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
|
||||
openai_event,
|
||||
))
|
||||
}
|
||||
_ => Err("Unsupported API combination for event-stream decoding".into()),
|
||||
}
|
||||
}
|
||||
DecodedFrame::Incomplete => {
|
||||
|
|
@ -351,15 +449,12 @@ impl TryFrom<(&aws_smithy_eventstream::frame::DecodedFrame, &SupportedAPIs, &Sup
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderResponseError {
|
||||
pub message: String,
|
||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
|
||||
impl fmt::Display for ProviderResponseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Provider response error: {}", self.message)
|
||||
|
|
@ -368,22 +463,23 @@ impl fmt::Display for ProviderResponseError {
|
|||
|
||||
impl Error for ProviderResponseError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
self.source
|
||||
.as_ref()
|
||||
.map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::sse::SseStreamIter;
|
||||
use crate::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::sse::SseStreamIter;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use serde_json::json;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_openai_response_from_bytes() {
|
||||
let resp = json!({
|
||||
|
|
@ -402,13 +498,17 @@ mod tests {
|
|||
"system_fingerprint": null
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI));
|
||||
let result = ProviderResponseType::try_from((
|
||||
bytes.as_slice(),
|
||||
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
&ProviderId::OpenAI,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.choices.len(), 1);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -427,13 +527,17 @@ mod tests {
|
|||
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic));
|
||||
let result = ProviderResponseType::try_from((
|
||||
bytes.as_slice(),
|
||||
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
|
||||
&ProviderId::Anthropic,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(r.content.len(), 1);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -457,14 +561,18 @@ mod tests {
|
|||
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
|
||||
let result = ProviderResponseType::try_from((
|
||||
bytes.as_slice(),
|
||||
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
|
||||
&ProviderId::OpenAI,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.usage.input_tokens, 10);
|
||||
assert_eq!(r.usage.output_tokens, 25);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -495,14 +603,18 @@ mod tests {
|
|||
}
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic));
|
||||
let result = ProviderResponseType::try_from((
|
||||
bytes.as_slice(),
|
||||
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
&ProviderId::Anthropic,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(r.usage.prompt_tokens, 10);
|
||||
assert_eq!(r.usage.completion_tokens, 25);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -514,11 +626,17 @@ mod tests {
|
|||
let event: Result<SseEvent, _> = line.parse();
|
||||
assert!(event.is_ok());
|
||||
let event = event.unwrap();
|
||||
assert_eq!(event.data, Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string()));
|
||||
assert_eq!(
|
||||
event.data,
|
||||
Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string())
|
||||
);
|
||||
|
||||
// Test conversion back to line using Display trait
|
||||
let wire_format = event.to_string();
|
||||
assert_eq!(wire_format, "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n");
|
||||
assert_eq!(
|
||||
wire_format,
|
||||
"data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"
|
||||
);
|
||||
|
||||
// Test [DONE] marker - should be valid SSE event
|
||||
let done_line = "data: [DONE]";
|
||||
|
|
@ -550,10 +668,12 @@ mod tests {
|
|||
event: None,
|
||||
raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"}
|
||||
|
||||
"#.to_string(),
|
||||
"#
|
||||
.to_string(),
|
||||
sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"}
|
||||
|
||||
"#.to_string(),
|
||||
"#
|
||||
.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
|
||||
|
|
@ -590,7 +710,8 @@ mod tests {
|
|||
data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()),
|
||||
event: Some("content_block_delta".to_string()),
|
||||
raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#
|
||||
.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
assert!(!normal_event.should_skip());
|
||||
|
|
@ -616,7 +737,7 @@ mod tests {
|
|||
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
|
||||
"data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(),
|
||||
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
|
||||
"data: [DONE]".to_string(), // This should end the stream
|
||||
"data: [DONE]".to_string(), // This should end the stream
|
||||
];
|
||||
|
||||
let mut iter = SseStreamIter::new(test_lines.into_iter());
|
||||
|
|
@ -684,13 +805,15 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_provider_stream_response_event_type() {
|
||||
use crate::apis::anthropic::{MessagesStreamEvent, MessagesContentDelta};
|
||||
use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent};
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
|
||||
// Test Anthropic event type
|
||||
let anthropic_event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta { text: "Hello".to_string() },
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
};
|
||||
let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event);
|
||||
assert_eq!(provider_type.event_type(), Some("content_block_delta"));
|
||||
|
|
@ -717,15 +840,24 @@ mod tests {
|
|||
// Test that [DONE] marker is properly converted to MessageStop in the transformation layer
|
||||
let done_bytes = b"[DONE]";
|
||||
let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
crate::apis::openai::OpenAIApi::ChatCompletions,
|
||||
);
|
||||
|
||||
let result = ProviderStreamResponseType::try_from((done_bytes.as_slice(), &client_api, &upstream_api));
|
||||
let result = ProviderStreamResponseType::try_from((
|
||||
done_bytes.as_slice(),
|
||||
&client_api,
|
||||
&upstream_api,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
|
||||
if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result {
|
||||
// Verify it's a MessageStop event
|
||||
assert_eq!(event.event_type(), Some("message_stop"));
|
||||
assert!(matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStop));
|
||||
assert!(matches!(
|
||||
event,
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStop
|
||||
));
|
||||
} else {
|
||||
panic!("Expected MessagesStreamEvent::MessageStop");
|
||||
}
|
||||
|
|
@ -747,7 +879,10 @@ mod tests {
|
|||
// This signals the caller to wait for more data
|
||||
let result = decoder.decode_frame();
|
||||
assert!(result.is_some());
|
||||
assert!(matches!(result.unwrap(), aws_smithy_eventstream::frame::DecodedFrame::Incomplete));
|
||||
assert!(matches!(
|
||||
result.unwrap(),
|
||||
aws_smithy_eventstream::frame::DecodedFrame::Incomplete
|
||||
));
|
||||
|
||||
// Verify we can still access the buffer
|
||||
assert!(decoder.has_remaining());
|
||||
|
|
@ -760,8 +895,8 @@ mod tests {
|
|||
use std::path::PathBuf;
|
||||
|
||||
// Read the actual response.hex file from tests/e2e directory
|
||||
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../tests/e2e/response.hex");
|
||||
let test_file =
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
|
||||
|
||||
// Only run this test if the file exists
|
||||
if !test_file.exists() {
|
||||
|
|
@ -782,7 +917,8 @@ mod tests {
|
|||
frame_count += 1;
|
||||
|
||||
// Verify we can access headers
|
||||
let event_type = message.headers()
|
||||
let event_type = message
|
||||
.headers()
|
||||
.iter()
|
||||
.find(|h| h.name().as_str() == ":event-type")
|
||||
.and_then(|h| h.value().as_string().ok());
|
||||
|
|
@ -811,8 +947,8 @@ mod tests {
|
|||
use std::path::PathBuf;
|
||||
|
||||
// Read the actual response.hex file
|
||||
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../tests/e2e/response.hex");
|
||||
let test_file =
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
|
||||
|
||||
if !test_file.exists() {
|
||||
println!("Skipping test - response.hex not found");
|
||||
|
|
@ -863,7 +999,10 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
assert!(total_frames > 0, "Should have decoded frames from chunked data");
|
||||
assert!(
|
||||
total_frames > 0,
|
||||
"Should have decoded frames from chunked data"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -872,7 +1011,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Run with: cargo test -- --ignored --nocapture
|
||||
#[ignore] // Run with: cargo test -- --ignored --nocapture
|
||||
fn test_bedrock_decoded_frame_to_provider_response_verbose() {
|
||||
test_bedrock_conversion(true);
|
||||
}
|
||||
|
|
@ -883,7 +1022,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Run with: cargo test -- --ignored --nocapture
|
||||
#[ignore] // Run with: cargo test -- --ignored --nocapture
|
||||
fn test_bedrock_decoded_frame_with_tool_use_verbose() {
|
||||
test_bedrock_conversion_with_tools(true);
|
||||
}
|
||||
|
|
@ -894,8 +1033,8 @@ mod tests {
|
|||
use std::path::PathBuf;
|
||||
|
||||
// Read the actual response.hex file from tests/e2e directory
|
||||
let test_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../tests/e2e/response.hex");
|
||||
let test_file =
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/e2e/response.hex");
|
||||
|
||||
// Only run this test if the file exists
|
||||
if !test_file.exists() {
|
||||
|
|
@ -908,8 +1047,11 @@ mod tests {
|
|||
|
||||
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
|
||||
|
||||
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream);
|
||||
let client_api =
|
||||
SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
|
||||
);
|
||||
|
||||
let mut conversion_count = 0;
|
||||
let mut message_start_seen = false;
|
||||
|
|
@ -919,14 +1061,18 @@ mod tests {
|
|||
match decoder.decode_frame() {
|
||||
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
|
||||
// Convert DecodedFrame to ProviderStreamResponseType
|
||||
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
|
||||
let result =
|
||||
ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
|
||||
|
||||
match result {
|
||||
Ok(provider_response) => {
|
||||
conversion_count += 1;
|
||||
|
||||
// Verify we got a MessagesStreamEvent
|
||||
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_)));
|
||||
assert!(matches!(
|
||||
provider_response,
|
||||
ProviderStreamResponseType::MessagesStreamEvent(_)
|
||||
));
|
||||
|
||||
if verbose {
|
||||
// Print the SSE string output
|
||||
|
|
@ -935,8 +1081,13 @@ mod tests {
|
|||
}
|
||||
|
||||
// Check for MessageStart event
|
||||
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response {
|
||||
if matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. }) {
|
||||
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) =
|
||||
provider_response
|
||||
{
|
||||
if matches!(
|
||||
event,
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. }
|
||||
) {
|
||||
message_start_seen = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -956,7 +1107,10 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
assert!(conversion_count > 0, "Should have converted at least one frame");
|
||||
assert!(
|
||||
conversion_count > 0,
|
||||
"Should have converted at least one frame"
|
||||
);
|
||||
assert!(message_start_seen, "Should have seen MessageStart event");
|
||||
}
|
||||
|
||||
|
|
@ -980,8 +1134,11 @@ mod tests {
|
|||
|
||||
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
|
||||
|
||||
let client_api = SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream);
|
||||
let client_api =
|
||||
SupportedAPIs::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
|
||||
);
|
||||
|
||||
let mut conversion_count = 0;
|
||||
let mut message_start_seen = false;
|
||||
|
|
@ -993,14 +1150,18 @@ mod tests {
|
|||
match decoder.decode_frame() {
|
||||
Some(frame @ aws_smithy_eventstream::frame::DecodedFrame::Complete(_)) => {
|
||||
// Convert DecodedFrame to ProviderStreamResponseType
|
||||
let result = ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
|
||||
let result =
|
||||
ProviderStreamResponseType::try_from((&frame, &client_api, &upstream_api));
|
||||
|
||||
match result {
|
||||
Ok(provider_response) => {
|
||||
conversion_count += 1;
|
||||
|
||||
// Verify we got a MessagesStreamEvent
|
||||
assert!(matches!(provider_response, ProviderStreamResponseType::MessagesStreamEvent(_)));
|
||||
assert!(matches!(
|
||||
provider_response,
|
||||
ProviderStreamResponseType::MessagesStreamEvent(_)
|
||||
));
|
||||
|
||||
if verbose {
|
||||
// Print the SSE string output
|
||||
|
|
@ -1009,7 +1170,9 @@ mod tests {
|
|||
}
|
||||
|
||||
// Check for specific events related to tool use
|
||||
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) = provider_response {
|
||||
if let ProviderStreamResponseType::MessagesStreamEvent(ref event) =
|
||||
provider_response
|
||||
{
|
||||
match event {
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStart { .. } => {
|
||||
message_start_seen = true;
|
||||
|
|
@ -1041,16 +1204,25 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
assert!(conversion_count > 0, "Should have converted at least one frame");
|
||||
assert!(
|
||||
conversion_count > 0,
|
||||
"Should have converted at least one frame"
|
||||
);
|
||||
assert!(message_start_seen, "Should have seen MessageStart event");
|
||||
assert!(content_block_start_seen, "Should have seen ContentBlockStart event for tool use");
|
||||
assert!(content_block_delta_tool_use_seen, "Should have seen ContentBlockDelta with ToolUseDelta");
|
||||
assert!(
|
||||
content_block_start_seen,
|
||||
"Should have seen ContentBlockStart event for tool use"
|
||||
);
|
||||
assert!(
|
||||
content_block_delta_tool_use_seen,
|
||||
"Should have seen ContentBlockDelta with ToolUseDelta"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_transformation_openai_to_anthropic_message_start() {
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an OpenAI stream response that represents a role start (which becomes message_start in Anthropic)
|
||||
let openai_stream_chunk = json!({
|
||||
|
|
@ -1085,8 +1257,14 @@ mod tests {
|
|||
|
||||
// Verify the transformation includes both message_start and content_block_start
|
||||
let buffer = transformed.sse_transform_buffer;
|
||||
assert!(buffer.contains("event: message_start"), "Should contain message_start event");
|
||||
assert!(buffer.contains("event: content_block_start"), "Should contain content_block_start event");
|
||||
assert!(
|
||||
buffer.contains("event: message_start"),
|
||||
"Should contain message_start event"
|
||||
);
|
||||
assert!(
|
||||
buffer.contains("event: content_block_start"),
|
||||
"Should contain content_block_start event"
|
||||
);
|
||||
|
||||
// Verify proper SSE format with event lines before data lines
|
||||
assert!(buffer.find("event: message_start").unwrap() < buffer.find("data:").unwrap());
|
||||
|
|
@ -1095,8 +1273,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_sse_event_transformation_openai_to_anthropic_message_delta() {
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an OpenAI stream response with finish_reason (which becomes message_delta in Anthropic)
|
||||
let openai_stream_chunk = json!({
|
||||
|
|
@ -1136,19 +1314,28 @@ mod tests {
|
|||
|
||||
// Verify the transformation includes both content_block_stop and message_delta
|
||||
let buffer = transformed.sse_transform_buffer;
|
||||
assert!(buffer.contains("event: content_block_stop"), "Should contain content_block_stop event");
|
||||
assert!(buffer.contains("event: message_delta"), "Should contain message_delta event");
|
||||
assert!(
|
||||
buffer.contains("event: content_block_stop"),
|
||||
"Should contain content_block_stop event"
|
||||
);
|
||||
assert!(
|
||||
buffer.contains("event: message_delta"),
|
||||
"Should contain message_delta event"
|
||||
);
|
||||
|
||||
// Verify content_block_stop comes before message_delta
|
||||
let stop_pos = buffer.find("content_block_stop").unwrap();
|
||||
let delta_pos = buffer.find("message_delta").unwrap();
|
||||
assert!(stop_pos < delta_pos, "content_block_stop should come before message_delta");
|
||||
assert!(
|
||||
stop_pos < delta_pos,
|
||||
"content_block_stop should come before message_delta"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_transformation_openai_to_anthropic_content_delta() {
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic)
|
||||
let openai_stream_chunk = json!({
|
||||
|
|
@ -1183,9 +1370,18 @@ mod tests {
|
|||
|
||||
// Verify the transformation is a content_block_delta (no extra events injected)
|
||||
let buffer = transformed.sse_transform_buffer;
|
||||
assert!(buffer.contains("event: content_block_delta"), "Should contain content_block_delta event");
|
||||
assert!(!buffer.contains("content_block_start"), "Should not inject content_block_start for content delta");
|
||||
assert!(!buffer.contains("content_block_stop"), "Should not inject content_block_stop for content delta");
|
||||
assert!(
|
||||
buffer.contains("event: content_block_delta"),
|
||||
"Should contain content_block_delta event"
|
||||
);
|
||||
assert!(
|
||||
!buffer.contains("content_block_start"),
|
||||
"Should not inject content_block_start for content delta"
|
||||
);
|
||||
assert!(
|
||||
!buffer.contains("content_block_stop"),
|
||||
"Should not inject content_block_stop for content delta"
|
||||
);
|
||||
|
||||
// Verify the content is preserved
|
||||
assert!(buffer.contains("Hello"), "Should preserve the content text");
|
||||
|
|
@ -1193,8 +1389,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_sse_event_transformation_anthropic_to_openai_suppresses_event_lines() {
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an Anthropic event-only SSE line (no data)
|
||||
let sse_event = SseEvent {
|
||||
|
|
@ -1215,14 +1411,20 @@ mod tests {
|
|||
let transformed = result.unwrap();
|
||||
|
||||
// Verify the event line is suppressed (replaced with just newline)
|
||||
assert_eq!(transformed.sse_transform_buffer, "\n", "Event-only lines should be suppressed to newline for OpenAI");
|
||||
assert!(transformed.is_event_only(), "Should still be marked as event-only");
|
||||
assert_eq!(
|
||||
transformed.sse_transform_buffer, "\n",
|
||||
"Event-only lines should be suppressed to newline for OpenAI"
|
||||
);
|
||||
assert!(
|
||||
transformed.is_event_only(),
|
||||
"Should still be marked as event-only"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_transformation_anthropic_to_openai_preserves_data() {
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an Anthropic message_start event with data
|
||||
let anthropic_event = json!({
|
||||
|
|
@ -1258,7 +1460,10 @@ mod tests {
|
|||
// Verify data is transformed to OpenAI format
|
||||
let buffer = transformed.sse_transform_buffer;
|
||||
assert!(buffer.starts_with("data: "), "Should have data: prefix");
|
||||
assert!(!buffer.contains("event:"), "Should not have event: lines for OpenAI");
|
||||
assert!(
|
||||
!buffer.contains("event:"),
|
||||
"Should not have event: lines for OpenAI"
|
||||
);
|
||||
|
||||
// Verify provider response was parsed
|
||||
assert!(transformed.provider_stream_response.is_some());
|
||||
|
|
@ -1310,8 +1515,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_sse_event_transformation_preserves_provider_response() {
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an OpenAI stream response
|
||||
let openai_stream_chunk = json!({
|
||||
|
|
@ -1344,11 +1549,17 @@ mod tests {
|
|||
let transformed = result.unwrap();
|
||||
|
||||
// Verify provider_stream_response is populated
|
||||
assert!(transformed.provider_stream_response.is_some(), "Should parse and store provider response");
|
||||
assert!(
|
||||
transformed.provider_stream_response.is_some(),
|
||||
"Should parse and store provider response"
|
||||
);
|
||||
|
||||
// Verify we can access the provider response
|
||||
let provider_response = transformed.provider_response();
|
||||
assert!(provider_response.is_ok(), "Should be able to access provider response");
|
||||
assert!(
|
||||
provider_response.is_ok(),
|
||||
"Should be able to access provider response"
|
||||
);
|
||||
|
||||
// Verify the content delta is accessible
|
||||
let content = provider_response.unwrap().content_delta();
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use serde_json::Value;
|
||||
use crate::apis::anthropic::{MessagesContentBlock,MessagesImageSource};
|
||||
use crate::apis::anthropic::{MessagesContentBlock, MessagesImageSource};
|
||||
use crate::apis::openai::{ContentPart, FunctionCall, ImageUrl, Message, MessageContent, ToolCall};
|
||||
use crate::clients::TransformError;
|
||||
use serde_json::Value;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub trait ExtractText {
|
||||
|
|
@ -11,12 +11,17 @@ pub trait ExtractText {
|
|||
/// Trait for utility functions on content collections
|
||||
pub trait ContentUtils<T> {
|
||||
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
|
||||
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
|
||||
fn split_for_openai(
|
||||
&self,
|
||||
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
|
||||
}
|
||||
|
||||
/// Helper to create a current unix timestamp
|
||||
pub fn current_timestamp() -> u64 {
|
||||
SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs()
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
// Content Utilities
|
||||
|
|
@ -26,24 +31,36 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
|
|||
|
||||
for block in self {
|
||||
match block {
|
||||
MessagesContentBlock::ToolUse { id, name, input, .. } |
|
||||
MessagesContentBlock::ServerToolUse { id, name, input } |
|
||||
MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
MessagesContentBlock::ToolUse {
|
||||
id, name, input, ..
|
||||
}
|
||||
| MessagesContentBlock::ServerToolUse { id, name, input }
|
||||
| MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
let arguments = serde_json::to_string(&input)?;
|
||||
tool_calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name: name.clone(), arguments },
|
||||
function: FunctionCall {
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(if tool_calls.is_empty() { None } else { Some(tool_calls) })
|
||||
Ok(if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
})
|
||||
}
|
||||
|
||||
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError> {
|
||||
fn split_for_openai(
|
||||
&self,
|
||||
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>
|
||||
{
|
||||
let mut content_parts = Vec::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
let mut tool_results = Vec::new();
|
||||
|
|
@ -62,25 +79,55 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
|
|||
},
|
||||
});
|
||||
}
|
||||
MessagesContentBlock::ToolUse { id, name, input, .. } |
|
||||
MessagesContentBlock::ServerToolUse { id, name, input } |
|
||||
MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
MessagesContentBlock::ToolUse {
|
||||
id, name, input, ..
|
||||
}
|
||||
| MessagesContentBlock::ServerToolUse { id, name, input }
|
||||
| MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
let arguments = serde_json::to_string(&input)?;
|
||||
tool_calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name: name.clone(), arguments },
|
||||
function: FunctionCall {
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
MessagesContentBlock::ToolResult { tool_use_id, content, is_error, .. } => {
|
||||
MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
..
|
||||
} => {
|
||||
let result_text = content.extract_text();
|
||||
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false)));
|
||||
tool_results.push((
|
||||
tool_use_id.clone(),
|
||||
result_text,
|
||||
is_error.unwrap_or(false),
|
||||
));
|
||||
}
|
||||
MessagesContentBlock::WebSearchToolResult { tool_use_id, content, is_error } |
|
||||
MessagesContentBlock::CodeExecutionToolResult { tool_use_id, content, is_error } |
|
||||
MessagesContentBlock::McpToolResult { tool_use_id, content, is_error } => {
|
||||
MessagesContentBlock::WebSearchToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
}
|
||||
| MessagesContentBlock::CodeExecutionToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
}
|
||||
| MessagesContentBlock::McpToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
} => {
|
||||
let result_text = content.extract_text();
|
||||
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false)));
|
||||
tool_results.push((
|
||||
tool_use_id.clone(),
|
||||
result_text,
|
||||
is_error.unwrap_or(false),
|
||||
));
|
||||
}
|
||||
_ => {
|
||||
// Skip unsupported content types
|
||||
|
|
@ -122,29 +169,41 @@ fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource {
|
|||
data: data.to_string(),
|
||||
}
|
||||
} else {
|
||||
MessagesImageSource::Url { url: image_url.url.clone() }
|
||||
MessagesImageSource::Url {
|
||||
url: image_url.url.clone(),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MessagesImageSource::Url { url: image_url.url.clone() }
|
||||
MessagesImageSource::Url {
|
||||
url: image_url.url.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenAI message to Anthropic content blocks
|
||||
pub fn convert_openai_message_to_anthropic_content(message: &Message) -> Result<Vec<MessagesContentBlock>, TransformError> {
|
||||
pub fn convert_openai_message_to_anthropic_content(
|
||||
message: &Message,
|
||||
) -> Result<Vec<MessagesContentBlock>, TransformError> {
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
// Handle regular content
|
||||
match &message.content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None });
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::Parts(parts) => {
|
||||
for part in parts {
|
||||
match part {
|
||||
ContentPart::Text { text } => {
|
||||
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None });
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
ContentPart::ImageUrl { image_url } => {
|
||||
let source = convert_image_url_to_source(image_url);
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@
|
|||
//! by the gateway, but the external API surface remains these two standard formats.
|
||||
//! The transformations are split into logical modules for maintainability.
|
||||
|
||||
pub mod lib;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod lib;
|
||||
|
||||
// Re-export commonly used items for convenience
|
||||
pub use lib::*;
|
||||
pub use request::*;
|
||||
pub use response::*;
|
||||
pub use lib::*;
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS
|
||||
|
|
|
|||
|
|
@ -1,14 +1,21 @@
|
|||
use crate::transforms::lib::*;
|
||||
use crate::clients::TransformError;
|
||||
use crate::apis::anthropic::{MessagesMessage, MessagesRequest, MessagesMessageContent, MessagesRole, MessagesStopReason, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage, MessagesSystemPrompt, ToolResultContent};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Function, FunctionChoice,FinishReason, Usage, ContentPart};
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ConverseRequest, SystemContentBlock, InferenceConfiguration, ToolConfiguration,
|
||||
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolInputSchema, ToolSpecDefinition,
|
||||
AutoChoice, AnyChoice, ToolChoiceSpec,
|
||||
Message as BedrockMessage, ConversationRole, ContentBlock,
|
||||
ToolUseBlock, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ImageBlock, ImageSource
|
||||
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, ImageBlock,
|
||||
ImageSource, InferenceConfiguration, Message as BedrockMessage, SystemContentBlock,
|
||||
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration,
|
||||
ToolInputSchema, ToolResultBlock, ToolResultContentBlock, ToolResultStatus, ToolSpecDefinition,
|
||||
ToolUseBlock,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, MessagesStopReason,
|
||||
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesUsage,
|
||||
ToolResultContent,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, ContentPart, FinishReason, Function, FunctionChoice, Message,
|
||||
MessageContent, Role, Tool, ToolCall, ToolChoice, ToolChoiceType, Usage,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
|
|
@ -32,7 +39,8 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
|||
|
||||
// Convert tools and tool choice
|
||||
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
|
||||
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice);
|
||||
let (openai_tool_choice, parallel_tool_calls) =
|
||||
convert_anthropic_tool_choice(req.tool_choice);
|
||||
|
||||
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: req.model,
|
||||
|
|
@ -91,7 +99,8 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
|
|||
// Convert tools and tool choice to ToolConfiguration
|
||||
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
|
||||
let tools = req.tools.map(|anthropic_tools| {
|
||||
anthropic_tools.into_iter()
|
||||
anthropic_tools
|
||||
.into_iter()
|
||||
.map(|tool| BedrockTool::ToolSpec {
|
||||
tool_spec: ToolSpecDefinition {
|
||||
name: tool.name,
|
||||
|
|
@ -106,25 +115,28 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
|
|||
|
||||
let tool_choice = req.tool_choice.map(|choice| {
|
||||
match choice.kind {
|
||||
MessagesToolChoiceType::Auto => BedrockToolChoice::Auto { auto: AutoChoice {} },
|
||||
MessagesToolChoiceType::Auto => BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
},
|
||||
MessagesToolChoiceType::Any => BedrockToolChoice::Any { any: AnyChoice {} },
|
||||
MessagesToolChoiceType::None => BedrockToolChoice::Auto { auto: AutoChoice {} }, // Bedrock doesn't have explicit "none"
|
||||
MessagesToolChoiceType::None => BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
}, // Bedrock doesn't have explicit "none"
|
||||
MessagesToolChoiceType::Tool => {
|
||||
if let Some(name) = choice.name {
|
||||
BedrockToolChoice::Tool {
|
||||
tool: ToolChoiceSpec { name }
|
||||
tool: ToolChoiceSpec { name },
|
||||
}
|
||||
} else {
|
||||
BedrockToolChoice::Auto { auto: AutoChoice {} }
|
||||
BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Some(ToolConfiguration {
|
||||
tools,
|
||||
tool_choice,
|
||||
})
|
||||
Some(ToolConfiguration { tools, tool_choice })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
|
@ -147,7 +159,6 @@ impl TryFrom<AnthropicMessagesRequest> for ConverseRequest {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Message Conversions
|
||||
impl TryFrom<MessagesMessage> for Vec<Message> {
|
||||
type Error = TransformError;
|
||||
|
|
@ -186,7 +197,11 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
|
|||
role: message.role.into(),
|
||||
content,
|
||||
name: None,
|
||||
tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) },
|
||||
tool_calls: if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
},
|
||||
tool_call_id: None,
|
||||
};
|
||||
result.push(main_message);
|
||||
|
|
@ -198,7 +213,6 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Role Conversions
|
||||
impl Into<Role> for MessagesRole {
|
||||
fn into(self) -> Role {
|
||||
|
|
@ -237,9 +251,7 @@ impl Into<Message> for MessagesSystemPrompt {
|
|||
fn into(self) -> Message {
|
||||
let system_content = match self {
|
||||
MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
|
||||
MessagesSystemPrompt::Blocks(blocks) => {
|
||||
MessageContent::Text(blocks.extract_text())
|
||||
}
|
||||
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
|
||||
};
|
||||
|
||||
Message {
|
||||
|
|
@ -255,7 +267,8 @@ impl Into<Message> for MessagesSystemPrompt {
|
|||
//Utility Functions
|
||||
/// Convert Anthropic tools to OpenAI format
|
||||
fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
|
||||
tools.into_iter()
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
|
|
@ -269,7 +282,9 @@ fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
|
|||
}
|
||||
|
||||
/// Convert Anthropic tool choice to OpenAI format
|
||||
fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Option<ToolChoice>, Option<bool>) {
|
||||
fn convert_anthropic_tool_choice(
|
||||
tool_choice: Option<MessagesToolChoice>,
|
||||
) -> (Option<ToolChoice>, Option<bool>) {
|
||||
match tool_choice {
|
||||
Some(choice) => {
|
||||
let openai_choice = match choice.kind {
|
||||
|
|
@ -290,12 +305,15 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
|
|||
let parallel = choice.disable_parallel_tool_use.map(|disable| !disable);
|
||||
(Some(openai_choice), parallel)
|
||||
}
|
||||
None => (None, None)
|
||||
None => (None, None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build OpenAI message content from parts and tool calls
|
||||
fn build_openai_content(content_parts: Vec<ContentPart>, tool_calls: &[ToolCall]) -> MessageContent {
|
||||
fn build_openai_content(
|
||||
content_parts: Vec<ContentPart>,
|
||||
tool_calls: &[ToolCall],
|
||||
) -> MessageContent {
|
||||
if content_parts.len() == 1 && tool_calls.is_empty() {
|
||||
match &content_parts[0] {
|
||||
ContentPart::Text { text } => MessageContent::Text(text.clone()),
|
||||
|
|
@ -334,7 +352,12 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
|
|||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
}
|
||||
crate::apis::anthropic::MessagesContentBlock::ToolUse { id, name, input, .. } => {
|
||||
crate::apis::anthropic::MessagesContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
..
|
||||
} => {
|
||||
content_blocks.push(ContentBlock::ToolUse {
|
||||
tool_use: ToolUseBlock {
|
||||
tool_use_id: id,
|
||||
|
|
@ -343,7 +366,12 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
|
|||
},
|
||||
});
|
||||
}
|
||||
crate::apis::anthropic::MessagesContentBlock::ToolResult { tool_use_id, is_error, content, .. } => {
|
||||
crate::apis::anthropic::MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
is_error,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
// Convert Anthropic ToolResultContent to Bedrock ToolResultContentBlock
|
||||
let tool_result_content = match content {
|
||||
ToolResultContent::Text(text) => {
|
||||
|
|
@ -366,7 +394,9 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
|
|||
|
||||
// Ensure we have at least one content block
|
||||
let final_content = if tool_result_content.is_empty() {
|
||||
vec![ToolResultContentBlock::Text { text: " ".to_string() }]
|
||||
vec![ToolResultContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
}]
|
||||
} else {
|
||||
tool_result_content
|
||||
};
|
||||
|
|
@ -388,13 +418,13 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
|
|||
crate::apis::anthropic::MessagesContentBlock::Image { source } => {
|
||||
// Convert Anthropic image to Bedrock image format
|
||||
match source {
|
||||
crate::apis::anthropic::MessagesImageSource::Base64 { media_type, data } => {
|
||||
crate::apis::anthropic::MessagesImageSource::Base64 {
|
||||
media_type,
|
||||
data,
|
||||
} => {
|
||||
content_blocks.push(ContentBlock::Image {
|
||||
image: ImageBlock {
|
||||
source: ImageSource::Base64 {
|
||||
media_type,
|
||||
data,
|
||||
},
|
||||
source: ImageSource::Base64 { media_type, data },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -413,7 +443,9 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
|
|||
|
||||
// Ensure we have at least one content block
|
||||
if content_blocks.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text: " ".to_string() });
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(BedrockMessage {
|
||||
|
|
@ -423,29 +455,33 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::anthropic::{MessagesRequest, MessagesMessage, MessagesMessageContent, MessagesRole, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, MessagesSystemPrompt};
|
||||
use crate::apis::amazon_bedrock::{ConverseRequest, SystemContentBlock, ToolChoice as BedrockToolChoice, ConversationRole, ContentBlock};
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
|
||||
ToolChoice as BedrockToolChoice,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
|
||||
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_to_bedrock_basic_request() {
|
||||
let anthropic_request = MessagesRequest {
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
messages: vec![
|
||||
MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello, how are you?".to_string()),
|
||||
}
|
||||
],
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello, how are you?".to_string()),
|
||||
}],
|
||||
max_tokens: 1000,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
system: Some(MessagesSystemPrompt::Single("You are a helpful assistant.".to_string())),
|
||||
system: Some(MessagesSystemPrompt::Single(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)),
|
||||
metadata: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
|
|
@ -478,19 +514,20 @@ mod tests {
|
|||
assert_eq!(inference_config.temperature, Some(0.7));
|
||||
assert_eq!(inference_config.top_p, Some(0.9));
|
||||
assert_eq!(inference_config.max_tokens, Some(1000));
|
||||
assert_eq!(inference_config.stop_sequences, Some(vec!["STOP".to_string()]));
|
||||
assert_eq!(
|
||||
inference_config.stop_sequences,
|
||||
Some(vec!["STOP".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_to_bedrock_with_tools() {
|
||||
let anthropic_request = MessagesRequest {
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
messages: vec![
|
||||
MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("What's the weather like?".to_string()),
|
||||
}
|
||||
],
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("What's the weather like?".to_string()),
|
||||
}],
|
||||
max_tokens: 1000,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
|
|
@ -503,22 +540,20 @@ mod tests {
|
|||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: Some(vec![
|
||||
MessagesTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get current weather information".to_string()),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
}
|
||||
]),
|
||||
tools: Some(vec![MessagesTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get current weather information".to_string()),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
}]),
|
||||
tool_choice: Some(MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Tool,
|
||||
name: Some("get_weather".to_string()),
|
||||
|
|
@ -537,7 +572,10 @@ mod tests {
|
|||
assert_eq!(tools.len(), 1);
|
||||
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
|
||||
assert_eq!(tool_spec.name, "get_weather");
|
||||
assert_eq!(tool_spec.description, Some("Get current weather information".to_string()));
|
||||
assert_eq!(
|
||||
tool_spec.description,
|
||||
Some("Get current weather information".to_string())
|
||||
);
|
||||
|
||||
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
|
||||
assert_eq!(tool.name, "get_weather");
|
||||
|
|
@ -550,12 +588,10 @@ mod tests {
|
|||
fn test_anthropic_to_bedrock_auto_tool_choice() {
|
||||
let anthropic_request = MessagesRequest {
|
||||
model: "claude-3-5-sonnet-20241022".to_string(),
|
||||
messages: vec![
|
||||
MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Help me with something".to_string()),
|
||||
}
|
||||
],
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Help me with something".to_string()),
|
||||
}],
|
||||
max_tokens: 500,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
|
|
@ -568,16 +604,14 @@ mod tests {
|
|||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: Some(vec![
|
||||
MessagesTool {
|
||||
name: "help_tool".to_string(),
|
||||
description: Some("A helpful tool".to_string()),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
}
|
||||
]),
|
||||
tools: Some(vec![MessagesTool {
|
||||
name: "help_tool".to_string(),
|
||||
description: Some("A helpful tool".to_string()),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
}]),
|
||||
tool_choice: Some(MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
|
|
@ -589,7 +623,10 @@ mod tests {
|
|||
|
||||
assert!(bedrock_request.tool_config.is_some());
|
||||
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
|
||||
assert!(matches!(tool_config.tool_choice, Some(BedrockToolChoice::Auto { .. })));
|
||||
assert!(matches!(
|
||||
tool_config.tool_choice,
|
||||
Some(BedrockToolChoice::Auto { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -603,12 +640,14 @@ mod tests {
|
|||
},
|
||||
MessagesMessage {
|
||||
role: MessagesRole::Assistant,
|
||||
content: MessagesMessageContent::Single("Hi there! How can I help you?".to_string()),
|
||||
content: MessagesMessageContent::Single(
|
||||
"Hi there! How can I help you?".to_string(),
|
||||
),
|
||||
},
|
||||
MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("What's 2+2?".to_string()),
|
||||
}
|
||||
},
|
||||
],
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
|
|
|
|||
|
|
@ -1,15 +1,21 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
AnyChoice, AutoChoice, ContentBlock, ConversationRole, ConverseRequest, InferenceConfiguration,
|
||||
Message as BedrockMessage, SystemContentBlock, Tool as BedrockTool,
|
||||
ToolChoice as BedrockToolChoice, ToolChoiceSpec, ToolConfiguration, ToolInputSchema,
|
||||
ToolSpecDefinition,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
|
||||
MessagesSystemPrompt, MessagesTool, MessagesToolChoice, MessagesToolChoiceType,
|
||||
ToolResultContent,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::transforms::lib::*;
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::*;
|
||||
use crate::apis::anthropic::{MessagesSystemPrompt, MessagesMessage,MessagesRequest, MessagesMessageContent, MessagesContentBlock, MessagesRole, MessagesTool, MessagesToolChoice, MessagesToolChoiceType, ToolResultContent};
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, Tool, ToolChoice, ToolChoiceType, MessageContent};
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ConverseRequest, SystemContentBlock, InferenceConfiguration, ToolConfiguration,
|
||||
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolInputSchema, ToolSpecDefinition,
|
||||
AutoChoice, AnyChoice, ToolChoiceSpec,
|
||||
Message as BedrockMessage, ConversationRole, ContentBlock
|
||||
};
|
||||
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
|
|
@ -21,7 +27,7 @@ impl Into<MessagesSystemPrompt> for Message {
|
|||
fn into(self) -> MessagesSystemPrompt {
|
||||
let system_text = match self.content {
|
||||
MessageContent::Text(text) => text,
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
};
|
||||
MessagesSystemPrompt::Single(system_text)
|
||||
}
|
||||
|
|
@ -36,8 +42,11 @@ impl TryFrom<Message> for MessagesMessage {
|
|||
Role::Assistant => MessagesRole::Assistant,
|
||||
Role::Tool => {
|
||||
// Tool messages become user messages with tool results
|
||||
let tool_call_id = message.tool_call_id
|
||||
.ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?;
|
||||
let tool_call_id = message.tool_call_id.ok_or_else(|| {
|
||||
TransformError::MissingField(
|
||||
"tool_call_id required for Tool messages".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
return Ok(MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
|
|
@ -55,7 +64,9 @@ impl TryFrom<Message> for MessagesMessage {
|
|||
});
|
||||
}
|
||||
Role::System => {
|
||||
return Err(TransformError::UnsupportedConversion("System messages should be handled separately".to_string()));
|
||||
return Err(TransformError::UnsupportedConversion(
|
||||
"System messages should be handled separately".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -75,7 +86,9 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
Role::Assistant => ConversationRole::Assistant,
|
||||
Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock
|
||||
Role::System => {
|
||||
return Err(TransformError::UnsupportedConversion("System messages should be handled separately in Bedrock".to_string()));
|
||||
return Err(TransformError::UnsupportedConversion(
|
||||
"System messages should be handled separately in Bedrock".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -103,7 +116,9 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
crate::apis::openai::ContentPart::ImageUrl { image_url } => {
|
||||
// Convert image URL to Bedrock image format
|
||||
if image_url.url.starts_with("data:") {
|
||||
if let Some((media_type, data)) = parse_data_url(&image_url.url) {
|
||||
if let Some((media_type, data)) =
|
||||
parse_data_url(&image_url.url)
|
||||
{
|
||||
content_blocks.push(ContentBlock::Image {
|
||||
image: crate::apis::amazon_bedrock::ImageBlock {
|
||||
source: crate::apis::amazon_bedrock::ImageSource::Base64 {
|
||||
|
|
@ -114,7 +129,10 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
});
|
||||
} else {
|
||||
return Err(TransformError::UnsupportedConversion(
|
||||
format!("Invalid data URL format: {}", image_url.url)
|
||||
format!(
|
||||
"Invalid data URL format: {}",
|
||||
image_url.url
|
||||
),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
|
|
@ -130,13 +148,18 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
|
||||
// Ensure we have at least one content block
|
||||
if content_blocks.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text: " ".to_string() });
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Role::Assistant => {
|
||||
// Handle text content - but only add if non-empty OR if we don't have tool calls
|
||||
let text_content = message.content.extract_text();
|
||||
let has_tool_calls = message.tool_calls.as_ref().map_or(false, |calls| !calls.is_empty());
|
||||
let has_tool_calls = message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map_or(false, |calls| !calls.is_empty());
|
||||
|
||||
// Add text content if it's non-empty, or if we have no tool calls (to avoid empty content)
|
||||
if !text_content.is_empty() {
|
||||
|
|
@ -144,17 +167,22 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
} else if !has_tool_calls {
|
||||
// If we have empty content and no tool calls, add a minimal placeholder
|
||||
// This prevents the "blank text field" error
|
||||
content_blocks.push(ContentBlock::Text { text: " ".to_string() });
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Convert tool calls to ToolUse content blocks
|
||||
if let Some(tool_calls) = message.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
// Parse the arguments string as JSON
|
||||
let input: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
|
||||
.map_err(|e| TransformError::UnsupportedConversion(
|
||||
format!("Failed to parse tool arguments as JSON: {}. Arguments: {}", e, tool_call.function.arguments)
|
||||
))?;
|
||||
let input: serde_json::Value =
|
||||
serde_json::from_str(&tool_call.function.arguments).map_err(|e| {
|
||||
TransformError::UnsupportedConversion(format!(
|
||||
"Failed to parse tool arguments as JSON: {}. Arguments: {}",
|
||||
e, tool_call.function.arguments
|
||||
))
|
||||
})?;
|
||||
|
||||
content_blocks.push(ContentBlock::ToolUse {
|
||||
tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
|
||||
|
|
@ -168,13 +196,18 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
|
||||
// Bedrock requires at least one content block
|
||||
if content_blocks.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text: " ".to_string() });
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: " ".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Role::Tool => {
|
||||
// Tool messages become user messages with ToolResult content blocks
|
||||
let tool_call_id = message.tool_call_id
|
||||
.ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?;
|
||||
let tool_call_id = message.tool_call_id.ok_or_else(|| {
|
||||
TransformError::MissingField(
|
||||
"tool_call_id required for Tool messages".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_content = message.content.extract_text();
|
||||
|
||||
|
|
@ -182,11 +215,11 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
let tool_result_content = if tool_content.is_empty() {
|
||||
// Even for tool results, we need non-empty content
|
||||
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
|
||||
text: " ".to_string()
|
||||
text: " ".to_string(),
|
||||
}]
|
||||
} else {
|
||||
vec![crate::apis::amazon_bedrock::ToolResultContentBlock::Text {
|
||||
text: tool_content
|
||||
text: tool_content,
|
||||
}]
|
||||
};
|
||||
|
||||
|
|
@ -232,13 +265,15 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
|
|||
|
||||
// Convert tools and tool choice
|
||||
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
|
||||
let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
|
||||
let anthropic_tool_choice =
|
||||
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
|
||||
|
||||
Ok(AnthropicMessagesRequest {
|
||||
model: req.model,
|
||||
system: system_prompt,
|
||||
messages,
|
||||
max_tokens: req.max_completion_tokens
|
||||
max_tokens: req
|
||||
.max_completion_tokens
|
||||
.or(req.max_tokens)
|
||||
.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
container: None,
|
||||
|
|
@ -297,8 +332,11 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
|
|||
|
||||
// Build inference configuration
|
||||
let max_tokens = req.max_completion_tokens.or(req.max_tokens);
|
||||
let inference_config = if max_tokens.is_some() || req.temperature.is_some() ||
|
||||
req.top_p.is_some() || req.stop.is_some() {
|
||||
let inference_config = if max_tokens.is_some()
|
||||
|| req.temperature.is_some()
|
||||
|| req.top_p.is_some()
|
||||
|| req.stop.is_some()
|
||||
{
|
||||
Some(InferenceConfiguration {
|
||||
max_tokens,
|
||||
temperature: req.temperature,
|
||||
|
|
@ -312,7 +350,8 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
|
|||
// Convert tools and tool choice to ToolConfiguration
|
||||
let tool_config = if req.tools.is_some() || req.tool_choice.is_some() {
|
||||
let tools = req.tools.map(|openai_tools| {
|
||||
openai_tools.into_iter()
|
||||
openai_tools
|
||||
.into_iter()
|
||||
.map(|tool| BedrockTool::ToolSpec {
|
||||
tool_spec: ToolSpecDefinition {
|
||||
name: tool.function.name,
|
||||
|
|
@ -325,34 +364,40 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
|
|||
.collect()
|
||||
});
|
||||
|
||||
let tool_choice = req.tool_choice.map(|choice| {
|
||||
match choice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => BedrockToolChoice::Auto { auto: AutoChoice {} },
|
||||
ToolChoiceType::Required => BedrockToolChoice::Any { any: AnyChoice {} },
|
||||
ToolChoiceType::None => BedrockToolChoice::Auto { auto: AutoChoice {} }, // Bedrock doesn't have explicit "none"
|
||||
},
|
||||
ToolChoice::Function { function, .. } => {
|
||||
BedrockToolChoice::Tool {
|
||||
tool: ToolChoiceSpec {
|
||||
name: function.name
|
||||
let tool_choice = req
|
||||
.tool_choice
|
||||
.map(|choice| {
|
||||
match choice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
},
|
||||
ToolChoiceType::Required => {
|
||||
BedrockToolChoice::Any { any: AnyChoice {} }
|
||||
}
|
||||
}
|
||||
ToolChoiceType::None => BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
}, // Bedrock doesn't have explicit "none"
|
||||
},
|
||||
ToolChoice::Function { function, .. } => BedrockToolChoice::Tool {
|
||||
tool: ToolChoiceSpec {
|
||||
name: function.name,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}).or_else(|| {
|
||||
// If tools are present but no tool_choice specified, default to "auto"
|
||||
if tools.is_some() {
|
||||
Some(BedrockToolChoice::Auto { auto: AutoChoice {} })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
})
|
||||
.or_else(|| {
|
||||
// If tools are present but no tool_choice specified, default to "auto"
|
||||
if tools.is_some() {
|
||||
Some(BedrockToolChoice::Auto {
|
||||
auto: AutoChoice {},
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
Some(ToolConfiguration {
|
||||
tools,
|
||||
tool_choice,
|
||||
})
|
||||
Some(ToolConfiguration { tools, tool_choice })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
|
@ -377,7 +422,8 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
|
|||
|
||||
/// Convert OpenAI tools to Anthropic format
|
||||
fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
|
||||
tools.into_iter()
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| MessagesTool {
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
|
|
@ -386,37 +432,34 @@ fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
|
|||
.collect()
|
||||
}
|
||||
|
||||
|
||||
/// Convert OpenAI tool choice to Anthropic format
|
||||
fn convert_openai_tool_choice(
|
||||
tool_choice: Option<ToolChoice>,
|
||||
parallel_tool_calls: Option<bool>
|
||||
parallel_tool_calls: Option<bool>,
|
||||
) -> Option<MessagesToolChoice> {
|
||||
tool_choice.map(|choice| {
|
||||
match choice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
ToolChoiceType::Required => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Any,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
ToolChoiceType::None => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::None,
|
||||
name: None,
|
||||
disable_parallel_tool_use: None,
|
||||
},
|
||||
},
|
||||
ToolChoice::Function { function, .. } => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Tool,
|
||||
name: Some(function.name),
|
||||
tool_choice.map(|choice| match choice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
}
|
||||
ToolChoiceType::Required => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Any,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
ToolChoiceType::None => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::None,
|
||||
name: None,
|
||||
disable_parallel_tool_use: None,
|
||||
},
|
||||
},
|
||||
ToolChoice::Function { function, .. } => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Tool,
|
||||
name: Some(function.name),
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -434,8 +477,6 @@ fn build_anthropic_content(content_blocks: Vec<MessagesContentBlock>) -> Message
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Parse a data URL into media type and base64 data
|
||||
/// Supports format: data:image/jpeg;base64,<data>
|
||||
fn parse_data_url(url: &str) -> Option<(String, String)> {
|
||||
|
|
@ -473,8 +514,14 @@ fn parse_data_url(url: &str) -> Option<(String, String)> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType, Function, FunctionChoice};
|
||||
use crate::apis::amazon_bedrock::{ConverseRequest, SystemContentBlock, ConversationRole, ContentBlock, ToolChoice as BedrockToolChoice};
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlock, ConversationRole, ConverseRequest, SystemContentBlock,
|
||||
ToolChoice as BedrockToolChoice,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, Function, FunctionChoice, Message, MessageContent, Role, Tool,
|
||||
ToolChoice, ToolChoiceType,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
|
|
@ -495,7 +542,7 @@ mod tests {
|
|||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
|
|
@ -534,50 +581,51 @@ mod tests {
|
|||
assert_eq!(inference_config.temperature, Some(0.7));
|
||||
assert_eq!(inference_config.top_p, Some(0.9));
|
||||
assert_eq!(inference_config.max_tokens, Some(1000));
|
||||
assert_eq!(inference_config.stop_sequences, Some(vec!["STOP".to_string()]));
|
||||
assert_eq!(
|
||||
inference_config.stop_sequences,
|
||||
Some(vec!["STOP".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_bedrock_with_tools() {
|
||||
let openai_request = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("What's the weather like?".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
],
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("What's the weather like?".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_completion_tokens: Some(1000),
|
||||
stop: None,
|
||||
stream: None,
|
||||
tools: Some(vec![
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get current weather information".to_string()),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
}
|
||||
]),
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get current weather information".to_string()),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Function {
|
||||
choice_type: "function".to_string(),
|
||||
function: FunctionChoice { name: "get_weather".to_string() },
|
||||
function: FunctionChoice {
|
||||
name: "get_weather".to_string(),
|
||||
},
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
|
@ -594,7 +642,10 @@ mod tests {
|
|||
|
||||
let crate::apis::amazon_bedrock::Tool::ToolSpec { tool_spec } = &tools[0];
|
||||
assert_eq!(tool_spec.name, "get_weather");
|
||||
assert_eq!(tool_spec.description, Some("Get current weather information".to_string()));
|
||||
assert_eq!(
|
||||
tool_spec.description,
|
||||
Some("Get current weather information".to_string())
|
||||
);
|
||||
|
||||
if let Some(BedrockToolChoice::Tool { tool }) = &tool_config.tool_choice {
|
||||
assert_eq!(tool.name, "get_weather");
|
||||
|
|
@ -607,34 +658,30 @@ mod tests {
|
|||
fn test_openai_to_bedrock_auto_tool_choice() {
|
||||
let openai_request = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Help me with something".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
],
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Help me with something".to_string()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
max_completion_tokens: Some(500),
|
||||
stop: None,
|
||||
stream: None,
|
||||
tools: Some(vec![
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "help_tool".to_string(),
|
||||
description: Some("A helpful tool".to_string()),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
}
|
||||
]),
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "help_tool".to_string(),
|
||||
description: Some("A helpful tool".to_string()),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Type(ToolChoiceType::Auto)),
|
||||
..Default::default()
|
||||
};
|
||||
|
|
@ -643,7 +690,10 @@ mod tests {
|
|||
|
||||
assert!(bedrock_request.tool_config.is_some());
|
||||
let tool_config = bedrock_request.tool_config.as_ref().unwrap();
|
||||
assert!(matches!(tool_config.tool_choice, Some(BedrockToolChoice::Auto { .. })));
|
||||
assert!(matches!(
|
||||
tool_config.tool_choice,
|
||||
Some(BedrockToolChoice::Auto { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -678,7 +728,7 @@ mod tests {
|
|||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
},
|
||||
],
|
||||
temperature: Some(0.5),
|
||||
top_p: None,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
use serde_json::Value;
|
||||
use crate::transforms::lib::*;
|
||||
use crate::clients::TransformError;
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesStreamEvent, MessagesStopReason, MessagesMessageDelta, MessagesResponse,
|
||||
MessagesStreamMessage, MessagesUsage, MessagesContentDelta, MessagesRole, MessagesContentBlock
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse,
|
||||
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
|
||||
};
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ConverseResponse, ConverseOutput, StopReason, ConverseStreamEvent, ContentBlockDelta
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience
|
||||
|
|
@ -20,11 +20,15 @@ impl TryFrom<ChatCompletionsResponse> for MessagesResponse {
|
|||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
|
||||
let choice = resp.choices.into_iter().next()
|
||||
let choice = resp
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| TransformError::MissingField("choices".to_string()))?;
|
||||
|
||||
let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?;
|
||||
let stop_reason = choice.finish_reason
|
||||
let stop_reason = choice
|
||||
.finish_reason
|
||||
.map(|fr| fr.into())
|
||||
.unwrap_or(MessagesStopReason::EndTurn);
|
||||
|
||||
|
|
@ -86,13 +90,17 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
|
|||
};
|
||||
|
||||
// Generate a response ID (Bedrock doesn't provide one)
|
||||
let id = format!("bedrock-{}", std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos());
|
||||
let id = format!(
|
||||
"bedrock-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
);
|
||||
|
||||
// Extract model ID from trace information if available, otherwise use fallback
|
||||
let model = resp.trace
|
||||
let model = resp
|
||||
.trace
|
||||
.as_ref()
|
||||
.and_then(|trace| trace.prompt_router.as_ref())
|
||||
.map(|router| router.invoked_model_id.clone())
|
||||
|
|
@ -231,15 +239,20 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
|||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => MessagesRole::Assistant,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => {
|
||||
MessagesRole::Assistant
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: format!("bedrock-stream-{}", std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()),
|
||||
id: format!(
|
||||
"bedrock-stream-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
),
|
||||
obj_type: "message".to_string(),
|
||||
role,
|
||||
content: vec![],
|
||||
|
|
@ -278,11 +291,11 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
|||
// ContentBlockDelta - convert to Anthropic ContentBlockDelta
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
let delta = match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => {
|
||||
MessagesContentDelta::TextDelta { text }
|
||||
}
|
||||
ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
|
||||
ContentBlockDelta::ToolUse { tool_use } => {
|
||||
MessagesContentDelta::InputJsonDelta { partial_json: tool_use.input }
|
||||
MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: tool_use.input,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -342,11 +355,11 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
|||
}
|
||||
|
||||
// Exception events - convert to Ping (could be enhanced to return error events)
|
||||
ConverseStreamEvent::InternalServerException(_) |
|
||||
ConverseStreamEvent::ModelStreamErrorException(_) |
|
||||
ConverseStreamEvent::ServiceUnavailableException(_) |
|
||||
ConverseStreamEvent::ThrottlingException(_) |
|
||||
ConverseStreamEvent::ValidationException(_) => {
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => {
|
||||
// TODO: Consider adding proper error handling/events
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
|
|
@ -355,7 +368,9 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
|||
}
|
||||
|
||||
/// Convert tool call deltas to Anthropic stream events
|
||||
fn convert_tool_call_deltas(tool_calls: Vec<ToolCallDelta>) -> Result<MessagesStreamEvent, TransformError> {
|
||||
fn convert_tool_call_deltas(
|
||||
tool_calls: Vec<ToolCallDelta>,
|
||||
) -> Result<MessagesStreamEvent, TransformError> {
|
||||
for tool_call in tool_calls {
|
||||
if let Some(id) = &tool_call.id {
|
||||
// Tool call start
|
||||
|
|
@ -403,7 +418,9 @@ fn convert_tool_call_deltas(tool_calls: Vec<ToolCallDelta>) -> Result<MessagesSt
|
|||
///
|
||||
/// Note on S3/URL handling: Converting S3 locations or URLs would require async operations
|
||||
/// to download and convert to base64, which is not implemented in this synchronous function.
|
||||
fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_bedrock::Message) -> Result<Vec<MessagesContentBlock>, TransformError> {
|
||||
fn convert_bedrock_message_to_anthropic_content(
|
||||
message: &crate::apis::amazon_bedrock::Message,
|
||||
) -> Result<Vec<MessagesContentBlock>, TransformError> {
|
||||
use crate::apis::amazon_bedrock::ContentBlock;
|
||||
|
||||
let mut content_blocks = Vec::new();
|
||||
|
|
@ -438,16 +455,19 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
|
|||
crate::apis::amazon_bedrock::ToolResultContentBlock::Image { source } => {
|
||||
// Convert Bedrock ImageSource to Anthropic format
|
||||
match source {
|
||||
crate::apis::amazon_bedrock::ImageSource::Base64 { media_type, data } => {
|
||||
crate::apis::amazon_bedrock::ImageSource::Base64 {
|
||||
media_type,
|
||||
data,
|
||||
} => {
|
||||
tool_result_blocks.push(MessagesContentBlock::Image {
|
||||
source: crate::apis::anthropic::MessagesImageSource::Base64 {
|
||||
media_type: media_type.clone(),
|
||||
data: data.clone(),
|
||||
},
|
||||
source:
|
||||
crate::apis::anthropic::MessagesImageSource::Base64 {
|
||||
media_type: media_type.clone(),
|
||||
data: data.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
// Note: S3Location is not yet implemented in the current Bedrock API definition
|
||||
// but would need async handling when added
|
||||
} // Note: S3Location is not yet implemented in the current Bedrock API definition
|
||||
// but would need async handling when added
|
||||
}
|
||||
}
|
||||
crate::apis::amazon_bedrock::ToolResultContentBlock::Json { json } => {
|
||||
|
|
@ -463,7 +483,10 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
|
|||
use crate::apis::anthropic::ToolResultContent;
|
||||
content_blocks.push(MessagesContentBlock::ToolResult {
|
||||
tool_use_id: tool_result.tool_use_id.clone(),
|
||||
is_error: tool_result.status.as_ref().map(|s| matches!(s, crate::apis::amazon_bedrock::ToolResultStatus::Error)),
|
||||
is_error: tool_result
|
||||
.status
|
||||
.as_ref()
|
||||
.map(|s| matches!(s, crate::apis::amazon_bedrock::ToolResultStatus::Error)),
|
||||
content: ToolResultContent::Blocks(tool_result_blocks),
|
||||
cache_control: None,
|
||||
});
|
||||
|
|
@ -478,8 +501,7 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
|
|||
data: data.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
// Note: S3Location would require async handling if implemented
|
||||
} // Note: S3Location would require async handling if implemented
|
||||
}
|
||||
}
|
||||
ContentBlock::Document { document } => {
|
||||
|
|
@ -493,8 +515,7 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
|
|||
data: data.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
// Note: S3Location would require async handling if implemented
|
||||
} // Note: S3Location would require async handling if implemented
|
||||
}
|
||||
}
|
||||
ContentBlock::GuardContent { guard_content } => {
|
||||
|
|
@ -516,11 +537,13 @@ fn convert_bedrock_message_to_anthropic_content(message: &crate::apis::amazon_be
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ConverseResponse, ConverseOutput, Message as BedrockMessage, ConversationRole,
|
||||
ContentBlock, StopReason, BedrockTokenUsage, ToolResultContentBlock, ToolResultStatus,
|
||||
ConverseTrace, PromptRouterTrace
|
||||
BedrockTokenUsage, ContentBlock, ConversationRole, ConverseOutput, ConverseResponse,
|
||||
ConverseTrace, Message as BedrockMessage, PromptRouterTrace, StopReason,
|
||||
ToolResultContentBlock, ToolResultStatus,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesResponse, MessagesRole, MessagesStopReason, ToolResultContent,
|
||||
};
|
||||
use crate::apis::anthropic::{MessagesResponse, MessagesContentBlock, MessagesStopReason, MessagesRole, ToolResultContent};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
|
|
@ -529,11 +552,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::Assistant,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "Hello! How can I help you today?".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "Hello! How can I help you today?".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: StopReason::EndTurn,
|
||||
|
|
@ -592,7 +613,7 @@ mod tests {
|
|||
"location": "San Francisco"
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
|
|
@ -624,7 +645,10 @@ mod tests {
|
|||
}
|
||||
|
||||
// Check tool use content
|
||||
if let MessagesContentBlock::ToolUse { id, name, input, .. } = &anthropic_response.content[1] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
id, name, input, ..
|
||||
} = &anthropic_response.content[1]
|
||||
{
|
||||
assert_eq!(id, "tool_12345");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco");
|
||||
|
|
@ -649,11 +673,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::Assistant,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "Test response".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "Test response".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: bedrock_stop_reason,
|
||||
|
|
@ -670,7 +692,10 @@ mod tests {
|
|||
};
|
||||
|
||||
let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap();
|
||||
assert_eq!(anthropic_response.stop_reason, expected_anthropic_stop_reason);
|
||||
assert_eq!(
|
||||
anthropic_response.stop_reason,
|
||||
expected_anthropic_stop_reason
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -680,11 +705,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::Assistant,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "Cached response".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "Cached response".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: StopReason::EndTurn,
|
||||
|
|
@ -706,7 +729,10 @@ mod tests {
|
|||
|
||||
assert_eq!(anthropic_response.usage.input_tokens, 100);
|
||||
assert_eq!(anthropic_response.usage.output_tokens, 50);
|
||||
assert_eq!(anthropic_response.usage.cache_creation_input_tokens, Some(20));
|
||||
assert_eq!(
|
||||
anthropic_response.usage.cache_creation_input_tokens,
|
||||
Some(20)
|
||||
);
|
||||
assert_eq!(anthropic_response.usage.cache_read_input_tokens, Some(10));
|
||||
}
|
||||
|
||||
|
|
@ -723,14 +749,12 @@ mod tests {
|
|||
ContentBlock::ToolResult {
|
||||
tool_result: crate::apis::amazon_bedrock::ToolResultBlock {
|
||||
tool_use_id: "tool_67890".to_string(),
|
||||
content: vec![
|
||||
ToolResultContentBlock::Text {
|
||||
text: "Temperature: 72°F, Sunny".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ToolResultContentBlock::Text {
|
||||
text: "Temperature: 72°F, Sunny".to_string(),
|
||||
}],
|
||||
status: Some(ToolResultStatus::Success),
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
|
|
@ -761,7 +785,12 @@ mod tests {
|
|||
}
|
||||
|
||||
// Check tool result content
|
||||
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &anthropic_response.content[1] {
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
..
|
||||
} = &anthropic_response.content[1]
|
||||
{
|
||||
assert_eq!(tool_use_id, "tool_67890");
|
||||
if let ToolResultContent::Blocks(blocks) = content {
|
||||
assert_eq!(blocks.len(), 1);
|
||||
|
|
@ -804,7 +833,7 @@ mod tests {
|
|||
name: "lookup".to_string(),
|
||||
input: json!({"id": "12345"}),
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
|
|
@ -870,11 +899,12 @@ mod tests {
|
|||
name: "test_function".to_string(),
|
||||
input: json!({"param": "value"}),
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let content_blocks = convert_bedrock_message_to_anthropic_content(&bedrock_message).unwrap();
|
||||
let content_blocks =
|
||||
convert_bedrock_message_to_anthropic_content(&bedrock_message).unwrap();
|
||||
|
||||
assert_eq!(content_blocks.len(), 2);
|
||||
|
||||
|
|
@ -884,7 +914,10 @@ mod tests {
|
|||
panic!("Expected text content block");
|
||||
}
|
||||
|
||||
if let MessagesContentBlock::ToolUse { id, name, input, .. } = &content_blocks[1] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
id, name, input, ..
|
||||
} = &content_blocks[1]
|
||||
{
|
||||
assert_eq!(id, "test_tool");
|
||||
assert_eq!(name, "test_function");
|
||||
assert_eq!(input["param"], "value");
|
||||
|
|
@ -900,11 +933,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::Assistant,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "I am an assistant".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "I am an assistant".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: StopReason::EndTurn,
|
||||
|
|
@ -928,11 +959,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::User,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "I am a user".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "I am a user".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: StopReason::EndTurn,
|
||||
|
|
@ -985,7 +1014,10 @@ mod tests {
|
|||
let anthropic_response: MessagesResponse = bedrock_response.try_into().unwrap();
|
||||
|
||||
// Should extract model ID from trace
|
||||
assert_eq!(anthropic_response.model, "anthropic.claude-3-sonnet-20240229-v1:0");
|
||||
assert_eq!(
|
||||
anthropic_response.model,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
);
|
||||
|
||||
// Test fallback when no trace information is available
|
||||
let bedrock_response_no_trace = ConverseResponse {
|
||||
|
|
@ -1010,7 +1042,8 @@ mod tests {
|
|||
performance_config: None,
|
||||
};
|
||||
|
||||
let anthropic_response_fallback: MessagesResponse = bedrock_response_no_trace.try_into().unwrap();
|
||||
let anthropic_response_fallback: MessagesResponse =
|
||||
bedrock_response_no_trace.try_into().unwrap();
|
||||
|
||||
// Should use fallback model name
|
||||
assert_eq!(anthropic_response_fallback.model, "bedrock-model");
|
||||
|
|
|
|||
|
|
@ -1,10 +1,18 @@
|
|||
use crate::apis::openai::{ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, ResponseMessage, Role, ToolCallDelta, FunctionCallDelta, Usage, StreamChoice, MessageDelta, MessageContent};
|
||||
use crate::apis::anthropic::{MessagesResponse, MessagesStreamEvent, MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesUsage};
|
||||
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason};
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason,
|
||||
MessagesStreamEvent, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
|
||||
FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice,
|
||||
ToolCallDelta, Usage,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// MAIN RESPONSE TRANSFORMATIONS
|
||||
// ============================================================================
|
||||
|
|
@ -35,7 +43,11 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
|
|||
MessageContent::Text(text) => Some(text),
|
||||
MessageContent::Parts(parts) => {
|
||||
let text = parts.extract_text();
|
||||
if text.is_empty() { None } else { Some(text) }
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -105,7 +117,6 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
|
|||
StopReason::ContentFiltered => FinishReason::ContentFilter,
|
||||
};
|
||||
|
||||
|
||||
// Create response message
|
||||
let response_message = ResponseMessage {
|
||||
role,
|
||||
|
|
@ -135,13 +146,17 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
|
|||
};
|
||||
|
||||
// Generate a response ID (using timestamp since Bedrock doesn't provide one)
|
||||
let id = format!("bedrock-{}", std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos());
|
||||
let id = format!(
|
||||
"bedrock-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
);
|
||||
|
||||
// Extract model ID from trace information if available, otherwise use fallback
|
||||
let model = resp.trace
|
||||
let model = resp
|
||||
.trace
|
||||
.as_ref()
|
||||
.and_then(|trace| trace.prompt_router.as_ref())
|
||||
.map(|router| router.invoked_model_id.clone())
|
||||
|
|
@ -160,7 +175,6 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// STREAMING TRANSFORMATIONS
|
||||
// ============================================================================
|
||||
|
|
@ -170,33 +184,27 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
|||
|
||||
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
MessagesStreamEvent::MessageStart { message } => {
|
||||
Ok(create_openai_chunk(
|
||||
&message.id,
|
||||
&message.model,
|
||||
MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
|
||||
&message.id,
|
||||
&message.model,
|
||||
MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
|
||||
convert_content_block_start(content_block)
|
||||
}
|
||||
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
convert_content_delta(delta)
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => {
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
||||
|
||||
MessagesStreamEvent::MessageDelta { delta, usage } => {
|
||||
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
|
||||
|
|
@ -217,39 +225,34 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
|||
))
|
||||
}
|
||||
|
||||
MessagesStreamEvent::MessageStop => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(FinishReason::Stop),
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(FinishReason::Stop),
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::Ping => {
|
||||
Ok(ChatCompletionsStreamResponse {
|
||||
id: "stream".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: "unknown".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
})
|
||||
}
|
||||
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
|
||||
id: "stream".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: "unknown".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
|
|
@ -280,29 +283,27 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
|||
use crate::apis::amazon_bedrock::ContentBlockStart;
|
||||
|
||||
match start_event.start {
|
||||
ContentBlockStart::ToolUse { tool_use } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: start_event.content_block_index as u32,
|
||||
id: Some(tool_use.tool_use_id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(tool_use.name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: start_event.content_block_index as u32,
|
||||
id: Some(tool_use.tool_use_id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(tool_use.name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -310,50 +311,44 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
|||
use crate::apis::amazon_bedrock::ContentBlockDelta;
|
||||
|
||||
match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
ContentBlockDelta::ToolUse { tool_use } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(tool_use.input),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(tool_use.input),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStop(_) => {
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
|
||||
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let finish_reason = match stop_event.stop_reason {
|
||||
|
|
@ -405,27 +400,27 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
|||
}
|
||||
|
||||
// Error events - convert to empty chunks (errors should be handled elsewhere)
|
||||
ConverseStreamEvent::InternalServerException(_) |
|
||||
ConverseStreamEvent::ModelStreamErrorException(_) |
|
||||
ConverseStreamEvent::ServiceUnavailableException(_) |
|
||||
ConverseStreamEvent::ThrottlingException(_) |
|
||||
ConverseStreamEvent::ValidationException(_) => {
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content block start to OpenAI chunk
|
||||
fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
fn convert_content_block_start(
|
||||
content_block: MessagesContentBlock,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match content_block {
|
||||
MessagesContentBlock::Text { .. } => {
|
||||
// No immediate output for text block start
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
MessagesContentBlock::ToolUse { id, name, .. } |
|
||||
MessagesContentBlock::ServerToolUse { id, name, .. } |
|
||||
MessagesContentBlock::McpToolUse { id, name, .. } => {
|
||||
MessagesContentBlock::ToolUse { id, name, .. }
|
||||
| MessagesContentBlock::ServerToolUse { id, name, .. }
|
||||
| MessagesContentBlock::McpToolUse { id, name, .. } => {
|
||||
// Tool use start → OpenAI chunk with tool_calls
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
|
|
@ -449,66 +444,64 @@ fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<Ch
|
|||
None,
|
||||
))
|
||||
}
|
||||
_ => Err(TransformError::UnsupportedContent("Unsupported content block type in stream start".to_string())),
|
||||
_ => Err(TransformError::UnsupportedContent(
|
||||
"Unsupported content block type in stream start".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content delta to OpenAI chunk
|
||||
fn convert_content_delta(delta: MessagesContentDelta) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
fn convert_content_delta(
|
||||
delta: MessagesContentDelta,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(format!("thinking: {}", thinking)),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesContentDelta::InputJsonDelta { partial_json } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(partial_json),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(format!("thinking: {}", thinking)),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(partial_json),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -518,7 +511,7 @@ fn create_openai_chunk(
|
|||
model: &str,
|
||||
delta: MessageDelta,
|
||||
finish_reason: Option<FinishReason>,
|
||||
usage: Option<Usage>
|
||||
usage: Option<Usage>,
|
||||
) -> ChatCompletionsStreamResponse {
|
||||
ChatCompletionsStreamResponse {
|
||||
id: id.to_string(),
|
||||
|
|
@ -555,7 +548,9 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
|
|||
}
|
||||
|
||||
/// Convert Anthropic content blocks to OpenAI message content
|
||||
fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Result<MessageContent, TransformError> {
|
||||
fn convert_anthropic_content_to_openai(
|
||||
content: &[MessagesContentBlock],
|
||||
) -> Result<MessageContent, TransformError> {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for block in content {
|
||||
|
|
@ -592,9 +587,11 @@ impl Into<FinishReason> for MessagesStopReason {
|
|||
|
||||
/// Convert Bedrock Message to OpenAI content and tool calls
|
||||
/// This function extracts text content and tool calls from a Bedrock message
|
||||
fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Message) -> Result<(Option<String>, Option<Vec<crate::apis::openai::ToolCall>>), TransformError> {
|
||||
fn convert_bedrock_message_to_openai(
|
||||
message: &crate::apis::amazon_bedrock::Message,
|
||||
) -> Result<(Option<String>, Option<Vec<crate::apis::openai::ToolCall>>), TransformError> {
|
||||
use crate::apis::amazon_bedrock::ContentBlock;
|
||||
use crate::apis::openai::{ToolCall, FunctionCall};
|
||||
use crate::apis::openai::{FunctionCall, ToolCall};
|
||||
|
||||
let mut text_content = String::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
|
|
@ -614,12 +611,20 @@ fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Mess
|
|||
},
|
||||
});
|
||||
}
|
||||
_ => continue,
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
let content = if text_content.is_empty() { None } else { Some(text_content) };
|
||||
let tool_calls = if tool_calls.is_empty() { None } else { Some(tool_calls) };
|
||||
let content = if text_content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text_content)
|
||||
};
|
||||
let tool_calls = if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
};
|
||||
|
||||
Ok((content, tool_calls))
|
||||
}
|
||||
|
|
@ -628,8 +633,8 @@ fn convert_bedrock_message_to_openai(message: &crate::apis::amazon_bedrock::Mess
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::amazon_bedrock::{
|
||||
ConverseResponse, ConverseOutput, Message as BedrockMessage, ConversationRole,
|
||||
ContentBlock, StopReason, BedrockTokenUsage, ConverseTrace, PromptRouterTrace
|
||||
BedrockTokenUsage, ContentBlock, ConversationRole, ConverseOutput, ConverseResponse,
|
||||
ConverseTrace, Message as BedrockMessage, PromptRouterTrace, StopReason,
|
||||
};
|
||||
use crate::apis::openai::{ChatCompletionsResponse, FinishReason, Role};
|
||||
use serde_json::json;
|
||||
|
|
@ -640,11 +645,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::Assistant,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "Hello! How can I help you today?".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "Hello! How can I help you today?".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: StopReason::EndTurn,
|
||||
|
|
@ -671,7 +674,10 @@ mod tests {
|
|||
let choice = &openai_response.choices[0];
|
||||
assert_eq!(choice.index, 0);
|
||||
assert_eq!(choice.message.role, Role::Assistant);
|
||||
assert_eq!(choice.message.content, Some("Hello! How can I help you today?".to_string()));
|
||||
assert_eq!(
|
||||
choice.message.content,
|
||||
Some("Hello! How can I help you today?".to_string())
|
||||
);
|
||||
assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
|
||||
assert!(choice.message.tool_calls.is_none());
|
||||
|
||||
|
|
@ -699,7 +705,7 @@ mod tests {
|
|||
"location": "San Francisco"
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
|
|
@ -718,11 +724,21 @@ mod tests {
|
|||
|
||||
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::ToolCalls));
|
||||
assert_eq!(openai_response.choices[0].message.content, Some("I'll help you check the weather.".to_string()));
|
||||
assert_eq!(
|
||||
openai_response.choices[0].finish_reason,
|
||||
Some(FinishReason::ToolCalls)
|
||||
);
|
||||
assert_eq!(
|
||||
openai_response.choices[0].message.content,
|
||||
Some("I'll help you check the weather.".to_string())
|
||||
);
|
||||
|
||||
// Check tool calls
|
||||
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap();
|
||||
let tool_calls = openai_response.choices[0]
|
||||
.message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
|
||||
let tool_call = &tool_calls[0];
|
||||
|
|
@ -750,11 +766,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::Assistant,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "Test response".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "Test response".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: bedrock_stop_reason,
|
||||
|
|
@ -771,7 +785,10 @@ mod tests {
|
|||
};
|
||||
|
||||
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
|
||||
assert_eq!(openai_response.choices[0].finish_reason, Some(expected_openai_finish_reason));
|
||||
assert_eq!(
|
||||
openai_response.choices[0].finish_reason,
|
||||
Some(expected_openai_finish_reason)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -798,7 +815,7 @@ mod tests {
|
|||
name: "lookup".to_string(),
|
||||
input: json!({"id": "12345"}),
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
|
|
@ -817,23 +834,35 @@ mod tests {
|
|||
|
||||
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::ToolCalls));
|
||||
assert_eq!(openai_response.choices[0].message.content, Some("I'll help with multiple tasks.".to_string()));
|
||||
assert_eq!(
|
||||
openai_response.choices[0].finish_reason,
|
||||
Some(FinishReason::ToolCalls)
|
||||
);
|
||||
assert_eq!(
|
||||
openai_response.choices[0].message.content,
|
||||
Some("I'll help with multiple tasks.".to_string())
|
||||
);
|
||||
|
||||
// Check multiple tool calls
|
||||
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap();
|
||||
let tool_calls = openai_response.choices[0]
|
||||
.message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
assert_eq!(tool_calls.len(), 2);
|
||||
|
||||
// First tool call
|
||||
assert_eq!(tool_calls[0].id, "tool_1");
|
||||
assert_eq!(tool_calls[0].function.name, "search");
|
||||
let args1: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
|
||||
let args1: serde_json::Value =
|
||||
serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
|
||||
assert_eq!(args1["query"], "weather");
|
||||
|
||||
// Second tool call
|
||||
assert_eq!(tool_calls[1].id, "tool_2");
|
||||
assert_eq!(tool_calls[1].function.name, "lookup");
|
||||
let args2: serde_json::Value = serde_json::from_str(&tool_calls[1].function.arguments).unwrap();
|
||||
let args2: serde_json::Value =
|
||||
serde_json::from_str(&tool_calls[1].function.arguments).unwrap();
|
||||
assert_eq!(args2["id"], "12345");
|
||||
}
|
||||
|
||||
|
|
@ -856,7 +885,7 @@ mod tests {
|
|||
},
|
||||
ContentBlock::Text {
|
||||
text: "Second part.".to_string(),
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
|
|
@ -876,10 +905,17 @@ mod tests {
|
|||
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
|
||||
|
||||
// Content should be combined text parts (no separator added)
|
||||
assert_eq!(openai_response.choices[0].message.content, Some("First part. Second part.".to_string()));
|
||||
assert_eq!(
|
||||
openai_response.choices[0].message.content,
|
||||
Some("First part. Second part.".to_string())
|
||||
);
|
||||
|
||||
// Should have one tool call
|
||||
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap();
|
||||
let tool_calls = openai_response.choices[0]
|
||||
.message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, "tool_mid");
|
||||
assert_eq!(tool_calls[0].function.name, "calculate");
|
||||
|
|
@ -891,15 +927,13 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::Assistant,
|
||||
content: vec![
|
||||
ContentBlock::ToolUse {
|
||||
tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
|
||||
tool_use_id: "tool_only".to_string(),
|
||||
name: "action".to_string(),
|
||||
input: json!({}),
|
||||
},
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::ToolUse {
|
||||
tool_use: crate::apis::amazon_bedrock::ToolUseBlock {
|
||||
tool_use_id: "tool_only".to_string(),
|
||||
name: "action".to_string(),
|
||||
input: json!({}),
|
||||
},
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: StopReason::ToolUse,
|
||||
|
|
@ -921,7 +955,11 @@ mod tests {
|
|||
assert_eq!(openai_response.choices[0].message.content, None);
|
||||
|
||||
// Should have tool call
|
||||
let tool_calls = openai_response.choices[0].message.tool_calls.as_ref().unwrap();
|
||||
let tool_calls = openai_response.choices[0]
|
||||
.message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, "tool_only");
|
||||
}
|
||||
|
|
@ -940,7 +978,7 @@ mod tests {
|
|||
name: "test_function".to_string(),
|
||||
input: json!({"param": "value"}),
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
|
|
@ -953,7 +991,8 @@ mod tests {
|
|||
assert_eq!(tool_calls[0].id, "test_tool");
|
||||
assert_eq!(tool_calls[0].function.name, "test_function");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
|
||||
assert_eq!(args["param"], "value");
|
||||
}
|
||||
|
||||
|
|
@ -964,11 +1003,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::Assistant,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "I am an assistant".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "I am an assistant".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: StopReason::EndTurn,
|
||||
|
|
@ -992,11 +1029,9 @@ mod tests {
|
|||
output: ConverseOutput::Message {
|
||||
message: BedrockMessage {
|
||||
role: ConversationRole::User,
|
||||
content: vec![
|
||||
ContentBlock::Text {
|
||||
text: "I am a user".to_string(),
|
||||
}
|
||||
],
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "I am a user".to_string(),
|
||||
}],
|
||||
},
|
||||
},
|
||||
stop_reason: StopReason::EndTurn,
|
||||
|
|
@ -1049,7 +1084,10 @@ mod tests {
|
|||
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
|
||||
|
||||
// Should extract model ID from trace
|
||||
assert_eq!(openai_response.model, "anthropic.claude-3-sonnet-20240229-v1:0");
|
||||
assert_eq!(
|
||||
openai_response.model,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
);
|
||||
|
||||
// Test fallback when no trace information is available
|
||||
let bedrock_response_no_trace = ConverseResponse {
|
||||
|
|
@ -1074,7 +1112,8 @@ mod tests {
|
|||
performance_config: None,
|
||||
};
|
||||
|
||||
let openai_response_fallback: ChatCompletionsResponse = bedrock_response_no_trace.try_into().unwrap();
|
||||
let openai_response_fallback: ChatCompletionsResponse =
|
||||
bedrock_response_no_trace.try_into().unwrap();
|
||||
|
||||
// Should use fallback model name
|
||||
assert_eq!(openai_response_fallback.model, "bedrock-model");
|
||||
|
|
@ -1082,7 +1121,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_bedrock_to_openai_with_multimedia_content() {
|
||||
use crate::apis::amazon_bedrock::{ImageSource};
|
||||
use crate::apis::amazon_bedrock::ImageSource;
|
||||
|
||||
let bedrock_response = ConverseResponse {
|
||||
output: ConverseOutput::Message {
|
||||
|
|
@ -1118,7 +1157,10 @@ mod tests {
|
|||
|
||||
let openai_response: ChatCompletionsResponse = bedrock_response.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_response.choices[0].finish_reason, Some(FinishReason::Stop));
|
||||
assert_eq!(
|
||||
openai_response.choices[0].finish_reason,
|
||||
Some(FinishReason::Stop)
|
||||
);
|
||||
|
||||
let content = openai_response.choices[0].message.content.as_ref().unwrap();
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue