mirror of
https://github.com/katanemo/plano.git
synced 2026-06-05 14:45:15 +02:00
add support for agents (#564)
This commit is contained in:
parent
f8991a3c4b
commit
96e0732089
41 changed files with 3571 additions and 856 deletions
|
|
@ -5,10 +5,10 @@ use serde_with::skip_serializing_none;
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::{MESSAGES_PATH};
|
||||
use crate::MESSAGES_PATH;
|
||||
|
||||
// Enum for all supported Anthropic APIs
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
|
@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi {
|
|||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
AnthropicApi::Messages,
|
||||
]
|
||||
vec![AnthropicApi::Messages]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,7 +98,6 @@ pub struct McpServer {
|
|||
pub tool_configuration: Option<McpToolConfiguration>,
|
||||
}
|
||||
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesRequest {
|
||||
|
|
@ -121,10 +118,8 @@ pub struct MessagesRequest {
|
|||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub tools: Option<Vec<MessagesTool>>,
|
||||
pub tool_choice: Option<MessagesToolChoice>,
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Messages API specific types
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
|
|
@ -235,34 +230,21 @@ impl ExtractText for Vec<MessagesContentBlock> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesImageSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
Base64 { media_type: String, data: String },
|
||||
Url { url: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesDocumentSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
File {
|
||||
file_id: String,
|
||||
},
|
||||
Base64 { media_type: String, data: String },
|
||||
Url { url: String },
|
||||
File { file_id: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent {
|
|||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessagesMessageContent::Single(text) => text.clone(),
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text()
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -320,7 +302,6 @@ pub struct MessagesToolChoice {
|
|||
pub disable_parallel_tool_use: Option<bool>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesStopReason {
|
||||
|
|
@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse {
|
|||
Some(self)
|
||||
}
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
|
||||
Some((
|
||||
self.usage.input_tokens as usize,
|
||||
self.usage.output_tokens as usize,
|
||||
(self.usage.input_tokens + self.usage.output_tokens) as usize,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest {
|
|||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
return &self.metadata;
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
|
|
@ -572,13 +557,11 @@ impl MessagesRole {
|
|||
impl ProviderStreamResponse for MessagesStreamEvent {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Some(text),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Some(text),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -627,7 +610,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields are properly set
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -687,7 +671,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -730,7 +715,10 @@ mod tests {
|
|||
assert_eq!(serialized_json["messages"], original_json["messages"]);
|
||||
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
|
||||
assert_eq!(serialized_json["system"], original_json["system"]);
|
||||
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
|
||||
assert_eq!(
|
||||
serialized_json["service_tier"],
|
||||
original_json["service_tier"]
|
||||
);
|
||||
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
|
||||
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
|
||||
|
||||
|
|
@ -818,7 +806,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate top-level fields
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -833,7 +822,10 @@ mod tests {
|
|||
|
||||
// Validate text content block
|
||||
if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] {
|
||||
assert_eq!(text, "What can you see in this image and what's the weather like?");
|
||||
assert_eq!(
|
||||
text,
|
||||
"What can you see in this image and what's the weather like?"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
|
@ -861,20 +853,32 @@ mod tests {
|
|||
|
||||
// Validate thinking content block
|
||||
if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] {
|
||||
assert_eq!(thinking, "Let me analyze the image and then check the weather...");
|
||||
assert_eq!(
|
||||
thinking,
|
||||
"Let me analyze the image and then check the weather..."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected thinking content block");
|
||||
}
|
||||
|
||||
// Validate text content block
|
||||
if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] {
|
||||
assert_eq!(text, "I can see the image. Let me check the weather for you.");
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see the image. Let me check the weather for you."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
||||
// Validate tool use content block
|
||||
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
ref id,
|
||||
ref name,
|
||||
ref input,
|
||||
..
|
||||
} = content_blocks[2]
|
||||
{
|
||||
assert_eq!(id, "toolu_weather123");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco, CA");
|
||||
|
|
@ -892,7 +896,10 @@ mod tests {
|
|||
|
||||
let tool = &tools[0];
|
||||
assert_eq!(tool.name, "get_weather");
|
||||
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
|
||||
assert_eq!(
|
||||
tool.description,
|
||||
Some("Get current weather information for a location".to_string())
|
||||
);
|
||||
assert_eq!(tool.input_schema["type"], "object");
|
||||
assert!(tool.input_schema["properties"]["location"].is_object());
|
||||
|
||||
|
|
@ -938,10 +945,16 @@ mod tests {
|
|||
assert_eq!(deserialized_mcp.name, "test-server");
|
||||
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
|
||||
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
|
||||
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
|
||||
assert_eq!(
|
||||
deserialized_mcp.authorization_token,
|
||||
Some("secret-token".to_string())
|
||||
);
|
||||
|
||||
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
|
||||
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
|
||||
assert_eq!(
|
||||
tool_config.allowed_tools,
|
||||
Some(vec!["tool1".to_string(), "tool2".to_string()])
|
||||
);
|
||||
assert_eq!(tool_config.enabled, Some(true));
|
||||
} else {
|
||||
panic!("Expected tool configuration");
|
||||
|
|
@ -957,7 +970,8 @@ mod tests {
|
|||
"url": "https://minimal.com/mcp"
|
||||
});
|
||||
|
||||
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
|
||||
let deserialized_minimal: McpServer =
|
||||
serde_json::from_value(minimal_mcp_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_minimal.name, "minimal-server");
|
||||
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
|
||||
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
|
||||
|
|
@ -991,12 +1005,16 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
|
||||
let deserialized_response: MessagesResponse =
|
||||
serde_json::from_value(response_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_response.id, "msg_01ABC123");
|
||||
assert_eq!(deserialized_response.obj_type, "message");
|
||||
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
|
||||
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
|
||||
assert_eq!(
|
||||
deserialized_response.stop_reason,
|
||||
MessagesStopReason::EndTurn
|
||||
);
|
||||
assert!(deserialized_response.stop_sequence.is_none());
|
||||
assert!(deserialized_response.container.is_none());
|
||||
|
||||
|
|
@ -1011,7 +1029,10 @@ mod tests {
|
|||
// Check usage
|
||||
assert_eq!(deserialized_response.usage.input_tokens, 10);
|
||||
assert_eq!(deserialized_response.usage.output_tokens, 25);
|
||||
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
|
||||
assert_eq!(
|
||||
deserialized_response.usage.cache_creation_input_tokens,
|
||||
Some(5)
|
||||
);
|
||||
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
|
||||
|
||||
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
|
||||
|
|
@ -1027,7 +1048,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
|
||||
let deserialized_event: MessagesStreamEvent =
|
||||
serde_json::from_value(stream_event_json.clone()).unwrap();
|
||||
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
|
||||
assert_eq!(index, 0);
|
||||
if let MessagesContentDelta::TextDelta { text } = delta {
|
||||
|
|
@ -1055,8 +1077,15 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use {
|
||||
let deserialized_tool_use: MessagesContentBlock =
|
||||
serde_json::from_value(tool_use_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
ref id,
|
||||
ref name,
|
||||
ref input,
|
||||
..
|
||||
} = deserialized_tool_use
|
||||
{
|
||||
assert_eq!(id, "toolu_01ABC123");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco, CA");
|
||||
|
|
@ -1079,8 +1108,15 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result {
|
||||
let deserialized_tool_result: MessagesContentBlock =
|
||||
serde_json::from_value(tool_result_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
ref tool_use_id,
|
||||
ref is_error,
|
||||
ref content,
|
||||
..
|
||||
} = deserialized_tool_result
|
||||
{
|
||||
assert_eq!(tool_use_id, "toolu_01ABC123");
|
||||
assert!(is_error.is_none());
|
||||
if let ToolResultContent::Blocks(blocks) = content {
|
||||
|
|
@ -1229,7 +1265,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize the complex MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(complex_request_json.clone()).unwrap();
|
||||
|
||||
// Verify basic fields
|
||||
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
|
||||
|
|
@ -1239,8 +1276,15 @@ mod tests {
|
|||
// Verify system message with cache_control
|
||||
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
|
||||
assert_eq!(system_blocks.len(), 2);
|
||||
if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] {
|
||||
assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude.");
|
||||
if let MessagesContentBlock::Text {
|
||||
text,
|
||||
cache_control,
|
||||
} = &system_blocks[0]
|
||||
{
|
||||
assert_eq!(
|
||||
text,
|
||||
"You are Claude Code, Anthropic's official CLI for Claude."
|
||||
);
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
} else {
|
||||
panic!("Expected text system message with cache_control");
|
||||
|
|
@ -1253,7 +1297,13 @@ mod tests {
|
|||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, MessagesRole::Assistant);
|
||||
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
|
||||
if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
cache_control,
|
||||
} = &content_blocks[0]
|
||||
{
|
||||
assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl");
|
||||
assert_eq!(name, "TodoWrite");
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
|
|
@ -1272,7 +1322,12 @@ mod tests {
|
|||
let user_message = &deserialized_request.messages[2];
|
||||
assert_eq!(user_message.role, MessagesRole::User);
|
||||
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
|
||||
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] {
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
..
|
||||
} = &content_blocks[0]
|
||||
{
|
||||
assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl");
|
||||
if let ToolResultContent::Text(text) = content {
|
||||
assert!(text.contains("Todos have been modified successfully"));
|
||||
|
|
@ -1284,7 +1339,11 @@ mod tests {
|
|||
}
|
||||
|
||||
// Verify text content with cache_control
|
||||
if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] {
|
||||
if let MessagesContentBlock::Text {
|
||||
text,
|
||||
cache_control,
|
||||
} = &content_blocks[2]
|
||||
{
|
||||
assert_eq!(text, "try again");
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
} else {
|
||||
|
|
@ -1296,11 +1355,15 @@ mod tests {
|
|||
|
||||
// Test serialization round-trip
|
||||
let serialized_request = serde_json::to_value(&deserialized_request).unwrap();
|
||||
let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap();
|
||||
let re_deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(serialized_request).unwrap();
|
||||
|
||||
// Verify round-trip consistency
|
||||
assert_eq!(deserialized_request.model, re_deserialized_request.model);
|
||||
assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len());
|
||||
assert_eq!(
|
||||
deserialized_request.messages.len(),
|
||||
re_deserialized_request.messages.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1339,7 +1402,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap();
|
||||
let deserialized_event: MessagesStreamEvent =
|
||||
serde_json::from_value(thinking_delta_json.clone()).unwrap();
|
||||
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
|
||||
assert_eq!(index, 0);
|
||||
if let MessagesContentDelta::ThinkingDelta { thinking } = delta {
|
||||
|
|
@ -1352,7 +1416,10 @@ mod tests {
|
|||
}
|
||||
|
||||
// Test that thinking delta is returned by content_delta()
|
||||
assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current"));
|
||||
assert_eq!(
|
||||
deserialized_event.content_delta(),
|
||||
Some(".\n\nI need to consider:\n1. Current")
|
||||
);
|
||||
|
||||
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
|
||||
assert_eq!(thinking_delta_json, serialized_event_json);
|
||||
|
|
@ -1376,7 +1443,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(request_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
|
||||
assert_eq!(deserialized_request.max_tokens, 2048);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ pub mod openai;
|
|||
pub use anthropic::*;
|
||||
pub use openai::*;
|
||||
|
||||
|
||||
pub trait ApiDefinition {
|
||||
/// Returns the endpoint path for this API
|
||||
fn endpoint(&self) -> &'static str;
|
||||
|
|
@ -49,11 +48,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_api_detection_from_endpoints() {
|
||||
// Test that we can detect APIs from endpoints using the trait
|
||||
let endpoints = vec![
|
||||
CHAT_COMPLETIONS_PATH,
|
||||
MESSAGES_PATH,
|
||||
"/v1/unknown"
|
||||
];
|
||||
let endpoints = vec![CHAT_COMPLETIONS_PATH, MESSAGES_PATH, "/v1/unknown"];
|
||||
|
||||
let mut detected_apis = Vec::new();
|
||||
|
||||
|
|
@ -67,11 +62,14 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
assert_eq!(detected_apis, vec![
|
||||
"OpenAI: ChatCompletions",
|
||||
"Anthropic: Messages",
|
||||
"Unknown API"
|
||||
]);
|
||||
assert_eq!(
|
||||
detected_apis,
|
||||
vec![
|
||||
"OpenAI: ChatCompletions",
|
||||
"Anthropic: Messages",
|
||||
"Unknown API"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ use std::collections::HashMap;
|
|||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::{ExtractText};
|
||||
use crate::{CHAT_COMPLETIONS_PATH};
|
||||
use crate::CHAT_COMPLETIONS_PATH;
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI API ENUMERATION
|
||||
|
|
@ -46,7 +46,7 @@ impl ApiDefinition for OpenAIApi {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
}
|
||||
}
|
||||
|
|
@ -58,9 +58,7 @@ impl ApiDefinition for OpenAIApi {
|
|||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
OpenAIApi::ChatCompletions,
|
||||
]
|
||||
vec![OpenAIApi::ChatCompletions]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -190,7 +188,9 @@ impl ResponseMessage {
|
|||
pub fn to_message(&self) -> Message {
|
||||
Message {
|
||||
role: self.role.clone(),
|
||||
content: self.content.as_ref()
|
||||
content: self
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|s| MessageContent::Text(s.clone()))
|
||||
.unwrap_or(MessageContent::Text(String::new())),
|
||||
name: None, // Response messages don't have names in the same way request messages do
|
||||
|
|
@ -215,7 +215,7 @@ impl ExtractText for MessageContent {
|
|||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -274,7 +274,6 @@ pub struct ImageUrl {
|
|||
|
||||
/// A single message in a chat conversation
|
||||
|
||||
|
||||
/// A tool call made by the assistant
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
|
|
@ -374,7 +373,6 @@ pub enum StaticContentType {
|
|||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
|
||||
/// Chat completions API response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -496,7 +494,6 @@ pub struct ChatCompletionsStreamResponse {
|
|||
pub service_tier: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
/// A choice in a streaming response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -566,7 +563,6 @@ pub struct Models {
|
|||
pub data: Vec<ModelDetail>,
|
||||
}
|
||||
|
||||
|
||||
// Error type for streaming operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OpenAIStreamError {
|
||||
|
|
@ -597,13 +593,13 @@ pub enum OpenAIError {
|
|||
/// Trait Implementations
|
||||
/// ===========================================================================
|
||||
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsRequest
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
|
||||
let mut req: ChatCompletionsRequest =
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
|
||||
// Use the centralized suppression logic
|
||||
req.suppress_max_tokens_if_o3();
|
||||
req.fix_temperature_if_gpt5();
|
||||
|
|
@ -651,13 +647,18 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
}).collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
acc + " "
|
||||
+ &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -721,14 +722,14 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
||||
self.choices.first().and_then(|choice| {
|
||||
choice.delta.role.as_ref().map(|r| match r {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
|
|
@ -736,7 +737,6 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -756,7 +756,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields are properly set
|
||||
assert_eq!(deserialized_request.model, "gpt-4");
|
||||
|
|
@ -799,7 +800,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields
|
||||
assert_eq!(deserialized_request.model, "gpt-4");
|
||||
|
|
@ -836,7 +838,10 @@ mod tests {
|
|||
assert_eq!(serialized_json["messages"], original_json["messages"]);
|
||||
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
|
||||
assert_eq!(serialized_json["stream"], original_json["stream"]);
|
||||
assert_eq!(serialized_json["stream_options"], original_json["stream_options"]);
|
||||
assert_eq!(
|
||||
serialized_json["stream_options"],
|
||||
original_json["stream_options"]
|
||||
);
|
||||
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
|
||||
|
||||
// Handle temperature with floating point tolerance
|
||||
|
|
@ -917,7 +922,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate top-level fields
|
||||
assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
|
||||
|
|
@ -953,7 +959,10 @@ mod tests {
|
|||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, Role::Assistant);
|
||||
if let MessageContent::Text(text) = &assistant_message.content {
|
||||
assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you.");
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see a beautiful cityscape. Let me check the weather for you."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content for assistant message");
|
||||
}
|
||||
|
|
@ -967,7 +976,10 @@ mod tests {
|
|||
assert_eq!(tool_call.id, "call_weather123");
|
||||
assert_eq!(tool_call.call_type, "function");
|
||||
assert_eq!(tool_call.function.name, "get_weather");
|
||||
assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}");
|
||||
assert_eq!(
|
||||
tool_call.function.arguments,
|
||||
"{\"location\": \"New York, NY\"}"
|
||||
);
|
||||
|
||||
// Validate third message (tool response)
|
||||
let tool_message = &deserialized_request.messages[2];
|
||||
|
|
@ -977,7 +989,10 @@ mod tests {
|
|||
} else {
|
||||
panic!("Expected text content for tool message");
|
||||
}
|
||||
assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string()));
|
||||
assert_eq!(
|
||||
tool_message.tool_call_id,
|
||||
Some("call_weather123".to_string())
|
||||
);
|
||||
|
||||
// Validate tools array
|
||||
assert!(deserialized_request.tools.is_some());
|
||||
|
|
@ -987,7 +1002,10 @@ mod tests {
|
|||
let tool = &tools[0];
|
||||
assert_eq!(tool.tool_type, "function");
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string()));
|
||||
assert_eq!(
|
||||
tool.function.description,
|
||||
Some("Get current weather information for a location".to_string())
|
||||
);
|
||||
assert_eq!(tool.function.strict, Some(true));
|
||||
|
||||
// Validate tool parameters schema
|
||||
|
|
@ -1093,7 +1111,8 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap();
|
||||
let deserialized_assistant: Message =
|
||||
serde_json::from_value(assistant_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_assistant.role, Role::Assistant);
|
||||
if let MessageContent::Text(content) = &deserialized_assistant.content {
|
||||
assert_eq!(content, "I'll help with that.");
|
||||
|
|
@ -1142,9 +1161,13 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap();
|
||||
let deserialized_response: ResponseMessage =
|
||||
serde_json::from_value(response_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_response.role, Role::Assistant);
|
||||
assert_eq!(deserialized_response.content, Some("Response content".to_string()));
|
||||
assert_eq!(
|
||||
deserialized_response.content,
|
||||
Some("Response content".to_string())
|
||||
);
|
||||
assert!(deserialized_response.annotations.is_some());
|
||||
assert!(deserialized_response.refusal.is_none());
|
||||
assert!(deserialized_response.function_call.is_none());
|
||||
|
|
@ -1186,7 +1209,10 @@ mod tests {
|
|||
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
|
||||
|
||||
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
|
||||
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
|
||||
assert_eq!(
|
||||
required_deserialized,
|
||||
ToolChoice::Type(ToolChoiceType::Required)
|
||||
);
|
||||
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
|
||||
|
||||
// Test that invalid string values fail deserialization (type safety!)
|
||||
|
|
@ -1237,7 +1263,10 @@ mod tests {
|
|||
assert_eq!(response.created, 1756574706);
|
||||
assert_eq!(response.model, "gpt-4o-2024-08-06");
|
||||
assert_eq!(response.service_tier, Some("default".to_string()));
|
||||
assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string()));
|
||||
assert_eq!(
|
||||
response.system_fingerprint,
|
||||
Some("fp_f33640a400".to_string())
|
||||
);
|
||||
assert_eq!(response.choices.len(), 1);
|
||||
assert_eq!(response.usage.prompt_tokens, 65);
|
||||
assert_eq!(response.usage.completion_tokens, 184);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,10 @@
|
|||
//! assert!(endpoints.contains(&"/v1/messages"));
|
||||
//! ```
|
||||
|
||||
use crate::{apis::{AnthropicApi, ApiDefinition, OpenAIApi}, ProviderId};
|
||||
use crate::{
|
||||
apis::{AnthropicApi, ApiDefinition, OpenAIApi},
|
||||
ProviderId,
|
||||
};
|
||||
use std::fmt;
|
||||
|
||||
/// Unified enum representing all supported API endpoints across providers
|
||||
|
|
@ -34,8 +37,12 @@ pub enum SupportedAPIs {
|
|||
impl fmt::Display for SupportedAPIs {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => write!(f, "OpenAI API ({})", api.endpoint()),
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => write!(f, "Anthropic API ({})", api.endpoint()),
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => {
|
||||
write!(f, "OpenAI API ({})", api.endpoint())
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => {
|
||||
write!(f, "Anthropic API ({})", api.endpoint())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -62,61 +69,60 @@ impl SupportedAPIs {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str, model_id: &str) -> String {
|
||||
pub fn target_endpoint_for_provider(
|
||||
&self,
|
||||
provider_id: &ProviderId,
|
||||
request_path: &str,
|
||||
model_id: &str,
|
||||
) -> String {
|
||||
let default_endpoint = "/v1/chat/completions".to_string();
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
|
||||
match provider_id {
|
||||
ProviderId::Anthropic => "/v1/messages".to_string(),
|
||||
_ => default_endpoint,
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
|
||||
ProviderId::Anthropic => "/v1/messages".to_string(),
|
||||
_ => default_endpoint,
|
||||
},
|
||||
_ => match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai{}", request_path)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai{}", request_path)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
ProviderId::Zhipu => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/api/paas/v4/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
ProviderId::Zhipu => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/api/paas/v4/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Qwen => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/compatible-mode/v1/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
}
|
||||
}
|
||||
ProviderId::Qwen => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/compatible-mode/v1/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview", model_id)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Get all supported endpoint paths
|
||||
pub fn supported_endpoints() -> Vec<&'static str> {
|
||||
let mut endpoints = Vec::new();
|
||||
|
|
@ -196,15 +202,26 @@ mod tests {
|
|||
|
||||
// All OpenAI endpoints should be in the result
|
||||
for endpoint in openai_endpoints {
|
||||
assert!(endpoints.contains(&endpoint), "Missing OpenAI endpoint: {}", endpoint);
|
||||
assert!(
|
||||
endpoints.contains(&endpoint),
|
||||
"Missing OpenAI endpoint: {}",
|
||||
endpoint
|
||||
);
|
||||
}
|
||||
|
||||
// All Anthropic endpoints should be in the result
|
||||
for endpoint in anthropic_endpoints {
|
||||
assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint);
|
||||
assert!(
|
||||
endpoints.contains(&endpoint),
|
||||
"Missing Anthropic endpoint: {}",
|
||||
endpoint
|
||||
);
|
||||
}
|
||||
|
||||
// Total should match
|
||||
assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len());
|
||||
assert_eq!(
|
||||
endpoints.len(),
|
||||
OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
pub mod endpoints;
|
||||
pub mod lib;
|
||||
pub mod transformer;
|
||||
pub mod endpoints;
|
||||
|
||||
// Re-export the main items for easier access
|
||||
pub use endpoints::{identify_provider, SupportedAPIs};
|
||||
pub use lib::*;
|
||||
pub use endpoints::{SupportedAPIs, identify_provider};
|
||||
|
||||
// Note: transformer module contains TryFrom trait implementations that are automatically available
|
||||
|
|
|
|||
|
|
@ -42,10 +42,10 @@
|
|||
//! # Ok::<(), Box<dyn std::error::Error>>(())
|
||||
//! ```
|
||||
|
||||
use super::TransformError;
|
||||
use crate::apis::*;
|
||||
use serde_json::Value;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use crate::apis::*;
|
||||
use super::TransformError;
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS
|
||||
|
|
@ -66,7 +66,9 @@ pub trait ExtractText {
|
|||
/// Trait for utility functions on content collections
|
||||
trait ContentUtils<T> {
|
||||
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
|
||||
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
|
||||
fn split_for_openai(
|
||||
&self,
|
||||
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -75,7 +77,6 @@ trait ContentUtils<T> {
|
|||
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
|
||||
impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
|
|
@ -95,7 +96,8 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
|||
|
||||
// Convert tools and tool choice
|
||||
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
|
||||
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice);
|
||||
let (openai_tool_choice, parallel_tool_calls) =
|
||||
convert_anthropic_tool_choice(req.tool_choice);
|
||||
|
||||
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: req.model,
|
||||
|
|
@ -137,13 +139,15 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
|
|||
|
||||
// Convert tools and tool choice
|
||||
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
|
||||
let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
|
||||
let anthropic_tool_choice =
|
||||
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
|
||||
|
||||
Ok(AnthropicMessagesRequest {
|
||||
model: req.model,
|
||||
system: system_prompt,
|
||||
messages,
|
||||
max_tokens: req.max_completion_tokens
|
||||
max_tokens: req
|
||||
.max_completion_tokens
|
||||
.or(req.max_tokens)
|
||||
.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
container: None,
|
||||
|
|
@ -179,7 +183,11 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
|
|||
MessageContent::Text(text) => Some(text),
|
||||
MessageContent::Parts(parts) => {
|
||||
let text = parts.extract_text();
|
||||
if text.is_empty() { None } else { Some(text) }
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -225,11 +233,15 @@ impl TryFrom<ChatCompletionsResponse> for MessagesResponse {
|
|||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
|
||||
let choice = resp.choices.into_iter().next()
|
||||
let choice = resp
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| TransformError::MissingField("choices".to_string()))?;
|
||||
|
||||
let content = convert_openai_message_to_anthropic_content(&choice.message.to_message())?;
|
||||
let stop_reason = choice.finish_reason
|
||||
let stop_reason = choice
|
||||
.finish_reason
|
||||
.map(|fr| fr.into())
|
||||
.unwrap_or(MessagesStopReason::EndTurn);
|
||||
|
||||
|
|
@ -263,33 +275,27 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
|||
|
||||
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
MessagesStreamEvent::MessageStart { message } => {
|
||||
Ok(create_openai_chunk(
|
||||
&message.id,
|
||||
&message.model,
|
||||
MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
|
||||
&message.id,
|
||||
&message.model,
|
||||
MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
|
||||
convert_content_block_start(content_block)
|
||||
}
|
||||
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
convert_content_delta(delta)
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => {
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
||||
|
||||
MessagesStreamEvent::MessageDelta { delta, usage } => {
|
||||
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
|
||||
|
|
@ -310,34 +316,30 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
|||
))
|
||||
}
|
||||
|
||||
MessagesStreamEvent::MessageStop => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(FinishReason::Stop),
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(FinishReason::Stop),
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::Ping => {
|
||||
Ok(ChatCompletionsStreamResponse {
|
||||
id: "stream".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: "unknown".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
})
|
||||
}
|
||||
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
|
||||
id: "stream".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: "unknown".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -442,9 +444,7 @@ impl Into<Message> for MessagesSystemPrompt {
|
|||
fn into(self) -> Message {
|
||||
let system_content = match self {
|
||||
MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
|
||||
MessagesSystemPrompt::Blocks(blocks) => {
|
||||
MessageContent::Text(blocks.extract_text())
|
||||
}
|
||||
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
|
||||
};
|
||||
|
||||
Message {
|
||||
|
|
@ -461,7 +461,7 @@ impl Into<MessagesSystemPrompt> for Message {
|
|||
fn into(self) -> MessagesSystemPrompt {
|
||||
let system_text = match self.content {
|
||||
MessageContent::Text(text) => text,
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
};
|
||||
MessagesSystemPrompt::Single(system_text)
|
||||
}
|
||||
|
|
@ -505,7 +505,11 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
|
|||
role: message.role.into(),
|
||||
content,
|
||||
name: None,
|
||||
tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) },
|
||||
tool_calls: if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
},
|
||||
tool_call_id: None,
|
||||
};
|
||||
result.push(main_message);
|
||||
|
|
@ -526,8 +530,11 @@ impl TryFrom<Message> for MessagesMessage {
|
|||
Role::Assistant => MessagesRole::Assistant,
|
||||
Role::Tool => {
|
||||
// Tool messages become user messages with tool results
|
||||
let tool_call_id = message.tool_call_id
|
||||
.ok_or_else(|| TransformError::MissingField("tool_call_id required for Tool messages".to_string()))?;
|
||||
let tool_call_id = message.tool_call_id.ok_or_else(|| {
|
||||
TransformError::MissingField(
|
||||
"tool_call_id required for Tool messages".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
return Ok(MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
|
|
@ -545,7 +552,9 @@ impl TryFrom<Message> for MessagesMessage {
|
|||
});
|
||||
}
|
||||
Role::System => {
|
||||
return Err(TransformError::UnsupportedConversion("System messages should be handled separately".to_string()));
|
||||
return Err(TransformError::UnsupportedConversion(
|
||||
"System messages should be handled separately".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -573,24 +582,36 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
|
|||
|
||||
for block in self {
|
||||
match block {
|
||||
MessagesContentBlock::ToolUse { id, name, input, .. } |
|
||||
MessagesContentBlock::ServerToolUse { id, name, input } |
|
||||
MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
MessagesContentBlock::ToolUse {
|
||||
id, name, input, ..
|
||||
}
|
||||
| MessagesContentBlock::ServerToolUse { id, name, input }
|
||||
| MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
let arguments = serde_json::to_string(&input)?;
|
||||
tool_calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name: name.clone(), arguments },
|
||||
function: FunctionCall {
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(if tool_calls.is_empty() { None } else { Some(tool_calls) })
|
||||
Ok(if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
})
|
||||
}
|
||||
|
||||
fn split_for_openai(&self) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError> {
|
||||
fn split_for_openai(
|
||||
&self,
|
||||
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>
|
||||
{
|
||||
let mut content_parts = Vec::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
let mut tool_results = Vec::new();
|
||||
|
|
@ -609,25 +630,55 @@ impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
|
|||
},
|
||||
});
|
||||
}
|
||||
MessagesContentBlock::ToolUse { id, name, input, .. } |
|
||||
MessagesContentBlock::ServerToolUse { id, name, input } |
|
||||
MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
MessagesContentBlock::ToolUse {
|
||||
id, name, input, ..
|
||||
}
|
||||
| MessagesContentBlock::ServerToolUse { id, name, input }
|
||||
| MessagesContentBlock::McpToolUse { id, name, input } => {
|
||||
let arguments = serde_json::to_string(&input)?;
|
||||
tool_calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name: name.clone(), arguments },
|
||||
function: FunctionCall {
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
MessagesContentBlock::ToolResult { tool_use_id, content, is_error, .. } => {
|
||||
MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
..
|
||||
} => {
|
||||
let result_text = content.extract_text();
|
||||
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false)));
|
||||
tool_results.push((
|
||||
tool_use_id.clone(),
|
||||
result_text,
|
||||
is_error.unwrap_or(false),
|
||||
));
|
||||
}
|
||||
MessagesContentBlock::WebSearchToolResult { tool_use_id, content, is_error } |
|
||||
MessagesContentBlock::CodeExecutionToolResult { tool_use_id, content, is_error } |
|
||||
MessagesContentBlock::McpToolResult { tool_use_id, content, is_error } => {
|
||||
MessagesContentBlock::WebSearchToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
}
|
||||
| MessagesContentBlock::CodeExecutionToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
}
|
||||
| MessagesContentBlock::McpToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
is_error,
|
||||
} => {
|
||||
let result_text = content.extract_text();
|
||||
tool_results.push((tool_use_id.clone(), result_text, is_error.unwrap_or(false)));
|
||||
tool_results.push((
|
||||
tool_use_id.clone(),
|
||||
result_text,
|
||||
is_error.unwrap_or(false),
|
||||
));
|
||||
}
|
||||
_ => {
|
||||
// Skip unsupported content types
|
||||
|
|
@ -696,7 +747,10 @@ impl Into<MessagesUsage> for Usage {
|
|||
|
||||
/// Helper to create a current unix timestamp
|
||||
fn current_timestamp() -> u64 {
|
||||
SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs()
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
/// Helper to create OpenAI streaming chunk
|
||||
|
|
@ -705,7 +759,7 @@ fn create_openai_chunk(
|
|||
model: &str,
|
||||
delta: MessageDelta,
|
||||
finish_reason: Option<FinishReason>,
|
||||
usage: Option<Usage>
|
||||
usage: Option<Usage>,
|
||||
) -> ChatCompletionsStreamResponse {
|
||||
ChatCompletionsStreamResponse {
|
||||
id: id.to_string(),
|
||||
|
|
@ -743,7 +797,8 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
|
|||
|
||||
/// Convert Anthropic tools to OpenAI format
|
||||
fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
|
||||
tools.into_iter()
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: Function {
|
||||
|
|
@ -758,7 +813,8 @@ fn convert_anthropic_tools(tools: Vec<MessagesTool>) -> Vec<Tool> {
|
|||
|
||||
/// Convert OpenAI tools to Anthropic format
|
||||
fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
|
||||
tools.into_iter()
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| MessagesTool {
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
|
|
@ -768,7 +824,9 @@ fn convert_openai_tools(tools: Vec<Tool>) -> Vec<MessagesTool> {
|
|||
}
|
||||
|
||||
/// Convert Anthropic tool choice to OpenAI format
|
||||
fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Option<ToolChoice>, Option<bool>) {
|
||||
fn convert_anthropic_tool_choice(
|
||||
tool_choice: Option<MessagesToolChoice>,
|
||||
) -> (Option<ToolChoice>, Option<bool>) {
|
||||
match tool_choice {
|
||||
Some(choice) => {
|
||||
let openai_choice = match choice.kind {
|
||||
|
|
@ -789,45 +847,46 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
|
|||
let parallel = choice.disable_parallel_tool_use.map(|disable| !disable);
|
||||
(Some(openai_choice), parallel)
|
||||
}
|
||||
None => (None, None)
|
||||
None => (None, None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenAI tool choice to Anthropic format
|
||||
fn convert_openai_tool_choice(
|
||||
tool_choice: Option<ToolChoice>,
|
||||
parallel_tool_calls: Option<bool>
|
||||
parallel_tool_calls: Option<bool>,
|
||||
) -> Option<MessagesToolChoice> {
|
||||
tool_choice.map(|choice| {
|
||||
match choice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
ToolChoiceType::Required => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Any,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
ToolChoiceType::None => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::None,
|
||||
name: None,
|
||||
disable_parallel_tool_use: None,
|
||||
},
|
||||
},
|
||||
ToolChoice::Function { function, .. } => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Tool,
|
||||
name: Some(function.name),
|
||||
tool_choice.map(|choice| match choice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
}
|
||||
ToolChoiceType::Required => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Any,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
ToolChoiceType::None => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::None,
|
||||
name: None,
|
||||
disable_parallel_tool_use: None,
|
||||
},
|
||||
},
|
||||
ToolChoice::Function { function, .. } => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Tool,
|
||||
name: Some(function.name),
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Build OpenAI message content from parts and tool calls
|
||||
fn build_openai_content(content_parts: Vec<ContentPart>, tool_calls: &[ToolCall]) -> MessageContent {
|
||||
fn build_openai_content(
|
||||
content_parts: Vec<ContentPart>,
|
||||
tool_calls: &[ToolCall],
|
||||
) -> MessageContent {
|
||||
if content_parts.len() == 1 && tool_calls.is_empty() {
|
||||
match &content_parts[0] {
|
||||
ContentPart::Text { text } => MessageContent::Text(text.clone()),
|
||||
|
|
@ -855,7 +914,9 @@ fn build_anthropic_content(content_blocks: Vec<MessagesContentBlock>) -> Message
|
|||
}
|
||||
|
||||
/// Convert Anthropic content blocks to OpenAI message content
|
||||
fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Result<MessageContent, TransformError> {
|
||||
fn convert_anthropic_content_to_openai(
|
||||
content: &[MessagesContentBlock],
|
||||
) -> Result<MessageContent, TransformError> {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for block in content {
|
||||
|
|
@ -877,21 +938,29 @@ fn convert_anthropic_content_to_openai(content: &[MessagesContentBlock]) -> Resu
|
|||
}
|
||||
|
||||
/// Convert OpenAI message to Anthropic content blocks
|
||||
fn convert_openai_message_to_anthropic_content(message: &Message) -> Result<Vec<MessagesContentBlock>, TransformError> {
|
||||
fn convert_openai_message_to_anthropic_content(
|
||||
message: &Message,
|
||||
) -> Result<Vec<MessagesContentBlock>, TransformError> {
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
// Handle regular content
|
||||
match &message.content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None });
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::Parts(parts) => {
|
||||
for part in parts {
|
||||
match part {
|
||||
ContentPart::Text { text } => {
|
||||
blocks.push(MessagesContentBlock::Text { text: text.clone(), cache_control: None });
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
ContentPart::ImageUrl { image_url } => {
|
||||
let source = convert_image_url_to_source(image_url);
|
||||
|
|
@ -947,23 +1016,29 @@ fn convert_image_url_to_source(image_url: &ImageUrl) -> MessagesImageSource {
|
|||
data: data.to_string(),
|
||||
}
|
||||
} else {
|
||||
MessagesImageSource::Url { url: image_url.url.clone() }
|
||||
MessagesImageSource::Url {
|
||||
url: image_url.url.clone(),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MessagesImageSource::Url { url: image_url.url.clone() }
|
||||
MessagesImageSource::Url {
|
||||
url: image_url.url.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content block start to OpenAI chunk
|
||||
fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
fn convert_content_block_start(
|
||||
content_block: MessagesContentBlock,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match content_block {
|
||||
MessagesContentBlock::Text { .. } => {
|
||||
// No immediate output for text block start
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
MessagesContentBlock::ToolUse { id, name, .. } |
|
||||
MessagesContentBlock::ServerToolUse { id, name, .. } |
|
||||
MessagesContentBlock::McpToolUse { id, name, .. } => {
|
||||
MessagesContentBlock::ToolUse { id, name, .. }
|
||||
| MessagesContentBlock::ServerToolUse { id, name, .. }
|
||||
| MessagesContentBlock::McpToolUse { id, name, .. } => {
|
||||
// Tool use start → OpenAI chunk with tool_calls
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
|
|
@ -987,71 +1062,71 @@ fn convert_content_block_start(content_block: MessagesContentBlock) -> Result<Ch
|
|||
None,
|
||||
))
|
||||
}
|
||||
_ => Err(TransformError::UnsupportedContent("Unsupported content block type in stream start".to_string())),
|
||||
_ => Err(TransformError::UnsupportedContent(
|
||||
"Unsupported content block type in stream start".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content delta to OpenAI chunk
|
||||
fn convert_content_delta(delta: MessagesContentDelta) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
fn convert_content_delta(
|
||||
delta: MessagesContentDelta,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(format!("thinking: {}", thinking)),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesContentDelta::InputJsonDelta { partial_json } => {
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(partial_json),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(format!("thinking: {}", thinking)),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(partial_json),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert tool call deltas to Anthropic stream events
|
||||
fn convert_tool_call_deltas(tool_calls: Vec<ToolCallDelta>) -> Result<MessagesStreamEvent, TransformError> {
|
||||
fn convert_tool_call_deltas(
|
||||
tool_calls: Vec<ToolCallDelta>,
|
||||
) -> Result<MessagesStreamEvent, TransformError> {
|
||||
for tool_call in tool_calls {
|
||||
if let Some(id) = &tool_call.id {
|
||||
// Tool call start
|
||||
|
|
@ -1160,11 +1235,20 @@ mod tests {
|
|||
|
||||
// Check key fields are preserved
|
||||
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
|
||||
assert_eq!(original_anthropic.max_tokens, roundtrip_anthropic.max_tokens);
|
||||
assert_eq!(original_anthropic.temperature, roundtrip_anthropic.temperature);
|
||||
assert_eq!(
|
||||
original_anthropic.max_tokens,
|
||||
roundtrip_anthropic.max_tokens
|
||||
);
|
||||
assert_eq!(
|
||||
original_anthropic.temperature,
|
||||
roundtrip_anthropic.temperature
|
||||
);
|
||||
assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p);
|
||||
assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream);
|
||||
assert_eq!(original_anthropic.messages.len(), roundtrip_anthropic.messages.len());
|
||||
assert_eq!(
|
||||
original_anthropic.messages.len(),
|
||||
roundtrip_anthropic.messages.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1308,7 +1392,10 @@ mod tests {
|
|||
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, Some("call_123".to_string()));
|
||||
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().name,
|
||||
Some("get_weather".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1328,7 +1415,10 @@ mod tests {
|
|||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.as_ref().unwrap().arguments, Some(r#"{"location": "San Francisco"#.to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().arguments,
|
||||
Some(r#"{"location": "San Francisco"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1491,7 +1581,10 @@ mod tests {
|
|||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::ContentBlockStart { index, content_block } => {
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => {
|
||||
assert_eq!(index, 0);
|
||||
match content_block {
|
||||
MessagesContentBlock::ToolUse { id, name, .. } => {
|
||||
|
|
@ -1634,16 +1727,28 @@ mod tests {
|
|||
// Verify tool start
|
||||
let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls[0].id, Some("call_weather".to_string()));
|
||||
assert_eq!(tool_calls[0].function.as_ref().unwrap().name, Some("get_weather".to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().name,
|
||||
Some("get_weather".to_string())
|
||||
);
|
||||
|
||||
// Verify argument deltas
|
||||
let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0]
|
||||
.function.as_ref().unwrap().arguments;
|
||||
.function
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.arguments;
|
||||
assert_eq!(args1, &Some(r#"{"location": "#.to_string()));
|
||||
|
||||
let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0]
|
||||
.function.as_ref().unwrap().arguments;
|
||||
assert_eq!(args2, &Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string()));
|
||||
.function
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.arguments;
|
||||
assert_eq!(
|
||||
args2,
|
||||
&Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1671,14 +1776,23 @@ mod tests {
|
|||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
assert_eq!(openai_resp.choices[0].finish_reason, Some(expected_openai_reason));
|
||||
assert_eq!(
|
||||
openai_resp.choices[0].finish_reason,
|
||||
Some(expected_openai_reason)
|
||||
);
|
||||
|
||||
// Test reverse conversion
|
||||
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
match roundtrip_event {
|
||||
MessagesStreamEvent::MessageDelta { delta, .. } => {
|
||||
// Note: Some precision may be lost in roundtrip due to mapping differences
|
||||
assert!(matches!(delta.stop_reason, MessagesStopReason::EndTurn | MessagesStopReason::MaxTokens | MessagesStopReason::ToolUse | MessagesStopReason::StopSequence));
|
||||
assert!(matches!(
|
||||
delta.stop_reason,
|
||||
MessagesStopReason::EndTurn
|
||||
| MessagesStopReason::MaxTokens
|
||||
| MessagesStopReason::ToolUse
|
||||
| MessagesStopReason::StopSequence
|
||||
));
|
||||
}
|
||||
_ => panic!("Expected MessageDelta after roundtrip"),
|
||||
}
|
||||
|
|
@ -1711,7 +1825,8 @@ mod tests {
|
|||
};
|
||||
|
||||
// Should convert to Ping when no meaningful content
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp_with_missing_data.try_into().unwrap();
|
||||
let anthropic_event: MessagesStreamEvent =
|
||||
openai_resp_with_missing_data.try_into().unwrap();
|
||||
assert!(matches!(anthropic_event, MessagesStreamEvent::Ping));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,21 @@
|
|||
//! hermesllm: A library for translating LLM API requests and responses
|
||||
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
|
||||
|
||||
pub mod providers;
|
||||
pub mod apis;
|
||||
pub mod clients;
|
||||
pub mod providers;
|
||||
// Re-export important types and traits
|
||||
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
|
||||
pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, SseEvent, SseStreamIter};
|
||||
pub use providers::id::ProviderId;
|
||||
|
||||
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use providers::response::{
|
||||
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
|
||||
ProviderStreamResponseType, SseEvent, SseStreamIter, TokenUsage,
|
||||
};
|
||||
|
||||
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -30,48 +31,50 @@ mod tests {
|
|||
#[test]
|
||||
fn test_provider_streaming_response() {
|
||||
// Test streaming response parsing with sample SSE data
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
let client_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
let client_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test the new simplified architecture - create SseStreamIter directly
|
||||
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
|
||||
assert!(sse_iter.is_ok());
|
||||
// Test the new simplified architecture - create SseStreamIter directly
|
||||
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
|
||||
assert!(sse_iter.is_ok());
|
||||
|
||||
let mut streaming_iter = sse_iter.unwrap();
|
||||
let mut streaming_iter = sse_iter.unwrap();
|
||||
|
||||
// Test that we can iterate over SseEvents
|
||||
let first_event = streaming_iter.next();
|
||||
assert!(first_event.is_some());
|
||||
// Test that we can iterate over SseEvents
|
||||
let first_event = streaming_iter.next();
|
||||
assert!(first_event.is_some());
|
||||
|
||||
let sse_event = first_event.unwrap();
|
||||
let sse_event = first_event.unwrap();
|
||||
|
||||
// Test SseEvent properties
|
||||
assert!(!sse_event.is_done());
|
||||
assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
|
||||
// Test SseEvent properties
|
||||
assert!(!sse_event.is_done());
|
||||
assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
|
||||
|
||||
// Test that we can parse the event into a provider stream response
|
||||
let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api));
|
||||
if let Err(e) = &transformed_event {
|
||||
println!("Transform error: {:?}", e);
|
||||
}
|
||||
assert!(transformed_event.is_ok());
|
||||
// Test that we can parse the event into a provider stream response
|
||||
let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api));
|
||||
if let Err(e) = &transformed_event {
|
||||
println!("Transform error: {:?}", e);
|
||||
}
|
||||
assert!(transformed_event.is_ok());
|
||||
|
||||
let transformed_event = transformed_event.unwrap();
|
||||
let provider_response = transformed_event.provider_response();
|
||||
assert!(provider_response.is_ok());
|
||||
let transformed_event = transformed_event.unwrap();
|
||||
let provider_response = transformed_event.provider_response();
|
||||
assert!(provider_response.is_ok());
|
||||
|
||||
let stream_response = provider_response.unwrap();
|
||||
assert_eq!(stream_response.content_delta(), Some("Hello"));
|
||||
assert!(!stream_response.is_final());
|
||||
let stream_response = provider_response.unwrap();
|
||||
assert_eq!(stream_response.content_delta(), Some("Hello"));
|
||||
assert!(!stream_response.is_final());
|
||||
|
||||
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
|
||||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
|
||||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use std::fmt::Display;
|
||||
use crate::apis::{AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::{OpenAIApi, AnthropicApi};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
|
|
@ -50,41 +50,50 @@ impl ProviderId {
|
|||
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs {
|
||||
match (self, client_api) {
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
|
||||
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
(
|
||||
ProviderId::Anthropic,
|
||||
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::Ollama
|
||||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(
|
||||
ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::Ollama
|
||||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
(ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::Ollama
|
||||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(
|
||||
ProviderId::OpenAI
|
||||
| ProviderId::Groq
|
||||
| ProviderId::Mistral
|
||||
| ProviderId::Deepseek
|
||||
| ProviderId::Arch
|
||||
| ProviderId::Gemini
|
||||
| ProviderId::GitHub
|
||||
| ProviderId::AzureOpenAI
|
||||
| ProviderId::XAI
|
||||
| ProviderId::TogetherAI
|
||||
| ProviderId::Ollama
|
||||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ pub mod request;
|
|||
pub mod response;
|
||||
|
||||
pub use id::ProviderId;
|
||||
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
|
||||
pub use response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage };
|
||||
pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage};
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::collections::HashMap;
|
||||
#[derive(Clone)]
|
||||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
|
|
@ -103,15 +103,18 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|||
// Use SupportedApi to determine the appropriate request type
|
||||
match client_api {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
let chat_completion_request: ChatCompletionsRequest =
|
||||
ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(
|
||||
chat_completion_request,
|
||||
))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -120,40 +123,55 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|||
impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = ProviderRequestError;
|
||||
|
||||
fn try_from((request, upstream_api): (ProviderRequestType, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(request, upstream_api): (ProviderRequestType, &SupportedAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
match (request, upstream_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let messages_req = MessagesRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
let messages_req =
|
||||
MessagesRequest::try_from(chat_req).map_err(|e| ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to MessagesRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Error types for provider operations
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderRequestError {
|
||||
|
|
@ -169,19 +187,20 @@ impl fmt::Display for ProviderRequestError {
|
|||
|
||||
impl Error for ProviderRequestError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
self.source
|
||||
.as_ref()
|
||||
.map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::anthropic::AnthropicApi::Messages;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest};
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use serde_json::json;
|
||||
|
||||
|
|
@ -202,7 +221,7 @@ mod tests {
|
|||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.messages.len(), 2);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -225,7 +244,7 @@ mod tests {
|
|||
ProviderRequestType::MessagesRequest(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet");
|
||||
assert_eq!(r.messages.len(), 1);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected MessagesRequest variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -247,7 +266,7 @@ mod tests {
|
|||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.messages.len(), 2);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -271,7 +290,7 @@ mod tests {
|
|||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet");
|
||||
assert_eq!(r.messages.len(), 1);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -280,13 +299,15 @@ mod tests {
|
|||
fn test_v1_messages_to_v1_chat_completions_roundtrip() {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())),
|
||||
messages: vec![
|
||||
crate::apis::anthropic::MessagesMessage {
|
||||
role: crate::apis::anthropic::MessagesRole::User,
|
||||
content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()),
|
||||
}
|
||||
],
|
||||
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single(
|
||||
"You are a helpful assistant".to_string(),
|
||||
)),
|
||||
messages: vec![crate::apis::anthropic::MessagesMessage {
|
||||
role: crate::apis::anthropic::MessagesRole::User,
|
||||
content: crate::apis::anthropic::MessagesMessageContent::Single(
|
||||
"Hello!".to_string(),
|
||||
),
|
||||
}],
|
||||
max_tokens: 128,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
|
|
@ -302,16 +323,27 @@ mod tests {
|
|||
metadata: None,
|
||||
};
|
||||
|
||||
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed");
|
||||
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed");
|
||||
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone())
|
||||
.expect("Anthropic->OpenAI conversion failed");
|
||||
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req)
|
||||
.expect("OpenAI->Anthropic conversion failed");
|
||||
|
||||
assert_eq!(anthropic_req.model, anthropic_req2.model);
|
||||
// Compare system prompt text if present
|
||||
assert_eq!(
|
||||
anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }),
|
||||
anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None })
|
||||
anthropic_req.system.as_ref().and_then(|s| match s {
|
||||
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
|
||||
_ => None,
|
||||
}),
|
||||
anthropic_req2.system.as_ref().and_then(|s| match s {
|
||||
crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t),
|
||||
_ => None,
|
||||
})
|
||||
);
|
||||
assert_eq!(
|
||||
anthropic_req.messages[0].role,
|
||||
anthropic_req2.messages[0].role
|
||||
);
|
||||
assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role);
|
||||
// Compare message content text if present
|
||||
assert_eq!(
|
||||
anthropic_req.messages[0].content.extract_text(),
|
||||
|
|
@ -320,49 +352,54 @@ mod tests {
|
|||
assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
|
||||
#[test]
|
||||
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
max_tokens: Some(128),
|
||||
stream: Some(false),
|
||||
stop: Some(vec!["\n".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
..Default::default()
|
||||
};
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
max_tokens: Some(128),
|
||||
stream: Some(false),
|
||||
stop: Some(vec!["\n".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed");
|
||||
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed");
|
||||
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone())
|
||||
.expect("OpenAI->Anthropic conversion failed");
|
||||
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req)
|
||||
.expect("Anthropic->OpenAI conversion failed");
|
||||
|
||||
assert_eq!(openai_req.model, openai_req2.model);
|
||||
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
|
||||
assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text());
|
||||
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
|
||||
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
|
||||
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
|
||||
assert_eq!(original_max_tokens, roundtrip_max_tokens);
|
||||
}
|
||||
assert_eq!(openai_req.model, openai_req2.model);
|
||||
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
|
||||
assert_eq!(
|
||||
openai_req.messages[0].content.extract_text(),
|
||||
openai_req2.messages[0].content.extract_text()
|
||||
);
|
||||
// After roundtrip, deprecated max_tokens should be converted to max_completion_tokens
|
||||
let original_max_tokens = openai_req.max_completion_tokens.or(openai_req.max_tokens);
|
||||
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
|
||||
assert_eq!(original_max_tokens, roundtrip_max_tokens);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
use crate::providers::id::ProviderId;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::convert::TryFrom;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
|
||||
/// Trait for token usage information
|
||||
pub trait TokenUsage {
|
||||
|
|
@ -38,7 +38,8 @@ pub trait ProviderResponse: Send + Sync {
|
|||
|
||||
/// Extract token counts for metrics
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
self.usage()
|
||||
.map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -110,19 +111,19 @@ impl ProviderStreamResponse for ProviderStreamResponseType {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SseEvent {
|
||||
#[serde(rename = "data")]
|
||||
pub data: Option<String>, // The JSON payload after "data: "
|
||||
pub data: Option<String>, // The JSON payload after "data: "
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
|
||||
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
}
|
||||
|
||||
impl SseEvent {
|
||||
|
|
@ -145,13 +146,13 @@ impl SseEvent {
|
|||
|
||||
/// Get the parsed provider response if available
|
||||
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
|
||||
self.provider_stream_response.as_ref()
|
||||
self.provider_stream_response
|
||||
.as_ref()
|
||||
.map(|resp| resp as &dyn ProviderStreamResponse)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl FromStr for SseEvent {
|
||||
|
|
@ -172,7 +173,8 @@ impl FromStr for SseEvent {
|
|||
sse_transform_buffer: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if line.starts_with("event: ") { //used by Anthropic
|
||||
} else if line.starts_with("event: ") {
|
||||
//used by Anthropic
|
||||
let event_type = line[7..].to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
|
|
@ -207,12 +209,13 @@ impl Into<Vec<u8>> for SseEvent {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// --- Response transformation logic for client API compatibility ---
|
||||
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId),
|
||||
) -> Result<Self, Self::Error> {
|
||||
let upstream_api = provider_id.compatible_api_for_client(client_api);
|
||||
match (&upstream_api, client_api) {
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
|
|
@ -230,8 +233,13 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
|||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to OpenAI ChatCompletions format using the transformer
|
||||
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
let chat_resp: ChatCompletionsResponse =
|
||||
anthropic_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
|
||||
}
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
|
|
@ -239,8 +247,12 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
|||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to Anthropic Messages format using the transformer
|
||||
let messages_resp: MessagesResponse = openai_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
let messages_resp: MessagesResponse = openai_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::MessagesResponse(messages_resp))
|
||||
}
|
||||
}
|
||||
|
|
@ -251,36 +263,50 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
|||
impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponseType {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
match (upstream_api, client_api) {
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp))
|
||||
let resp: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
serde_json::from_slice(bytes)?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
|
||||
resp,
|
||||
))
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?;
|
||||
let resp: crate::apis::anthropic::MessagesStreamEvent =
|
||||
serde_json::from_slice(bytes)?;
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(resp))
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?;
|
||||
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent =
|
||||
serde_json::from_slice(bytes)?;
|
||||
|
||||
// Transform to OpenAI ChatCompletions stream format using the transformer
|
||||
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = anthropic_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(chat_resp))
|
||||
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
anthropic_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(
|
||||
chat_resp,
|
||||
))
|
||||
}
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion
|
||||
if bytes == b"[DONE]" {
|
||||
return Ok(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStop
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStop,
|
||||
));
|
||||
}
|
||||
|
||||
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
|
||||
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
serde_json::from_slice(bytes)?;
|
||||
|
||||
// Transform to Anthropic Messages stream format using the transformer
|
||||
let messages_resp: crate::apis::anthropic::MessagesStreamEvent = openai_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(messages_resp))
|
||||
let messages_resp: crate::apis::anthropic::MessagesStreamEvent =
|
||||
openai_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
messages_resp,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -290,7 +316,9 @@ impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponse
|
|||
impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
// Create a new transformed event based on the original
|
||||
let mut transformed_event = sse_event;
|
||||
|
||||
|
|
@ -298,7 +326,8 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
|
|||
if transformed_event.data.is_some() {
|
||||
let data_str = transformed_event.data.as_ref().unwrap();
|
||||
let data_bytes = data_str.as_bytes();
|
||||
let transformed_response = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
|
||||
let transformed_response =
|
||||
ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
|
||||
let transformed_json = serde_json::to_string(&transformed_response)?;
|
||||
transformed_event.sse_transform_buffer = format!("data: {}\n\n", transformed_json);
|
||||
transformed_event.provider_stream_response = Some(transformed_response);
|
||||
|
|
@ -344,7 +373,10 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
|
|||
transformed_event.sse_transform_buffer
|
||||
);
|
||||
} else {
|
||||
transformed_event.sse_transform_buffer = format!("event: {}\n{}", event_type, transformed_event.sse_transform_buffer);
|
||||
transformed_event.sse_transform_buffer = format!(
|
||||
"event: {}\n{}",
|
||||
event_type, transformed_event.sse_transform_buffer
|
||||
);
|
||||
}
|
||||
}
|
||||
// If event_type is None, we just keep the data line as-is without an event line
|
||||
|
|
@ -396,7 +428,10 @@ where
|
|||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self { lines, done_seen: false }
|
||||
Self {
|
||||
lines,
|
||||
done_seen: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -451,7 +486,6 @@ pub struct ProviderResponseError {
|
|||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
|
||||
impl fmt::Display for ProviderResponseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Provider response error: {}", self.message)
|
||||
|
|
@ -460,17 +494,19 @@ impl fmt::Display for ProviderResponseError {
|
|||
|
||||
impl Error for ProviderResponseError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
self.source
|
||||
.as_ref()
|
||||
.map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
|
|
@ -491,13 +527,17 @@ mod tests {
|
|||
"system_fingerprint": null
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI));
|
||||
let result = ProviderResponseType::try_from((
|
||||
bytes.as_slice(),
|
||||
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
&ProviderId::OpenAI,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.choices.len(), 1);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -516,13 +556,17 @@ mod tests {
|
|||
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic));
|
||||
let result = ProviderResponseType::try_from((
|
||||
bytes.as_slice(),
|
||||
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
|
||||
&ProviderId::Anthropic,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(r.content.len(), 1);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -546,14 +590,18 @@ mod tests {
|
|||
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
|
||||
let result = ProviderResponseType::try_from((
|
||||
bytes.as_slice(),
|
||||
&SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
|
||||
&ProviderId::OpenAI,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.usage.input_tokens, 10);
|
||||
assert_eq!(r.usage.output_tokens, 25);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -584,14 +632,18 @@ mod tests {
|
|||
}
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic));
|
||||
let result = ProviderResponseType::try_from((
|
||||
bytes.as_slice(),
|
||||
&SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
&ProviderId::Anthropic,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(r.usage.prompt_tokens, 10);
|
||||
assert_eq!(r.usage.completion_tokens, 25);
|
||||
},
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
}
|
||||
|
|
@ -603,11 +655,17 @@ mod tests {
|
|||
let event: Result<SseEvent, _> = line.parse();
|
||||
assert!(event.is_ok());
|
||||
let event = event.unwrap();
|
||||
assert_eq!(event.data, Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string()));
|
||||
assert_eq!(
|
||||
event.data,
|
||||
Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string())
|
||||
);
|
||||
|
||||
// Test conversion back to line using Display trait
|
||||
let wire_format = event.to_string();
|
||||
assert_eq!(wire_format, "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n");
|
||||
assert_eq!(
|
||||
wire_format,
|
||||
"data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"
|
||||
);
|
||||
|
||||
// Test [DONE] marker - should be valid SSE event
|
||||
let done_line = "data: [DONE]";
|
||||
|
|
@ -639,10 +697,12 @@ mod tests {
|
|||
event: None,
|
||||
raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"}
|
||||
|
||||
"#.to_string(),
|
||||
"#
|
||||
.to_string(),
|
||||
sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"}
|
||||
|
||||
"#.to_string(),
|
||||
"#
|
||||
.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
|
||||
|
|
@ -679,7 +739,8 @@ mod tests {
|
|||
data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()),
|
||||
event: Some("content_block_delta".to_string()),
|
||||
raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#
|
||||
.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
assert!(!normal_event.should_skip());
|
||||
|
|
@ -705,7 +766,7 @@ mod tests {
|
|||
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
|
||||
"data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(),
|
||||
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
|
||||
"data: [DONE]".to_string(), // This should end the stream
|
||||
"data: [DONE]".to_string(), // This should end the stream
|
||||
];
|
||||
|
||||
let mut iter = SseStreamIter::new(test_lines.into_iter());
|
||||
|
|
@ -773,13 +834,15 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_provider_stream_response_event_type() {
|
||||
use crate::apis::anthropic::{MessagesStreamEvent, MessagesContentDelta};
|
||||
use crate::apis::anthropic::{MessagesContentDelta, MessagesStreamEvent};
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
|
||||
// Test Anthropic event type
|
||||
let anthropic_event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta { text: "Hello".to_string() },
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
};
|
||||
let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event);
|
||||
assert_eq!(provider_type.event_type(), Some("content_block_delta"));
|
||||
|
|
@ -806,15 +869,23 @@ mod tests {
|
|||
// Test that [DONE] marker is properly converted to MessageStop in the transformation layer
|
||||
let done_bytes = b"[DONE]";
|
||||
let client_api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions);
|
||||
let upstream_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::openai::OpenAIApi::ChatCompletions);
|
||||
|
||||
let result = ProviderStreamResponseType::try_from((done_bytes.as_slice(), &client_api, &upstream_api));
|
||||
let result = ProviderStreamResponseType::try_from((
|
||||
done_bytes.as_slice(),
|
||||
&client_api,
|
||||
&upstream_api,
|
||||
));
|
||||
assert!(result.is_ok());
|
||||
|
||||
if let Ok(ProviderStreamResponseType::MessagesStreamEvent(event)) = result {
|
||||
// Verify it's a MessageStop event
|
||||
assert_eq!(event.event_type(), Some("message_stop"));
|
||||
assert!(matches!(event, crate::apis::anthropic::MessagesStreamEvent::MessageStop));
|
||||
assert!(matches!(
|
||||
event,
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStop
|
||||
));
|
||||
} else {
|
||||
panic!("Expected MessagesStreamEvent::MessageStop");
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue