add support for agents (#564)

This commit is contained in:
Adil Hafeez 2025-10-14 14:01:11 -07:00 committed by GitHub
parent f8991a3c4b
commit 96e0732089
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 3571 additions and 856 deletions

View file

@ -5,10 +5,10 @@ use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
use crate::clients::transformer::ExtractText;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
use crate::clients::transformer::ExtractText;
use crate::{MESSAGES_PATH};
use crate::MESSAGES_PATH;
// Enum for all supported Anthropic APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi {
}
fn all_variants() -> Vec<Self> {
vec![
AnthropicApi::Messages,
]
vec![AnthropicApi::Messages]
}
}
@ -100,7 +98,6 @@ pub struct McpServer {
pub tool_configuration: Option<McpToolConfiguration>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesRequest {
@ -121,10 +118,8 @@ pub struct MessagesRequest {
pub stop_sequences: Option<Vec<String>>,
pub tools: Option<Vec<MessagesTool>>,
pub tool_choice: Option<MessagesToolChoice>,
}
// Messages API specific types
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
@ -235,34 +230,21 @@ impl ExtractText for Vec<MessagesContentBlock> {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
Base64 { media_type: String, data: String },
Url { url: String },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesDocumentSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
File {
file_id: String,
},
Base64 { media_type: String, data: String },
Url { url: String },
File { file_id: String },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent {
fn extract_text(&self) -> String {
match self {
MessagesMessageContent::Single(text) => text.clone(),
MessagesMessageContent::Blocks(parts) => parts.extract_text()
MessagesMessageContent::Blocks(parts) => parts.extract_text(),
}
}
}
@ -320,7 +302,6 @@ pub struct MessagesToolChoice {
pub disable_parallel_tool_use: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MessagesStopReason {
@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse {
Some(self)
}
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
Some((
self.usage.input_tokens as usize,
self.usage.output_tokens as usize,
(self.usage.input_tokens + self.usage.output_tokens) as usize,
))
}
}
@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest {
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
@ -572,13 +557,11 @@ impl MessagesRole {
impl ProviderStreamResponse for MessagesStreamEvent {
fn content_delta(&self) -> Option<&str> {
match self {
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
match delta {
MessagesContentDelta::TextDelta { text } => Some(text),
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
_ => None,
}
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta {
MessagesContentDelta::TextDelta { text } => Some(text),
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
_ => None,
},
_ => None,
}
}
@ -627,7 +610,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -687,7 +671,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -730,7 +715,10 @@ mod tests {
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["system"], original_json["system"]);
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
assert_eq!(
serialized_json["service_tier"],
original_json["service_tier"]
);
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
@ -818,7 +806,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -833,7 +822,10 @@ mod tests {
// Validate text content block
if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] {
assert_eq!(text, "What can you see in this image and what's the weather like?");
assert_eq!(
text,
"What can you see in this image and what's the weather like?"
);
} else {
panic!("Expected text content block");
}
@ -861,20 +853,32 @@ mod tests {
// Validate thinking content block
if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] {
assert_eq!(thinking, "Let me analyze the image and then check the weather...");
assert_eq!(
thinking,
"Let me analyze the image and then check the weather..."
);
} else {
panic!("Expected thinking content block");
}
// Validate text content block
if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] {
assert_eq!(text, "I can see the image. Let me check the weather for you.");
assert_eq!(
text,
"I can see the image. Let me check the weather for you."
);
} else {
panic!("Expected text content block");
}
// Validate tool use content block
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] {
if let MessagesContentBlock::ToolUse {
ref id,
ref name,
ref input,
..
} = content_blocks[2]
{
assert_eq!(id, "toolu_weather123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
@ -892,7 +896,10 @@ mod tests {
let tool = &tools[0];
assert_eq!(tool.name, "get_weather");
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
assert_eq!(
tool.description,
Some("Get current weather information for a location".to_string())
);
assert_eq!(tool.input_schema["type"], "object");
assert!(tool.input_schema["properties"]["location"].is_object());
@ -938,10 +945,16 @@ mod tests {
assert_eq!(deserialized_mcp.name, "test-server");
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
assert_eq!(
deserialized_mcp.authorization_token,
Some("secret-token".to_string())
);
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
assert_eq!(
tool_config.allowed_tools,
Some(vec!["tool1".to_string(), "tool2".to_string()])
);
assert_eq!(tool_config.enabled, Some(true));
} else {
panic!("Expected tool configuration");
@ -957,7 +970,8 @@ mod tests {
"url": "https://minimal.com/mcp"
});
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
let deserialized_minimal: McpServer =
serde_json::from_value(minimal_mcp_json.clone()).unwrap();
assert_eq!(deserialized_minimal.name, "minimal-server");
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
@ -991,12 +1005,16 @@ mod tests {
}
});
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
let deserialized_response: MessagesResponse =
serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.id, "msg_01ABC123");
assert_eq!(deserialized_response.obj_type, "message");
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
assert_eq!(
deserialized_response.stop_reason,
MessagesStopReason::EndTurn
);
assert!(deserialized_response.stop_sequence.is_none());
assert!(deserialized_response.container.is_none());
@ -1011,7 +1029,10 @@ mod tests {
// Check usage
assert_eq!(deserialized_response.usage.input_tokens, 10);
assert_eq!(deserialized_response.usage.output_tokens, 25);
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
assert_eq!(
deserialized_response.usage.cache_creation_input_tokens,
Some(5)
);
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
@ -1027,7 +1048,8 @@ mod tests {
}
});
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
let deserialized_event: MessagesStreamEvent =
serde_json::from_value(stream_event_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0);
if let MessagesContentDelta::TextDelta { text } = delta {
@ -1055,8 +1077,15 @@ mod tests {
}
});
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use {
let deserialized_tool_use: MessagesContentBlock =
serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse {
ref id,
ref name,
ref input,
..
} = deserialized_tool_use
{
assert_eq!(id, "toolu_01ABC123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
@ -1079,8 +1108,15 @@ mod tests {
]
});
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result {
let deserialized_tool_result: MessagesContentBlock =
serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult {
ref tool_use_id,
ref is_error,
ref content,
..
} = deserialized_tool_result
{
assert_eq!(tool_use_id, "toolu_01ABC123");
assert!(is_error.is_none());
if let ToolResultContent::Blocks(blocks) = content {
@ -1229,7 +1265,8 @@ mod tests {
});
// Deserialize the complex MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(complex_request_json.clone()).unwrap();
// Verify basic fields
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
@ -1239,8 +1276,15 @@ mod tests {
// Verify system message with cache_control
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
assert_eq!(system_blocks.len(), 2);
if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] {
assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude.");
if let MessagesContentBlock::Text {
text,
cache_control,
} = &system_blocks[0]
{
assert_eq!(
text,
"You are Claude Code, Anthropic's official CLI for Claude."
);
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
} else {
panic!("Expected text system message with cache_control");
@ -1253,7 +1297,13 @@ mod tests {
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, MessagesRole::Assistant);
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] {
if let MessagesContentBlock::ToolUse {
id,
name,
input,
cache_control,
} = &content_blocks[0]
{
assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl");
assert_eq!(name, "TodoWrite");
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
@ -1272,7 +1322,12 @@ mod tests {
let user_message = &deserialized_request.messages[2];
assert_eq!(user_message.role, MessagesRole::User);
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] {
if let MessagesContentBlock::ToolResult {
tool_use_id,
content,
..
} = &content_blocks[0]
{
assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl");
if let ToolResultContent::Text(text) = content {
assert!(text.contains("Todos have been modified successfully"));
@ -1284,7 +1339,11 @@ mod tests {
}
// Verify text content with cache_control
if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] {
if let MessagesContentBlock::Text {
text,
cache_control,
} = &content_blocks[2]
{
assert_eq!(text, "try again");
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
} else {
@ -1296,11 +1355,15 @@ mod tests {
// Test serialization round-trip
let serialized_request = serde_json::to_value(&deserialized_request).unwrap();
let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap();
let re_deserialized_request: MessagesRequest =
serde_json::from_value(serialized_request).unwrap();
// Verify round-trip consistency
assert_eq!(deserialized_request.model, re_deserialized_request.model);
assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len());
assert_eq!(
deserialized_request.messages.len(),
re_deserialized_request.messages.len()
);
}
#[test]
@ -1339,7 +1402,8 @@ mod tests {
}
});
let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap();
let deserialized_event: MessagesStreamEvent =
serde_json::from_value(thinking_delta_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0);
if let MessagesContentDelta::ThinkingDelta { thinking } = delta {
@ -1352,7 +1416,10 @@ mod tests {
}
// Test that thinking delta is returned by content_delta()
assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current"));
assert_eq!(
deserialized_event.content_delta(),
Some(".\n\nI need to consider:\n1. Current")
);
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
assert_eq!(thinking_delta_json, serialized_event_json);
@ -1376,7 +1443,8 @@ mod tests {
}
});
let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(request_json.clone()).unwrap();
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
assert_eq!(deserialized_request.max_tokens, 2048);