add support for agents (#564)

This commit is contained in:
Adil Hafeez 2025-10-14 14:01:11 -07:00 committed by GitHub
parent f8991a3c4b
commit 96e0732089
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 3571 additions and 856 deletions

View file

@ -5,10 +5,10 @@ use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
use crate::clients::transformer::ExtractText;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
use crate::clients::transformer::ExtractText;
use crate::{MESSAGES_PATH};
use crate::MESSAGES_PATH;
// Enum for all supported Anthropic APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@ -52,9 +52,7 @@ impl ApiDefinition for AnthropicApi {
}
fn all_variants() -> Vec<Self> {
vec![
AnthropicApi::Messages,
]
vec![AnthropicApi::Messages]
}
}
@ -100,7 +98,6 @@ pub struct McpServer {
pub tool_configuration: Option<McpToolConfiguration>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesRequest {
@ -121,10 +118,8 @@ pub struct MessagesRequest {
pub stop_sequences: Option<Vec<String>>,
pub tools: Option<Vec<MessagesTool>>,
pub tool_choice: Option<MessagesToolChoice>,
}
// Messages API specific types
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
@ -235,34 +230,21 @@ impl ExtractText for Vec<MessagesContentBlock> {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
Base64 { media_type: String, data: String },
Url { url: String },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesDocumentSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
File {
file_id: String,
},
Base64 { media_type: String, data: String },
Url { url: String },
File { file_id: String },
}
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -276,7 +258,7 @@ impl ExtractText for MessagesMessageContent {
fn extract_text(&self) -> String {
match self {
MessagesMessageContent::Single(text) => text.clone(),
MessagesMessageContent::Blocks(parts) => parts.extract_text()
MessagesMessageContent::Blocks(parts) => parts.extract_text(),
}
}
}
@ -320,7 +302,6 @@ pub struct MessagesToolChoice {
pub disable_parallel_tool_use: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MessagesStopReason {
@ -457,7 +438,11 @@ impl ProviderResponse for MessagesResponse {
Some(self)
}
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
Some((
self.usage.input_tokens as usize,
self.usage.output_tokens as usize,
(self.usage.input_tokens + self.usage.output_tokens) as usize,
))
}
}
@ -535,7 +520,7 @@ impl ProviderRequest for MessagesRequest {
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
@ -572,13 +557,11 @@ impl MessagesRole {
impl ProviderStreamResponse for MessagesStreamEvent {
fn content_delta(&self) -> Option<&str> {
match self {
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
match delta {
MessagesContentDelta::TextDelta { text } => Some(text),
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
_ => None,
}
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => match delta {
MessagesContentDelta::TextDelta { text } => Some(text),
MessagesContentDelta::ThinkingDelta { thinking } => Some(thinking),
_ => None,
},
_ => None,
}
}
@ -627,7 +610,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -687,7 +671,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -730,7 +715,10 @@ mod tests {
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["system"], original_json["system"]);
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
assert_eq!(
serialized_json["service_tier"],
original_json["service_tier"]
);
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
@ -818,7 +806,8 @@ mod tests {
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
@ -833,7 +822,10 @@ mod tests {
// Validate text content block
if let MessagesContentBlock::Text { text, .. } = &content_blocks[0] {
assert_eq!(text, "What can you see in this image and what's the weather like?");
assert_eq!(
text,
"What can you see in this image and what's the weather like?"
);
} else {
panic!("Expected text content block");
}
@ -861,20 +853,32 @@ mod tests {
// Validate thinking content block
if let MessagesContentBlock::Thinking { thinking, .. } = &content_blocks[0] {
assert_eq!(thinking, "Let me analyze the image and then check the weather...");
assert_eq!(
thinking,
"Let me analyze the image and then check the weather..."
);
} else {
panic!("Expected thinking content block");
}
// Validate text content block
if let MessagesContentBlock::Text { text, .. } = &content_blocks[1] {
assert_eq!(text, "I can see the image. Let me check the weather for you.");
assert_eq!(
text,
"I can see the image. Let me check the weather for you."
);
} else {
panic!("Expected text content block");
}
// Validate tool use content block
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = content_blocks[2] {
if let MessagesContentBlock::ToolUse {
ref id,
ref name,
ref input,
..
} = content_blocks[2]
{
assert_eq!(id, "toolu_weather123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
@ -892,7 +896,10 @@ mod tests {
let tool = &tools[0];
assert_eq!(tool.name, "get_weather");
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
assert_eq!(
tool.description,
Some("Get current weather information for a location".to_string())
);
assert_eq!(tool.input_schema["type"], "object");
assert!(tool.input_schema["properties"]["location"].is_object());
@ -938,10 +945,16 @@ mod tests {
assert_eq!(deserialized_mcp.name, "test-server");
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
assert_eq!(
deserialized_mcp.authorization_token,
Some("secret-token".to_string())
);
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
assert_eq!(
tool_config.allowed_tools,
Some(vec!["tool1".to_string(), "tool2".to_string()])
);
assert_eq!(tool_config.enabled, Some(true));
} else {
panic!("Expected tool configuration");
@ -957,7 +970,8 @@ mod tests {
"url": "https://minimal.com/mcp"
});
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
let deserialized_minimal: McpServer =
serde_json::from_value(minimal_mcp_json.clone()).unwrap();
assert_eq!(deserialized_minimal.name, "minimal-server");
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
@ -991,12 +1005,16 @@ mod tests {
}
});
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
let deserialized_response: MessagesResponse =
serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.id, "msg_01ABC123");
assert_eq!(deserialized_response.obj_type, "message");
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
assert_eq!(
deserialized_response.stop_reason,
MessagesStopReason::EndTurn
);
assert!(deserialized_response.stop_sequence.is_none());
assert!(deserialized_response.container.is_none());
@ -1011,7 +1029,10 @@ mod tests {
// Check usage
assert_eq!(deserialized_response.usage.input_tokens, 10);
assert_eq!(deserialized_response.usage.output_tokens, 25);
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
assert_eq!(
deserialized_response.usage.cache_creation_input_tokens,
Some(5)
);
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
@ -1027,7 +1048,8 @@ mod tests {
}
});
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
let deserialized_event: MessagesStreamEvent =
serde_json::from_value(stream_event_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0);
if let MessagesContentDelta::TextDelta { text } = delta {
@ -1055,8 +1077,15 @@ mod tests {
}
});
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input, .. } = deserialized_tool_use {
let deserialized_tool_use: MessagesContentBlock =
serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse {
ref id,
ref name,
ref input,
..
} = deserialized_tool_use
{
assert_eq!(id, "toolu_01ABC123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
@ -1079,8 +1108,15 @@ mod tests {
]
});
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content, .. } = deserialized_tool_result {
let deserialized_tool_result: MessagesContentBlock =
serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult {
ref tool_use_id,
ref is_error,
ref content,
..
} = deserialized_tool_result
{
assert_eq!(tool_use_id, "toolu_01ABC123");
assert!(is_error.is_none());
if let ToolResultContent::Blocks(blocks) = content {
@ -1229,7 +1265,8 @@ mod tests {
});
// Deserialize the complex MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(complex_request_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(complex_request_json.clone()).unwrap();
// Verify basic fields
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
@ -1239,8 +1276,15 @@ mod tests {
// Verify system message with cache_control
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
assert_eq!(system_blocks.len(), 2);
if let MessagesContentBlock::Text { text, cache_control } = &system_blocks[0] {
assert_eq!(text, "You are Claude Code, Anthropic's official CLI for Claude.");
if let MessagesContentBlock::Text {
text,
cache_control,
} = &system_blocks[0]
{
assert_eq!(
text,
"You are Claude Code, Anthropic's official CLI for Claude."
);
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
} else {
panic!("Expected text system message with cache_control");
@ -1253,7 +1297,13 @@ mod tests {
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, MessagesRole::Assistant);
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
if let MessagesContentBlock::ToolUse { id, name, input, cache_control } = &content_blocks[0] {
if let MessagesContentBlock::ToolUse {
id,
name,
input,
cache_control,
} = &content_blocks[0]
{
assert_eq!(id, "call_kV50LtJQKHvvzZui5TW56DUl");
assert_eq!(name, "TodoWrite");
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
@ -1272,7 +1322,12 @@ mod tests {
let user_message = &deserialized_request.messages[2];
assert_eq!(user_message.role, MessagesRole::User);
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
if let MessagesContentBlock::ToolResult { tool_use_id, content, .. } = &content_blocks[0] {
if let MessagesContentBlock::ToolResult {
tool_use_id,
content,
..
} = &content_blocks[0]
{
assert_eq!(tool_use_id, "call_kV50LtJQKHvvzZui5TW56DUl");
if let ToolResultContent::Text(text) = content {
assert!(text.contains("Todos have been modified successfully"));
@ -1284,7 +1339,11 @@ mod tests {
}
// Verify text content with cache_control
if let MessagesContentBlock::Text { text, cache_control } = &content_blocks[2] {
if let MessagesContentBlock::Text {
text,
cache_control,
} = &content_blocks[2]
{
assert_eq!(text, "try again");
assert_eq!(cache_control, &Some(MessagesCacheControl::Ephemeral));
} else {
@ -1296,11 +1355,15 @@ mod tests {
// Test serialization round-trip
let serialized_request = serde_json::to_value(&deserialized_request).unwrap();
let re_deserialized_request: MessagesRequest = serde_json::from_value(serialized_request).unwrap();
let re_deserialized_request: MessagesRequest =
serde_json::from_value(serialized_request).unwrap();
// Verify round-trip consistency
assert_eq!(deserialized_request.model, re_deserialized_request.model);
assert_eq!(deserialized_request.messages.len(), re_deserialized_request.messages.len());
assert_eq!(
deserialized_request.messages.len(),
re_deserialized_request.messages.len()
);
}
#[test]
@ -1339,7 +1402,8 @@ mod tests {
}
});
let deserialized_event: MessagesStreamEvent = serde_json::from_value(thinking_delta_json.clone()).unwrap();
let deserialized_event: MessagesStreamEvent =
serde_json::from_value(thinking_delta_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0);
if let MessagesContentDelta::ThinkingDelta { thinking } = delta {
@ -1352,7 +1416,10 @@ mod tests {
}
// Test that thinking delta is returned by content_delta()
assert_eq!(deserialized_event.content_delta(), Some(".\n\nI need to consider:\n1. Current"));
assert_eq!(
deserialized_event.content_delta(),
Some(".\n\nI need to consider:\n1. Current")
);
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
assert_eq!(thinking_delta_json, serialized_event_json);
@ -1376,7 +1443,8 @@ mod tests {
}
});
let deserialized_request: MessagesRequest = serde_json::from_value(request_json.clone()).unwrap();
let deserialized_request: MessagesRequest =
serde_json::from_value(request_json.clone()).unwrap();
assert_eq!(deserialized_request.model, "claude-sonnet-4-20250514");
assert_eq!(deserialized_request.max_tokens, 2048);

View file

@ -3,7 +3,6 @@ pub mod openai;
pub use anthropic::*;
pub use openai::*;
pub trait ApiDefinition {
/// Returns the endpoint path for this API
fn endpoint(&self) -> &'static str;
@ -49,11 +48,7 @@ mod tests {
#[test]
fn test_api_detection_from_endpoints() {
// Test that we can detect APIs from endpoints using the trait
let endpoints = vec![
CHAT_COMPLETIONS_PATH,
MESSAGES_PATH,
"/v1/unknown"
];
let endpoints = vec![CHAT_COMPLETIONS_PATH, MESSAGES_PATH, "/v1/unknown"];
let mut detected_apis = Vec::new();
@ -67,11 +62,14 @@ mod tests {
}
}
assert_eq!(detected_apis, vec![
"OpenAI: ChatCompletions",
"Anthropic: Messages",
"Unknown API"
]);
assert_eq!(
detected_apis,
vec![
"OpenAI: ChatCompletions",
"Anthropic: Messages",
"Unknown API"
]
);
}
#[test]

View file

@ -5,11 +5,11 @@ use std::collections::HashMap;
use std::fmt::Display;
use thiserror::Error;
use super::ApiDefinition;
use crate::clients::transformer::ExtractText;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
use super::ApiDefinition;
use crate::clients::transformer::{ExtractText};
use crate::{CHAT_COMPLETIONS_PATH};
use crate::CHAT_COMPLETIONS_PATH;
// ============================================================================
// OPENAI API ENUMERATION
@ -46,7 +46,7 @@ impl ApiDefinition for OpenAIApi {
}
fn supports_tools(&self) -> bool {
match self {
match self {
OpenAIApi::ChatCompletions => true,
}
}
@ -58,9 +58,7 @@ impl ApiDefinition for OpenAIApi {
}
fn all_variants() -> Vec<Self> {
vec![
OpenAIApi::ChatCompletions,
]
vec![OpenAIApi::ChatCompletions]
}
}
@ -190,7 +188,9 @@ impl ResponseMessage {
pub fn to_message(&self) -> Message {
Message {
role: self.role.clone(),
content: self.content.as_ref()
content: self
.content
.as_ref()
.map(|s| MessageContent::Text(s.clone()))
.unwrap_or(MessageContent::Text(String::new())),
name: None, // Response messages don't have names in the same way request messages do
@ -215,7 +215,7 @@ impl ExtractText for MessageContent {
fn extract_text(&self) -> String {
match self {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts.extract_text()
MessageContent::Parts(parts) => parts.extract_text(),
}
}
}
@ -274,7 +274,6 @@ pub struct ImageUrl {
/// A single message in a chat conversation
/// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall {
@ -374,7 +373,6 @@ pub enum StaticContentType {
Parts(Vec<ContentPart>),
}
/// Chat completions API response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -496,7 +494,6 @@ pub struct ChatCompletionsStreamResponse {
pub service_tier: Option<String>,
}
/// A choice in a streaming response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -566,7 +563,6 @@ pub struct Models {
pub data: Vec<ModelDetail>,
}
// Error type for streaming operations
#[derive(Debug, thiserror::Error)]
pub enum OpenAIStreamError {
@ -597,13 +593,13 @@ pub enum OpenAIError {
/// Trait Implementations
/// ===========================================================================
/// Parameterized conversion for ChatCompletionsRequest
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
let mut req: ChatCompletionsRequest =
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
// Use the centralized suppression logic
req.suppress_max_tokens_if_o3();
req.fix_temperature_if_gpt5();
@ -651,13 +647,18 @@ impl ProviderRequest for ChatCompletionsRequest {
fn extract_messages_text(&self) -> String {
self.messages.iter().fold(String::new(), |acc, m| {
acc + " " + &match &m.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
ContentPart::Text { text } => text.clone(),
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
}).collect::<Vec<_>>().join(" ")
}
acc + " "
+ &match &m.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts
.iter()
.map(|part| match part {
ContentPart::Text { text } => text.clone(),
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
})
.collect::<Vec<_>>()
.join(" "),
}
})
}
@ -721,14 +722,14 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
}
fn role(&self) -> Option<&str> {
self.choices
.first()
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
self.choices.first().and_then(|choice| {
choice.delta.role.as_ref().map(|r| match r {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}))
})
})
}
fn event_type(&self) -> Option<&str> {
@ -736,7 +737,6 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -756,7 +756,8 @@ mod tests {
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "gpt-4");
@ -799,7 +800,8 @@ mod tests {
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "gpt-4");
@ -836,7 +838,10 @@ mod tests {
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["stream"], original_json["stream"]);
assert_eq!(serialized_json["stream_options"], original_json["stream_options"]);
assert_eq!(
serialized_json["stream_options"],
original_json["stream_options"]
);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
// Handle temperature with floating point tolerance
@ -917,7 +922,8 @@ mod tests {
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
let deserialized_request: ChatCompletionsRequest =
serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
@ -953,7 +959,10 @@ mod tests {
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, Role::Assistant);
if let MessageContent::Text(text) = &assistant_message.content {
assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you.");
assert_eq!(
text,
"I can see a beautiful cityscape. Let me check the weather for you."
);
} else {
panic!("Expected text content for assistant message");
}
@ -967,7 +976,10 @@ mod tests {
assert_eq!(tool_call.id, "call_weather123");
assert_eq!(tool_call.call_type, "function");
assert_eq!(tool_call.function.name, "get_weather");
assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}");
assert_eq!(
tool_call.function.arguments,
"{\"location\": \"New York, NY\"}"
);
// Validate third message (tool response)
let tool_message = &deserialized_request.messages[2];
@ -977,7 +989,10 @@ mod tests {
} else {
panic!("Expected text content for tool message");
}
assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string()));
assert_eq!(
tool_message.tool_call_id,
Some("call_weather123".to_string())
);
// Validate tools array
assert!(deserialized_request.tools.is_some());
@ -987,7 +1002,10 @@ mod tests {
let tool = &tools[0];
assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "get_weather");
assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string()));
assert_eq!(
tool.function.description,
Some("Get current weather information for a location".to_string())
);
assert_eq!(tool.function.strict, Some(true));
// Validate tool parameters schema
@ -1093,7 +1111,8 @@ mod tests {
]
});
let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap();
let deserialized_assistant: Message =
serde_json::from_value(assistant_json.clone()).unwrap();
assert_eq!(deserialized_assistant.role, Role::Assistant);
if let MessageContent::Text(content) = &deserialized_assistant.content {
assert_eq!(content, "I'll help with that.");
@ -1142,9 +1161,13 @@ mod tests {
]
});
let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap();
let deserialized_response: ResponseMessage =
serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.role, Role::Assistant);
assert_eq!(deserialized_response.content, Some("Response content".to_string()));
assert_eq!(
deserialized_response.content,
Some("Response content".to_string())
);
assert!(deserialized_response.annotations.is_some());
assert!(deserialized_response.refusal.is_none());
assert!(deserialized_response.function_call.is_none());
@ -1186,7 +1209,10 @@ mod tests {
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
assert_eq!(
required_deserialized,
ToolChoice::Type(ToolChoiceType::Required)
);
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
// Test that invalid string values fail deserialization (type safety!)
@ -1237,7 +1263,10 @@ mod tests {
assert_eq!(response.created, 1756574706);
assert_eq!(response.model, "gpt-4o-2024-08-06");
assert_eq!(response.service_tier, Some("default".to_string()));
assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string()));
assert_eq!(
response.system_fingerprint,
Some("fp_f33640a400".to_string())
);
assert_eq!(response.choices.len(), 1);
assert_eq!(response.usage.prompt_tokens, 65);
assert_eq!(response.usage.completion_tokens, 184);