mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 15:22:43 +02:00
add support for agents (#564)
This commit is contained in:
parent
f8991a3c4b
commit
96e0732089
41 changed files with 3571 additions and 856 deletions
|
|
@ -5,10 +5,10 @@ use serde_with::skip_serializing_none;
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::{MESSAGES_PATH};
|
||||
use crate::MESSAGES_PATH;
|
||||
|
||||
// Enum for all supported Anthropic APIs
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
|
@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi {
|
|||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
AnthropicApi::Messages,
|
||||
]
|
||||
vec![AnthropicApi::Messages]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,7 +98,6 @@ pub struct McpServer {
|
|||
pub tool_configuration: Option<McpToolConfiguration>,
|
||||
}
|
||||
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesRequest {
|
||||
|
|
@ -121,10 +118,8 @@ pub struct MessagesRequest {
|
|||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub tools: Option<Vec<MessagesTool>>,
|
||||
pub tool_choice: Option<MessagesToolChoice>,
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Messages API specific types
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
|
|
@ -235,34 +230,21 @@ impl ExtractText for Vec<MessagesContentBlock> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesImageSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
Base64 { media_type: String, data: String },
|
||||
Url { url: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesDocumentSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
File {
|
||||
file_id: String,
|
||||
},
|
||||
Base64 { media_type: String, data: String },
|
||||
Url { url: String },
|
||||
File { file_id: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent {
|
|||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessagesMessageContent::Single(text) => text.clone(),
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text()
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -320,7 +302,6 @@ pub struct MessagesToolChoice {
|
|||
pub disable_parallel_tool_use: Option<bool>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesStopReason {
|
||||
|
|
@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse {
|
|||
Some(self)
|
||||
}
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
|
||||
Some((
|
||||
self.usage.input_tokens as usize,
|
||||
self.usage.output_tokens as usize,
|
||||
(self.usage.input_tokens + self.usage.output_tokens) as usize,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest {
|
|||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
return &self.metadata;
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
|
|
@ -572,13 +557,11 @@ impl MessagesRole {
|
|||
impl ProviderStreamResponse for MessagesStreamEvent {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Some(text),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Some(text),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -627,7 +610,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields are properly set
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -687,7 +671,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -730,7 +715,10 @@ mod tests {
|
|||
assert_eq!(serialized_json["messages"], original_json["messages"]);
|
||||
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
|
||||
assert_eq!(serialized_json["system"], original_json["system"]);
|
||||
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
|
||||
assert_eq!(
|
||||
serialized_json["service_tier"],
|
||||
original_json["service_tier"]
|
||||
);
|
||||
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
|
||||
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
|
||||
|
||||
|
|
@ -818,7 +806,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate top-level fields
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -833,7 +822,10 @@ mod tests {
|
|||
|
||||
// Validate text content block
|
||||
if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] {
|
||||
assert_eq!(text, "What can you see in this image and what's the weather like?");
|
||||
assert_eq!(
|
||||
text,
|
||||
"What can you see in this image and what's the weather like?"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
|
@ -861,20 +853,32 @@ mod tests {
|
|||
|
||||
// Validate thinking content block
|
||||
if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] {
|
||||
assert_eq!(thinking, "Let me analyze the image and then check the weather...");
|
||||
assert_eq!(
|
||||
thinking,
|
||||
"Let me analyze the image and then check the weather..."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected thinking content block");
|
||||
}
|
||||
|
||||
// Validate text content block
|
||||
if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] {
|
||||
assert_eq!(text, "I can see the image. Let me check the weather for you.");
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see the image. Let me check the weather for you."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
||||
// Validate tool use content block
|
||||
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
ref id,
|
||||
ref name,
|
||||
ref input,
|
||||
..
|
||||
} = content_blocks[2]
|
||||
{
|
||||
assert_eq!(id, "toolu_weather123");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco, CA");
|
||||
|
|
@ -892,7 +896,10 @@ mod tests {
|
|||
|
||||
let tool = &tools[0];
|
||||
assert_eq!(tool.name, "get_weather");
|
||||
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
|
||||
assert_eq!(
|
||||
tool.description,
|
||||
Some("Get current weather information for a location".to_string())
|
||||
);
|
||||
assert_eq!(tool.input_schema["type"], "object");
|
||||
assert!(tool.input_schema["properties"]["location"].is_object());
|
||||
|
||||
|
|
@ -938,10 +945,16 @@ mod tests {
|
|||
assert_eq!(deserialized_mcp.name, "test-server");
|
||||
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
|
||||
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
|
||||
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
|
||||
assert_eq!(
|
||||
deserialized_mcp.authorization_token,
|
||||
Some("secret-token".to_string())
|
||||
);
|
||||
|
||||
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
|
||||
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
|
||||
assert_eq!(
|
||||
tool_config.allowed_tools,
|
||||
Some(vec!["tool1".to_string(), "tool2".to_string()])
|
||||
);
|
||||
assert_eq!(tool_config.enabled, Some(true));
|
||||
} else {
|
||||
panic!("Expected tool configuration");
|
||||
|
|
@ -957,7 +970,8 @@ mod tests {
|
|||
"url": "https://minimal.com/mcp"
|
||||
});
|
||||
|
||||
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
|
||||
let deserialized_minimal: McpServer =
|
||||
serde_json::from_value(minimal_mcp_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_minimal.name, "minimal-server");
|
||||
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
|
||||
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
|
||||
|
|
@ -991,12 +1005,16 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
|
||||
let deserialized_response: MessagesResponse =
|
||||
serde_json::from_value(response_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_response.id, "msg_01ABC123");
|
||||
assert_eq!(deserialized_response.obj_type, "message");
|
||||
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
|
||||
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
|
||||
assert_eq!(
|
||||
deserialized_response.stop_reason,
|
||||
MessagesStopReason::EndTurn
|
||||
);
|
||||
assert!(deserialized_response.stop_sequence.is_none());
|
||||
assert!(deserialized_response.container.is_none());
|
||||
|
||||
|
|
@ -1011,7 +1029,10 @@ mod tests {
|
|||
// Check usage
|
||||
assert_eq!(deserialized_response.usage.input_tokens, 10);
|
||||
assert_eq!(deserialized_response.usage.output_tokens, 25);
|
||||
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
|
||||
assert_eq!(
|
||||
deserialized_response.usage.cache_creation_input_tokens,
|
||||
Some(5)
|
||||
);
|
||||
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
|
||||
|
||||
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
|
||||
|
|
@ -1027,7 +1048,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
|
||||
let deserialized_event: MessagesStreamEvent =
|
||||
serde_json::from_value(stream_event_json.clone()).unwrap();
|
||||
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
|
||||
assert_eq!(index, 0);
|
||||
if let MessagesContentDelta::TextDelta { text } = delta {
|
||||
|
|
@ -1055,8 +1077,15 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use {
|
||||
let deserialized_tool_use: MessagesContentBlock =
|
||||
serde_json::from_value(tool_use_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
ref id,
|
||||
ref name,
|
||||
ref input,
|
||||
..
|
||||
} = deserialized_tool_use
|
||||
{
|
||||
assert_eq!(id, "toolu_01ABC123");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco, CA");
|
||||
|
|
@ -1079,8 +1108,15 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result {
|
||||
let deserialized_tool_result: MessagesContentBlock =
|
||||
serde_json::from_value(tool_result_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
ref tool_use_id,
|
||||
ref is_error,
|
||||
ref content,
|
||||
..
|
||||
} = deserialized_tool_result
|
||||
{
|
||||
assert_eq!(tool_use_id, "toolu_01ABC123");
|
||||
assert!(is_error.is_none());
|
||||
if let ToolResultContent::Blocks(blocks) = content {
|
||||
|
|
@ -1229,7 +1265,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize the complex MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(complex_request_json.clone()).unwrap();
|
||||
|
||||
// Verify basic fields
|
||||
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
|
||||
|
|
@ -1239,8 +1276,15 @@ mod tests {
|
|||
// Verify system message with cache_control
|
||||
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
|
||||
assert_eq!(system_blocks.len(), 2);
|
||||
if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] {
|
||||
assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude.");
|
||||
if let MessagesContentBlock::Text {
|
||||
text,
|
||||
cache_control,
|
||||
} = &system_blocks[0]
|
||||
{
|
||||
assert_eq!(
|
||||
text,
|
||||
"You are Claude Code, Anthropic's official CLI for Claude."
|
||||
);
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
} else {
|
||||
panic!("Expected text system message with cache_control");
|
||||
|
|
@ -1253,7 +1297,13 @@ mod tests {
|
|||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, MessagesRole::Assistant);
|
||||
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
|
||||
if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
cache_control,
|
||||
} = &content_blocks[0]
|
||||
{
|
||||
assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl");
|
||||
assert_eq!(name, "TodoWrite");
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
|
|
@ -1272,7 +1322,12 @@ mod tests {
|
|||
let user_message = &deserialized_request.messages[2];
|
||||
assert_eq!(user_message.role, MessagesRole::User);
|
||||
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
|
||||
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] {
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
..
|
||||
} = &content_blocks[0]
|
||||
{
|
||||
assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl");
|
||||
if let ToolResultContent::Text(text) = content {
|
||||
assert!(text.contains("Todos have been modified successfully"));
|
||||
|
|
@ -1284,7 +1339,11 @@ mod tests {
|
|||
}
|
||||
|
||||
// Verify text content with cache_control
|
||||
if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] {
|
||||
if let MessagesContentBlock::Text {
|
||||
text,
|
||||
cache_control,
|
||||
} = &content_blocks[2]
|
||||
{
|
||||
assert_eq!(text, "try again");
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
} else {
|
||||
|
|
@ -1296,11 +1355,15 @@ mod tests {
|
|||
|
||||
// Test serialization round-trip
|
||||
let serialized_request = serde_json::to_value(&deserialized_request).unwrap();
|
||||
let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap();
|
||||
let re_deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(serialized_request).unwrap();
|
||||
|
||||
// Verify round-trip consistency
|
||||
assert_eq!(deserialized_request.model, re_deserialized_request.model);
|
||||
assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len());
|
||||
assert_eq!(
|
||||
deserialized_request.messages.len(),
|
||||
re_deserialized_request.messages.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1339,7 +1402,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap();
|
||||
let deserialized_event: MessagesStreamEvent =
|
||||
serde_json::from_value(thinking_delta_json.clone()).unwrap();
|
||||
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
|
||||
assert_eq!(index, 0);
|
||||
if let MessagesContentDelta::ThinkingDelta { thinking } = delta {
|
||||
|
|
@ -1352,7 +1416,10 @@ mod tests {
|
|||
}
|
||||
|
||||
// Test that thinking delta is returned by content_delta()
|
||||
assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current"));
|
||||
assert_eq!(
|
||||
deserialized_event.content_delta(),
|
||||
Some(".\n\nI need to consider:\n1. Current")
|
||||
);
|
||||
|
||||
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
|
||||
assert_eq!(thinking_delta_json, serialized_event_json);
|
||||
|
|
@ -1376,7 +1443,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(request_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
|
||||
assert_eq!(deserialized_request.max_tokens, 2048);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ pub mod openai;
|
|||
pub use anthropic::*;
|
||||
pub use openai::*;
|
||||
|
||||
|
||||
pub trait ApiDefinition {
|
||||
/// Returns the endpoint path for this API
|
||||
fn endpoint(&self) -> &'static str;
|
||||
|
|
@ -49,11 +48,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_api_detection_from_endpoints() {
|
||||
// Test that we can detect APIs from endpoints using the trait
|
||||
let endpoints = vec![
|
||||
CHAT_COMPLETIONS_PATH,
|
||||
MESSAGES_PATH,
|
||||
"/v1/unknown"
|
||||
];
|
||||
let endpoints = vec![CHAT_COMPLETIONS_PATH, MESSAGES_PATH, "/v1/unknown"];
|
||||
|
||||
let mut detected_apis = Vec::new();
|
||||
|
||||
|
|
@ -67,11 +62,14 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
assert_eq!(detected_apis, vec![
|
||||
"OpenAI: ChatCompletions",
|
||||
"Anthropic: Messages",
|
||||
"Unknown API"
|
||||
]);
|
||||
assert_eq!(
|
||||
detected_apis,
|
||||
vec![
|
||||
"OpenAI: ChatCompletions",
|
||||
"Anthropic: Messages",
|
||||
"Unknown API"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ use std::collections::HashMap;
|
|||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::{ExtractText};
|
||||
use crate::{CHAT_COMPLETIONS_PATH};
|
||||
use crate::CHAT_COMPLETIONS_PATH;
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI API ENUMERATION
|
||||
|
|
@ -46,7 +46,7 @@ impl ApiDefinition for OpenAIApi {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
}
|
||||
}
|
||||
|
|
@ -58,9 +58,7 @@ impl ApiDefinition for OpenAIApi {
|
|||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
OpenAIApi::ChatCompletions,
|
||||
]
|
||||
vec![OpenAIApi::ChatCompletions]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -190,7 +188,9 @@ impl ResponseMessage {
|
|||
pub fn to_message(&self) -> Message {
|
||||
Message {
|
||||
role: self.role.clone(),
|
||||
content: self.content.as_ref()
|
||||
content: self
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|s| MessageContent::Text(s.clone()))
|
||||
.unwrap_or(MessageContent::Text(String::new())),
|
||||
name: None, // Response messages don't have names in the same way request messages do
|
||||
|
|
@ -215,7 +215,7 @@ impl ExtractText for MessageContent {
|
|||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -274,7 +274,6 @@ pub struct ImageUrl {
|
|||
|
||||
/// A single message in a chat conversation
|
||||
|
||||
|
||||
/// A tool call made by the assistant
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
|
|
@ -374,7 +373,6 @@ pub enum StaticContentType {
|
|||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
|
||||
/// Chat completions API response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -496,7 +494,6 @@ pub struct ChatCompletionsStreamResponse {
|
|||
pub service_tier: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
/// A choice in a streaming response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -566,7 +563,6 @@ pub struct Models {
|
|||
pub data: Vec<ModelDetail>,
|
||||
}
|
||||
|
||||
|
||||
// Error type for streaming operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OpenAIStreamError {
|
||||
|
|
@ -597,13 +593,13 @@ pub enum OpenAIError {
|
|||
/// Trait Implementations
|
||||
/// ===========================================================================
|
||||
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsRequest
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
|
||||
let mut req: ChatCompletionsRequest =
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
|
||||
// Use the centralized suppression logic
|
||||
req.suppress_max_tokens_if_o3();
|
||||
req.fix_temperature_if_gpt5();
|
||||
|
|
@ -651,13 +647,18 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
}).collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
acc + " "
|
||||
+ &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -721,14 +722,14 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
||||
self.choices.first().and_then(|choice| {
|
||||
choice.delta.role.as_ref().map(|r| match r {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
|
|
@ -736,7 +737,6 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -756,7 +756,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields are properly set
|
||||
assert_eq!(deserialized_request.model, "gpt-4");
|
||||
|
|
@ -799,7 +800,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields
|
||||
assert_eq!(deserialized_request.model, "gpt-4");
|
||||
|
|
@ -836,7 +838,10 @@ mod tests {
|
|||
assert_eq!(serialized_json["messages"], original_json["messages"]);
|
||||
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
|
||||
assert_eq!(serialized_json["stream"], original_json["stream"]);
|
||||
assert_eq!(serialized_json["stream_options"], original_json["stream_options"]);
|
||||
assert_eq!(
|
||||
serialized_json["stream_options"],
|
||||
original_json["stream_options"]
|
||||
);
|
||||
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
|
||||
|
||||
// Handle temperature with floating point tolerance
|
||||
|
|
@ -917,7 +922,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into ChatCompletionsRequest
|
||||
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate top-level fields
|
||||
assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
|
||||
|
|
@ -953,7 +959,10 @@ mod tests {
|
|||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, Role::Assistant);
|
||||
if let MessageContent::Text(text) = &assistant_message.content {
|
||||
assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you.");
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see a beautiful cityscape. Let me check the weather for you."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content for assistant message");
|
||||
}
|
||||
|
|
@ -967,7 +976,10 @@ mod tests {
|
|||
assert_eq!(tool_call.id, "call_weather123");
|
||||
assert_eq!(tool_call.call_type, "function");
|
||||
assert_eq!(tool_call.function.name, "get_weather");
|
||||
assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}");
|
||||
assert_eq!(
|
||||
tool_call.function.arguments,
|
||||
"{\"location\": \"New York, NY\"}"
|
||||
);
|
||||
|
||||
// Validate third message (tool response)
|
||||
let tool_message = &deserialized_request.messages[2];
|
||||
|
|
@ -977,7 +989,10 @@ mod tests {
|
|||
} else {
|
||||
panic!("Expected text content for tool message");
|
||||
}
|
||||
assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string()));
|
||||
assert_eq!(
|
||||
tool_message.tool_call_id,
|
||||
Some("call_weather123".to_string())
|
||||
);
|
||||
|
||||
// Validate tools array
|
||||
assert!(deserialized_request.tools.is_some());
|
||||
|
|
@ -987,7 +1002,10 @@ mod tests {
|
|||
let tool = &tools[0];
|
||||
assert_eq!(tool.tool_type, "function");
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string()));
|
||||
assert_eq!(
|
||||
tool.function.description,
|
||||
Some("Get current weather information for a location".to_string())
|
||||
);
|
||||
assert_eq!(tool.function.strict, Some(true));
|
||||
|
||||
// Validate tool parameters schema
|
||||
|
|
@ -1093,7 +1111,8 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap();
|
||||
let deserialized_assistant: Message =
|
||||
serde_json::from_value(assistant_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_assistant.role, Role::Assistant);
|
||||
if let MessageContent::Text(content) = &deserialized_assistant.content {
|
||||
assert_eq!(content, "I'll help with that.");
|
||||
|
|
@ -1142,9 +1161,13 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap();
|
||||
let deserialized_response: ResponseMessage =
|
||||
serde_json::from_value(response_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_response.role, Role::Assistant);
|
||||
assert_eq!(deserialized_response.content, Some("Response content".to_string()));
|
||||
assert_eq!(
|
||||
deserialized_response.content,
|
||||
Some("Response content".to_string())
|
||||
);
|
||||
assert!(deserialized_response.annotations.is_some());
|
||||
assert!(deserialized_response.refusal.is_none());
|
||||
assert!(deserialized_response.function_call.is_none());
|
||||
|
|
@ -1186,7 +1209,10 @@ mod tests {
|
|||
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
|
||||
|
||||
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
|
||||
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
|
||||
assert_eq!(
|
||||
required_deserialized,
|
||||
ToolChoice::Type(ToolChoiceType::Required)
|
||||
);
|
||||
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
|
||||
|
||||
// Test that invalid string values fail deserialization (type safety!)
|
||||
|
|
@ -1237,7 +1263,10 @@ mod tests {
|
|||
assert_eq!(response.created, 1756574706);
|
||||
assert_eq!(response.model, "gpt-4o-2024-08-06");
|
||||
assert_eq!(response.service_tier, Some("default".to_string()));
|
||||
assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string()));
|
||||
assert_eq!(
|
||||
response.system_fingerprint,
|
||||
Some("fp_f33640a400".to_string())
|
||||
);
|
||||
assert_eq!(response.choices.len(), 1);
|
||||
assert_eq!(response.usage.prompt_tokens, 65);
|
||||
assert_eq!(response.usage.completion_tokens, 184);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue