mirror of
https://github.com/katanemo/plano.git
synced 2026-05-24 14:05:14 +02:00
add support for agents (#564)
This commit is contained in:
parent
f8991a3c4b
commit
96e0732089
41 changed files with 3571 additions and 856 deletions
|
|
@ -5,10 +5,10 @@ use serde_with::skip_serializing_none;
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::{MESSAGES_PATH};
|
||||
use crate::MESSAGES_PATH;
|
||||
|
||||
// Enum for all supported Anthropic APIs
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
|
@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi {
|
|||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
AnthropicApi::Messages,
|
||||
]
|
||||
vec![AnthropicApi::Messages]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,7 +98,6 @@ pub struct McpServer {
|
|||
pub tool_configuration: Option<McpToolConfiguration>,
|
||||
}
|
||||
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesRequest {
|
||||
|
|
@ -121,10 +118,8 @@ pub struct MessagesRequest {
|
|||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub tools: Option<Vec<MessagesTool>>,
|
||||
pub tool_choice: Option<MessagesToolChoice>,
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Messages API specific types
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
|
|
@ -235,34 +230,21 @@ impl ExtractText for Vec<MessagesContentBlock> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesImageSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
Base64 { media_type: String, data: String },
|
||||
Url { url: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesDocumentSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
File {
|
||||
file_id: String,
|
||||
},
|
||||
Base64 { media_type: String, data: String },
|
||||
Url { url: String },
|
||||
File { file_id: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent {
|
|||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessagesMessageContent::Single(text) => text.clone(),
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text()
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -320,7 +302,6 @@ pub struct MessagesToolChoice {
|
|||
pub disable_parallel_tool_use: Option<bool>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesStopReason {
|
||||
|
|
@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse {
|
|||
Some(self)
|
||||
}
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
|
||||
Some((
|
||||
self.usage.input_tokens as usize,
|
||||
self.usage.output_tokens as usize,
|
||||
(self.usage.input_tokens + self.usage.output_tokens) as usize,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest {
|
|||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
return &self.metadata;
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
|
|
@ -572,13 +557,11 @@ impl MessagesRole {
|
|||
impl ProviderStreamResponse for MessagesStreamEvent {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Some(text),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Some(text),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -627,7 +610,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields are properly set
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -687,7 +671,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate required fields
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -730,7 +715,10 @@ mod tests {
|
|||
assert_eq!(serialized_json["messages"], original_json["messages"]);
|
||||
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
|
||||
assert_eq!(serialized_json["system"], original_json["system"]);
|
||||
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
|
||||
assert_eq!(
|
||||
serialized_json["service_tier"],
|
||||
original_json["service_tier"]
|
||||
);
|
||||
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
|
||||
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
|
||||
|
||||
|
|
@ -818,7 +806,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize JSON into MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(original_json.clone()).unwrap();
|
||||
|
||||
// Validate top-level fields
|
||||
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
|
||||
|
|
@ -833,7 +822,10 @@ mod tests {
|
|||
|
||||
// Validate text content block
|
||||
if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] {
|
||||
assert_eq!(text, "What can you see in this image and what's the weather like?");
|
||||
assert_eq!(
|
||||
text,
|
||||
"What can you see in this image and what's the weather like?"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
|
@ -861,20 +853,32 @@ mod tests {
|
|||
|
||||
// Validate thinking content block
|
||||
if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] {
|
||||
assert_eq!(thinking, "Let me analyze the image and then check the weather...");
|
||||
assert_eq!(
|
||||
thinking,
|
||||
"Let me analyze the image and then check the weather..."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected thinking content block");
|
||||
}
|
||||
|
||||
// Validate text content block
|
||||
if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] {
|
||||
assert_eq!(text, "I can see the image. Let me check the weather for you.");
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see the image. Let me check the weather for you."
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content block");
|
||||
}
|
||||
|
||||
// Validate tool use content block
|
||||
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
ref id,
|
||||
ref name,
|
||||
ref input,
|
||||
..
|
||||
} = content_blocks[2]
|
||||
{
|
||||
assert_eq!(id, "toolu_weather123");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco, CA");
|
||||
|
|
@ -892,7 +896,10 @@ mod tests {
|
|||
|
||||
let tool = &tools[0];
|
||||
assert_eq!(tool.name, "get_weather");
|
||||
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
|
||||
assert_eq!(
|
||||
tool.description,
|
||||
Some("Get current weather information for a location".to_string())
|
||||
);
|
||||
assert_eq!(tool.input_schema["type"], "object");
|
||||
assert!(tool.input_schema["properties"]["location"].is_object());
|
||||
|
||||
|
|
@ -938,10 +945,16 @@ mod tests {
|
|||
assert_eq!(deserialized_mcp.name, "test-server");
|
||||
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
|
||||
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
|
||||
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
|
||||
assert_eq!(
|
||||
deserialized_mcp.authorization_token,
|
||||
Some("secret-token".to_string())
|
||||
);
|
||||
|
||||
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
|
||||
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
|
||||
assert_eq!(
|
||||
tool_config.allowed_tools,
|
||||
Some(vec!["tool1".to_string(), "tool2".to_string()])
|
||||
);
|
||||
assert_eq!(tool_config.enabled, Some(true));
|
||||
} else {
|
||||
panic!("Expected tool configuration");
|
||||
|
|
@ -957,7 +970,8 @@ mod tests {
|
|||
"url": "https://minimal.com/mcp"
|
||||
});
|
||||
|
||||
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
|
||||
let deserialized_minimal: McpServer =
|
||||
serde_json::from_value(minimal_mcp_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_minimal.name, "minimal-server");
|
||||
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
|
||||
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
|
||||
|
|
@ -991,12 +1005,16 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
|
||||
let deserialized_response: MessagesResponse =
|
||||
serde_json::from_value(response_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_response.id, "msg_01ABC123");
|
||||
assert_eq!(deserialized_response.obj_type, "message");
|
||||
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
|
||||
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
|
||||
assert_eq!(
|
||||
deserialized_response.stop_reason,
|
||||
MessagesStopReason::EndTurn
|
||||
);
|
||||
assert!(deserialized_response.stop_sequence.is_none());
|
||||
assert!(deserialized_response.container.is_none());
|
||||
|
||||
|
|
@ -1011,7 +1029,10 @@ mod tests {
|
|||
// Check usage
|
||||
assert_eq!(deserialized_response.usage.input_tokens, 10);
|
||||
assert_eq!(deserialized_response.usage.output_tokens, 25);
|
||||
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
|
||||
assert_eq!(
|
||||
deserialized_response.usage.cache_creation_input_tokens,
|
||||
Some(5)
|
||||
);
|
||||
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
|
||||
|
||||
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
|
||||
|
|
@ -1027,7 +1048,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
|
||||
let deserialized_event: MessagesStreamEvent =
|
||||
serde_json::from_value(stream_event_json.clone()).unwrap();
|
||||
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
|
||||
assert_eq!(index, 0);
|
||||
if let MessagesContentDelta::TextDelta { text } = delta {
|
||||
|
|
@ -1055,8 +1077,15 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use {
|
||||
let deserialized_tool_use: MessagesContentBlock =
|
||||
serde_json::from_value(tool_use_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
ref id,
|
||||
ref name,
|
||||
ref input,
|
||||
..
|
||||
} = deserialized_tool_use
|
||||
{
|
||||
assert_eq!(id, "toolu_01ABC123");
|
||||
assert_eq!(name, "get_weather");
|
||||
assert_eq!(input["location"], "San Francisco, CA");
|
||||
|
|
@ -1079,8 +1108,15 @@ mod tests {
|
|||
]
|
||||
});
|
||||
|
||||
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result {
|
||||
let deserialized_tool_result: MessagesContentBlock =
|
||||
serde_json::from_value(tool_result_json.clone()).unwrap();
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
ref tool_use_id,
|
||||
ref is_error,
|
||||
ref content,
|
||||
..
|
||||
} = deserialized_tool_result
|
||||
{
|
||||
assert_eq!(tool_use_id, "toolu_01ABC123");
|
||||
assert!(is_error.is_none());
|
||||
if let ToolResultContent::Blocks(blocks) = content {
|
||||
|
|
@ -1229,7 +1265,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Deserialize the complex MessagesRequest
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(complex_request_json.clone()).unwrap();
|
||||
|
||||
// Verify basic fields
|
||||
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
|
||||
|
|
@ -1239,8 +1276,15 @@ mod tests {
|
|||
// Verify system message with cache_control
|
||||
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
|
||||
assert_eq!(system_blocks.len(), 2);
|
||||
if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] {
|
||||
assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude.");
|
||||
if let MessagesContentBlock::Text {
|
||||
text,
|
||||
cache_control,
|
||||
} = &system_blocks[0]
|
||||
{
|
||||
assert_eq!(
|
||||
text,
|
||||
"You are Claude Code, Anthropic's official CLI for Claude."
|
||||
);
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
} else {
|
||||
panic!("Expected text system message with cache_control");
|
||||
|
|
@ -1253,7 +1297,13 @@ mod tests {
|
|||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, MessagesRole::Assistant);
|
||||
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
|
||||
if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] {
|
||||
if let MessagesContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
cache_control,
|
||||
} = &content_blocks[0]
|
||||
{
|
||||
assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl");
|
||||
assert_eq!(name, "TodoWrite");
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
|
|
@ -1272,7 +1322,12 @@ mod tests {
|
|||
let user_message = &deserialized_request.messages[2];
|
||||
assert_eq!(user_message.role, MessagesRole::User);
|
||||
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
|
||||
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] {
|
||||
if let MessagesContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
..
|
||||
} = &content_blocks[0]
|
||||
{
|
||||
assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl");
|
||||
if let ToolResultContent::Text(text) = content {
|
||||
assert!(text.contains("Todos have been modified successfully"));
|
||||
|
|
@ -1284,7 +1339,11 @@ mod tests {
|
|||
}
|
||||
|
||||
// Verify text content with cache_control
|
||||
if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] {
|
||||
if let MessagesContentBlock::Text {
|
||||
text,
|
||||
cache_control,
|
||||
} = &content_blocks[2]
|
||||
{
|
||||
assert_eq!(text, "try again");
|
||||
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
|
||||
} else {
|
||||
|
|
@ -1296,11 +1355,15 @@ mod tests {
|
|||
|
||||
// Test serialization round-trip
|
||||
let serialized_request = serde_json::to_value(&deserialized_request).unwrap();
|
||||
let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap();
|
||||
let re_deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(serialized_request).unwrap();
|
||||
|
||||
// Verify round-trip consistency
|
||||
assert_eq!(deserialized_request.model, re_deserialized_request.model);
|
||||
assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len());
|
||||
assert_eq!(
|
||||
deserialized_request.messages.len(),
|
||||
re_deserialized_request.messages.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1339,7 +1402,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap();
|
||||
let deserialized_event: MessagesStreamEvent =
|
||||
serde_json::from_value(thinking_delta_json.clone()).unwrap();
|
||||
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
|
||||
assert_eq!(index, 0);
|
||||
if let MessagesContentDelta::ThinkingDelta { thinking } = delta {
|
||||
|
|
@ -1352,7 +1416,10 @@ mod tests {
|
|||
}
|
||||
|
||||
// Test that thinking delta is returned by content_delta()
|
||||
assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current"));
|
||||
assert_eq!(
|
||||
deserialized_event.content_delta(),
|
||||
Some(".\n\nI need to consider:\n1. Current")
|
||||
);
|
||||
|
||||
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
|
||||
assert_eq!(thinking_delta_json, serialized_event_json);
|
||||
|
|
@ -1376,7 +1443,8 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap();
|
||||
let deserialized_request: MessagesRequest =
|
||||
serde_json::from_value(request_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
|
||||
assert_eq!(deserialized_request.max_tokens, 2048);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue