add support for agents (#564)

This commit is contained in:
Adil Hafeez 2025-10-14 14:01:11 -07:00 committed by GitHub
parent f8991a3c4b
commit 96e0732089
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 3571 additions and 856 deletions

View file

@ -21,7 +21,10 @@
//! assert!(endpoints.contains(&"/v1/messages"));
//! ```
use crate::{apis::{AnthropicApi, ApiDefinition, OpenAIApi}, ProviderId};
use crate::{
apis::{AnthropicApi, ApiDefinition, OpenAIApi},
ProviderId,
};
use std::fmt;
/// Unified enum representing all supported API endpoints across providers
@ -34,8 +37,12 @@ pub enum SupportedAPIs {
impl fmt::Display for SupportedAPIs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SupportedAPIs::OpenAIChatCompletions(api) => write!(f, "OpenAI API ({})", api.endpoint()),
SupportedAPIs::AnthropicMessagesAPI(api) => write!(f, "Anthropic API ({})", api.endpoint()),
SupportedAPIs::OpenAIChatCompletions(api) => {
write!(f, "OpenAI API ({})", api.endpoint())
}
SupportedAPIs::AnthropicMessagesAPI(api) => {
write!(f, "Anthropic API ({})", api.endpoint())
}
}
}
}
@ -62,61 +69,60 @@ impl SupportedAPIs {
}
}
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String {
pub fn target_endpoint_for_provider(
&self,
provider_id: &ProviderId,
request_path: &str,
model_id: &str,
) -> String {
let default_endpoint = "/v1/chat/completions".to_string();
match self {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
match provider_id {
ProviderId::Anthropic => "/v1/messages".to_string(),
_ => default_endpoint,
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
ProviderId::Anthropic => "/v1/messages".to_string(),
_ => default_endpoint,
},
_ => match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
format!("/openai{}", request_path)
} else {
default_endpoint
}
}
}
_ => {
match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
format!("/openai{}", request_path)
} else {
default_endpoint
}
ProviderId::Zhipu => {
if request_path.starts_with("/v1/") {
"/api/paas/v4/chat/completions".to_string()
} else {
default_endpoint
}
ProviderId::Zhipu => {
if request_path.starts_with("/v1/") {
"/api/paas/v4/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::Qwen => {
if request_path.starts_with("/v1/") {
"/compatible-mode/v1/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
} else {
default_endpoint
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
"/v1beta/openai/chat/completions".to_string()
} else {
default_endpoint
}
}
_ => default_endpoint,
}
}
ProviderId::Qwen => {
if request_path.starts_with("/v1/") {
"/compatible-mode/v1/chat/completions".to_string()
} else {
default_endpoint
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
} else {
default_endpoint
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
"/v1beta/openai/chat/completions".to_string()
} else {
default_endpoint
}
}
_ => default_endpoint,
},
}
}
}
/// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new();
@ -196,15 +202,26 @@ mod tests {
// All OpenAI endpoints should be in the result
for endpoint in openai_endpoints {
assert!(endpoints.contains(&endpoint), "Missing OpenAI endpoint: {}", endpoint);
assert!(
endpoints.contains(&endpoint),
"Missing OpenAI endpoint: {}",
endpoint
);
}
// All Anthropic endpoints should be in the result
for endpoint in anthropic_endpoints {
assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint);
assert!(
endpoints.contains(&endpoint),
"Missing Anthropic endpoint: {}",
endpoint
);
}
// Total should match
assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len());
assert_eq!(
endpoints.len(),
OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len()
);
}
}

View file

@ -1,9 +1,9 @@
pub mod endpoints;
pub mod lib;
pub mod transformer;
pub mod endpoints;
// Re-export the main items for easier access
pub use endpoints::{identify_provider, SupportedAPIs};
pub use lib::*;
pub use endpoints::{SupportedAPIs, identify_provider};
// Note: transformer module contains TryFrom trait implementations that are automatically available

View file

@ -42,10 +42,10 @@
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! ```
use super::TransformError;
use crate::apis::*;
use serde_json::Value;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::apis::*;
use super::TransformError;
// ============================================================================
// CONSTANTS
@ -66,7 +66,9 @@ pub trait ExtractText {
/// Trait for utility functions on content collections
trait ContentUtils<T> {
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
}
// ============================================================================
@ -75,7 +77,6 @@ trait ContentUtils<T> {
type AnthropicMessagesRequest = MessagesRequest;
impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
type Error = TransformError;
@ -95,7 +96,8 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
// Convert tools and tool choice
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice);
let (openai_tool_choice, parallel_tool_calls) =
convert_anthropic_tool_choice(req.tool_choice);
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
model: req.model,
@ -137,13 +139,15 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
// Convert tools and tool choice
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
let anthropic_tool_choice =
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
Ok(AnthropicMessagesRequest {
model: req.model,
system: system_prompt,
messages,
max_tokens: req.max_completion_tokens
max_tokens: req
.max_completion_tokens
.or(req.max_tokens)
.unwrap_or(DEFAULT_MAX_TOKENS),
container: None,
@ -179,7 +183,11 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
MessageContent::Text(text) => Some(text),
MessageContent::Parts(parts) => {
let text = parts.extract_text();
if text.is_empty() { None } else { Some(text) }
if text.is_empty() {
None
} else {
Some(text)
}
}
};
@ -225,11 +233,15 @@ impl TryFrom<ChatCompletionsResponse> for MessagesResponse {
type Error = TransformError;
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
let choice = resp.choices.into_iter().next()
let choice = resp
.choices
.into_iter()
.next()
.ok_or_else(|| TransformError::MissingField("choices".to_string()))?;
let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?;
let stop_reason = choice.finish_reason
let stop_reason = choice
.finish_reason
.map(|fr| fr.into())
.unwrap_or(MessagesStopReason::EndTurn);
@ -263,33 +275,27 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
match event {
MessagesStreamEvent::MessageStart { message } => {
Ok(create_openai_chunk(
&message.id,
&message.model,
MessageDelta {
role: Some(Role::Assistant),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
&message.id,
&message.model,
MessageDelta {
role: Some(Role::Assistant),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
convert_content_block_start(content_block)
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
convert_content_delta(delta)
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
MessagesStreamEvent::ContentBlockStop { .. } => {
Ok(create_empty_openai_chunk())
}
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
MessagesStreamEvent::MessageDelta { delta, usage } => {
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
@ -310,34 +316,30 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
))
}
MessagesStreamEvent::MessageStop => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
Some(FinishReason::Stop),
None,
))
}
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
Some(FinishReason::Stop),
None,
)),
MessagesStreamEvent::Ping => {
Ok(ChatCompletionsStreamResponse {
id: "stream".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: current_timestamp(),
model: "unknown".to_string(),
choices: vec![],
usage: None,
system_fingerprint: None,
service_tier: None,
})
}
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
id: "stream".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: current_timestamp(),
model: "unknown".to_string(),
choices: vec![],
usage: None,
system_fingerprint: None,
service_tier: None,
}),
}
}
}
@ -442,9 +444,7 @@ impl Into<Message> for MessagesSystemPrompt {
fn into(self) -> Message {
let system_content = match self {
MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
MessagesSystemPrompt::Blocks(blocks) => {
MessageContent::Text(blocks.extract_text())
}
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
};
Message {
@ -461,7 +461,7 @@ impl Into<MessagesSystemPrompt> for Message {
fn into(self) -> MessagesSystemPrompt {
let system_text = match self.content {
MessageContent::Text(text) => text,
MessageContent::Parts(parts) => parts.extract_text()
MessageContent::Parts(parts) => parts.extract_text(),
};
MessagesSystemPrompt::Single(system_text)
}
@ -505,7 +505,11 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
role: message.role.into(),
content,
name: None,
tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) },
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
};
result.push(main_message);
@ -526,8 +530,11 @@ impl TryFrom<Message> for MessagesMessage {
Role::Assistant => MessagesRole::Assistant,
Role::Tool => {
// Tool messages become user messages with tool results
let tool_call_id = message.tool_call_id
.ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?;
let tool_call_id = message.tool_call_id.ok_or_else(|| {
TransformError::MissingField(
"tool_call_id required for Tool messages".to_string(),
)
})?;
return Ok(MessagesMessage {
role: MessagesRole::User,
@ -545,7 +552,9 @@ impl TryFrom<Message> for MessagesMessage {
});
}
Role::System => {
return Err(TransformError::UnsupportedConversion("System messages should be handled separately".to_string()));
return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately".to_string(),
));
}
};
@ -573,24 +582,36 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
for block in self {
match block {
MessagesContentBlock::ToolUse { id, name, input, .. } |
MessagesContentBlock::ServerToolUse { id, name, input } |
MessagesContentBlock::McpToolUse { id, name, input } => {
MessagesContentBlock::ToolUse {
id, name, input, ..
}
| MessagesContentBlock::ServerToolUse { id, name, input }
| MessagesContentBlock::McpToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input)?;
tool_calls.push(ToolCall {
id: id.clone(),
call_type: "function".to_string(),
function: FunctionCall { name: name.clone(), arguments },
function: FunctionCall {
name: name.clone(),
arguments,
},
});
}
_ => continue,
}
}
Ok(if tool_calls.is_empty() { None } else { Some(tool_calls) })
Ok(if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
})
}
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError> {
fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>
{
let mut content_parts = Vec::new();
let mut tool_calls = Vec::new();
let mut tool_results = Vec::new();
@ -609,25 +630,55 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
},
});
}
MessagesContentBlock::ToolUse { id, name, input, .. } |
MessagesContentBlock::ServerToolUse { id, name, input } |
MessagesContentBlock::McpToolUse { id, name, input } => {
MessagesContentBlock::ToolUse {
id, name, input, ..
}
| MessagesContentBlock::ServerToolUse { id, name, input }
| MessagesContentBlock::McpToolUse { id, name, input } => {
let arguments = serde_json::to_string(&input)?;
tool_calls.push(ToolCall {
id: id.clone(),
call_type: "function".to_string(),
function: FunctionCall { name: name.clone(), arguments },
function: FunctionCall {
name: name.clone(),
arguments,
},
});
}
MessagesContentBlock::ToolResult { tool_use_id, content, is_error, .. } => {
MessagesContentBlock::ToolResult {
tool_use_id,
content,
is_error,
..
} => {
let result_text = content.extract_text();
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false)));
tool_results.push((
tool_use_id.clone(),
result_text,
is_error.unwrap_or(false),
));
}
MessagesContentBlock::WebSearchToolResult { tool_use_id, content, is_error } |
MessagesContentBlock::CodeExecutionToolResult { tool_use_id, content, is_error } |
MessagesContentBlock::McpToolResult { tool_use_id, content, is_error } => {
MessagesContentBlock::WebSearchToolResult {
tool_use_id,
content,
is_error,
}
| MessagesContentBlock::CodeExecutionToolResult {
tool_use_id,
content,
is_error,
}
| MessagesContentBlock::McpToolResult {
tool_use_id,
content,
is_error,
} => {
let result_text = content.extract_text();
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false)));
tool_results.push((
tool_use_id.clone(),
result_text,
is_error.unwrap_or(false),
));
}
_ => {
// Skip unsupported content types
@ -696,7 +747,10 @@ impl Into<MessagesUsage> for Usage {
/// Helper to create a current unix timestamp
fn current_timestamp() -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs()
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
/// Helper to create OpenAI streaming chunk
@ -705,7 +759,7 @@ fn create_openai_chunk(
model: &str,
delta: MessageDelta,
finish_reason: Option<FinishReason>,
usage: Option<Usage>
usage: Option<Usage>,
) -> ChatCompletionsStreamResponse {
ChatCompletionsStreamResponse {
id: id.to_string(),
@ -743,7 +797,8 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
/// Convert Anthropic tools to OpenAI format
fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
tools.into_iter()
tools
.into_iter()
.map(|tool| Tool {
tool_type: "function".to_string(),
function: Function {
@ -758,7 +813,8 @@ fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
/// Convert OpenAI tools to Anthropic format
fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
tools.into_iter()
tools
.into_iter()
.map(|tool| MessagesTool {
name: tool.function.name,
description: tool.function.description,
@ -768,7 +824,9 @@ fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
}
/// Convert Anthropic tool choice to OpenAI format
fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Option<ToolChoice>, Option<bool>) {
fn convert_anthropic_tool_choice(
tool_choice: Option<MessagesToolChoice>,
) -> (Option<ToolChoice>, Option<bool>) {
match tool_choice {
Some(choice) => {
let openai_choice = match choice.kind {
@ -789,45 +847,46 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
let parallel = choice.disable_parallel_tool_use.map(|disable| !disable);
(Some(openai_choice), parallel)
}
None => (None, None)
None => (None, None),
}
}
/// Convert OpenAI tool choice to Anthropic format
fn convert_openai_tool_choice(
tool_choice: Option<ToolChoice>,
parallel_tool_calls: Option<bool>
parallel_tool_calls: Option<bool>,
) -> Option<MessagesToolChoice> {
tool_choice.map(|choice| {
match choice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::Required => MessagesToolChoice {
kind: MessagesToolChoiceType::Any,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::None => MessagesToolChoice {
kind: MessagesToolChoiceType::None,
name: None,
disable_parallel_tool_use: None,
},
},
ToolChoice::Function { function, .. } => MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some(function.name),
tool_choice.map(|choice| match choice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
}
ToolChoiceType::Required => MessagesToolChoice {
kind: MessagesToolChoiceType::Any,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
ToolChoiceType::None => MessagesToolChoice {
kind: MessagesToolChoiceType::None,
name: None,
disable_parallel_tool_use: None,
},
},
ToolChoice::Function { function, .. } => MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
name: Some(function.name),
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
})
}
/// Build OpenAI message content from parts and tool calls
fn build_openai_content(content_parts: Vec<ContentPart>, tool_calls: &[ToolCall]) -> MessageContent {
fn build_openai_content(
content_parts: Vec<ContentPart>,
tool_calls: &[ToolCall],
) -> MessageContent {
if content_parts.len() == 1 && tool_calls.is_empty() {
match &content_parts[0] {
ContentPart::Text { text } => MessageContent::Text(text.clone()),
@ -855,7 +914,9 @@ fn build_anthropic_content(content_blocks: Vec<MessagesContentBlock>) -> Message
}
/// Convert Anthropic content blocks to OpenAI message content
fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Result<MessageContent, TransformError> {
fn convert_anthropic_content_to_openai(
content: &[MessagesContentBlock],
) -> Result<MessageContent, TransformError> {
let mut text_parts = Vec::new();
for block in content {
@ -877,21 +938,29 @@ fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Resu
}
/// Convert OpenAI message to Anthropic content blocks
fn convert_openai_message_to_anthropic_content(message: &Message) -> Result<Vec<MessagesContentBlock>, TransformError> {
fn convert_openai_message_to_anthropic_content(
message: &Message,
) -> Result<Vec<MessagesContentBlock>, TransformError> {
let mut blocks = Vec::new();
// Handle regular content
match &message.content {
MessageContent::Text(text) => {
if !text.is_empty() {
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None });
blocks.push(MessagesContentBlock::Text {
text: text.clone(),
cache_control: None,
});
}
}
MessageContent::Parts(parts) => {
for part in parts {
match part {
ContentPart::Text { text } => {
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None });
blocks.push(MessagesContentBlock::Text {
text: text.clone(),
cache_control: None,
});
}
ContentPart::ImageUrl { image_url } => {
let source = convert_image_url_to_source(image_url);
@ -947,23 +1016,29 @@ fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource {
data: data.to_string(),
}
} else {
MessagesImageSource::Url { url: image_url.url.clone() }
MessagesImageSource::Url {
url: image_url.url.clone(),
}
}
} else {
MessagesImageSource::Url { url: image_url.url.clone() }
MessagesImageSource::Url {
url: image_url.url.clone(),
}
}
}
/// Convert content block start to OpenAI chunk
fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<ChatCompletionsStreamResponse, TransformError> {
fn convert_content_block_start(
content_block: MessagesContentBlock,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match content_block {
MessagesContentBlock::Text { .. } => {
// No immediate output for text block start
Ok(create_empty_openai_chunk())
}
MessagesContentBlock::ToolUse { id, name, .. } |
MessagesContentBlock::ServerToolUse { id, name, .. } |
MessagesContentBlock::McpToolUse { id, name, .. } => {
MessagesContentBlock::ToolUse { id, name, .. }
| MessagesContentBlock::ServerToolUse { id, name, .. }
| MessagesContentBlock::McpToolUse { id, name, .. } => {
// Tool use start → OpenAI chunk with tool_calls
Ok(create_openai_chunk(
"stream",
@ -987,71 +1062,71 @@ fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<Ch
None,
))
}
_ => Err(TransformError::UnsupportedContent("Unsupported content block type in stream start".to_string())),
_ => Err(TransformError::UnsupportedContent(
"Unsupported content block type in stream start".to_string(),
)),
}
}
/// Convert content delta to OpenAI chunk
fn convert_content_delta(delta: MessagesContentDelta) -> Result<ChatCompletionsStreamResponse, TransformError> {
fn convert_content_delta(
delta: MessagesContentDelta,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match delta {
MessagesContentDelta::TextDelta { text } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
MessagesContentDelta::ThinkingDelta { thinking } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(format!("thinking: {}", thinking)),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
MessagesContentDelta::InputJsonDelta { partial_json } => {
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(partial_json),
}),
}]),
},
None,
None,
))
}
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(format!("thinking: {}", thinking)),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(partial_json),
}),
}]),
},
None,
None,
)),
}
}
/// Convert tool call deltas to Anthropic stream events
fn convert_tool_call_deltas(tool_calls: Vec<ToolCallDelta>) -> Result<MessagesStreamEvent, TransformError> {
fn convert_tool_call_deltas(
tool_calls: Vec<ToolCallDelta>,
) -> Result<MessagesStreamEvent, TransformError> {
for tool_call in tool_calls {
if let Some(id) = &tool_call.id {
// Tool call start
@ -1160,11 +1235,20 @@ mod tests {
// Check key fields are preserved
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
assert_eq!(original_anthropic.max_tokens, roundtrip_anthropic.max_tokens);
assert_eq!(original_anthropic.temperature, roundtrip_anthropic.temperature);
assert_eq!(
original_anthropic.max_tokens,
roundtrip_anthropic.max_tokens
);
assert_eq!(
original_anthropic.temperature,
roundtrip_anthropic.temperature
);
assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p);
assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream);
assert_eq!(original_anthropic.messages.len(), roundtrip_anthropic.messages.len());
assert_eq!(
original_anthropic.messages.len(),
roundtrip_anthropic.messages.len()
);
}
#[test]
@ -1308,7 +1392,10 @@ mod tests {
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, Some("call_123".to_string()));
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("get_weather".to_string())
);
}
#[test]
@ -1328,7 +1415,10 @@ mod tests {
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.as_ref().unwrap().arguments, Some(r#"{"location": "San Francisco"#.to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().arguments,
Some(r#"{"location": "San Francisco"#.to_string())
);
}
#[test]
@ -1491,7 +1581,10 @@ mod tests {
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match anthropic_event {
MessagesStreamEvent::ContentBlockStart { index, content_block } => {
MessagesStreamEvent::ContentBlockStart {
index,
content_block,
} => {
assert_eq!(index, 0);
match content_block {
MessagesContentBlock::ToolUse { id, name, .. } => {
@ -1634,16 +1727,28 @@ mod tests {
// Verify tool start
let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls[0].id, Some("call_weather".to_string()));
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("get_weather".to_string())
);
// Verify argument deltas
let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0]
.function.as_ref().unwrap().arguments;
.function
.as_ref()
.unwrap()
.arguments;
assert_eq!(args1, &Some(r#"{"location": "#.to_string()));
let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0]
.function.as_ref().unwrap().arguments;
assert_eq!(args2, &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string()));
.function
.as_ref()
.unwrap()
.arguments;
assert_eq!(
args2,
&Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())
);
}
#[test]
@ -1671,14 +1776,23 @@ mod tests {
};
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.choices[0].finish_reason, Some(expected_openai_reason));
assert_eq!(
openai_resp.choices[0].finish_reason,
Some(expected_openai_reason)
);
// Test reverse conversion
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match roundtrip_event {
MessagesStreamEvent::MessageDelta { delta, .. } => {
// Note: Some precision may be lost in roundtrip due to mapping differences
assert!(matches!(delta.stop_reason, MessagesStopReason::EndTurn | MessagesStopReason::MaxTokens | MessagesStopReason::ToolUse | MessagesStopReason::StopSequence));
assert!(matches!(
delta.stop_reason,
MessagesStopReason::EndTurn
| MessagesStopReason::MaxTokens
| MessagesStopReason::ToolUse
| MessagesStopReason::StopSequence
));
}
_ => panic!("Expected MessageDelta after roundtrip"),
}
@ -1711,7 +1825,8 @@ mod tests {
};
// Should convert to Ping when no meaningful content
let anthropic_event: MessagesStreamEvent = openai_resp_with_missing_data.try_into().unwrap();
let anthropic_event: MessagesStreamEvent =
openai_resp_with_missing_data.try_into().unwrap();
assert!(matches!(anthropic_event, MessagesStreamEvent::Ping));
}