pushing new apis module for hermes (#547)

This commit is contained in:
Salman Paracha 2025-08-07 12:42:09 -07:00 committed by GitHub
parent 62a092fa63
commit 93ff4d7b1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 3878 additions and 2 deletions

View file

@ -0,0 +1,898 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
// Enum for all supported Anthropic APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AnthropicApi {
Messages,
// Future APIs can be added here:
// Embeddings,
// etc.
}
impl ApiDefinition for AnthropicApi {
fn endpoint(&self) -> &'static str {
match self {
AnthropicApi::Messages => "/v1/messages",
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
"/v1/messages" => Some(AnthropicApi::Messages),
_ => None,
}
}
fn supports_streaming(&self) -> bool {
match self {
AnthropicApi::Messages => true,
}
}
fn supports_tools(&self) -> bool {
match self {
AnthropicApi::Messages => true,
}
}
fn supports_vision(&self) -> bool {
match self {
AnthropicApi::Messages => true,
}
}
fn all_variants() -> Vec<Self> {
vec![
AnthropicApi::Messages,
]
}
}
// Service tier enum for request priority
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
Auto,
StandardOnly,
}
// Thinking configuration
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ThinkingConfig {
pub enabled: bool,
}
// MCP Server types
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum McpServerType {
Url,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct McpToolConfiguration {
pub allowed_tools: Option<Vec<String>>,
pub enabled: Option<bool>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct McpServer {
pub name: String,
#[serde(rename = "type")]
pub server_type: McpServerType,
pub url: String,
pub authorization_token: Option<String>,
pub tool_configuration: Option<McpToolConfiguration>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesRequest {
pub model: String,
pub messages: Vec<MessagesMessage>,
pub max_tokens: u32,
pub container: Option<String>,
pub mcp_servers: Option<Vec<McpServer>>,
pub system: Option<MessagesSystemPrompt>,
pub metadata: Option<HashMap<String, Value>>,
pub service_tier: Option<ServiceTier>,
pub thinking: Option<ThinkingConfig>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub stream: Option<bool>,
pub stop_sequences: Option<Vec<String>>,
pub tools: Option<Vec<MessagesTool>>,
pub tool_choice: Option<MessagesToolChoice>,
}
// Messages API specific types
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MessagesRole {
User,
Assistant,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesContentBlock {
Text {
text: String,
},
Thinking {
text: String,
},
Image {
source: MessagesImageSource,
},
Document {
source: MessagesDocumentSource,
},
ToolUse {
id: String,
name: String,
input: Value,
},
ToolResult {
tool_use_id: String,
is_error: Option<bool>,
content: Vec<MessagesContentBlock>,
},
ServerToolUse {
id: String,
name: String,
input: Value,
},
WebSearchToolResult {
tool_use_id: String,
is_error: Option<bool>,
content: Vec<MessagesContentBlock>,
},
CodeExecutionToolResult {
tool_use_id: String,
is_error: Option<bool>,
content: Vec<MessagesContentBlock>,
},
McpToolUse {
id: String,
name: String,
input: Value,
},
McpToolResult {
tool_use_id: String,
is_error: Option<bool>,
content: Vec<MessagesContentBlock>,
},
ContainerUpload {
id: String,
name: String,
media_type: String,
data: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum MessagesImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum MessagesDocumentSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
File {
file_id: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessagesMessageContent {
Single(String),
Blocks(Vec<MessagesContentBlock>),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessagesSystemPrompt {
Single(String),
Blocks(Vec<MessagesContentBlock>),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesMessage {
pub role: MessagesRole,
pub content: MessagesMessageContent,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesTool {
pub name: String,
pub description: Option<String>,
pub input_schema: Value,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum MessagesToolChoiceType {
Auto,
Any,
Tool,
None,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesToolChoice {
#[serde(rename = "type")]
pub kind: MessagesToolChoiceType,
pub name: Option<String>,
pub disable_parallel_tool_use: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MessagesStopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
PauseTurn,
Refusal,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub cache_creation_input_tokens: Option<u32>,
pub cache_read_input_tokens: Option<u32>,
}
// Container response object
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesContainer {
pub id: String,
#[serde(rename = "type")]
pub container_type: String,
pub name: String,
pub status: String,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesResponse {
pub id: String,
#[serde(rename = "type")]
pub obj_type: String,
pub role: MessagesRole,
pub content: Vec<MessagesContentBlock>,
pub model: String,
pub stop_reason: MessagesStopReason,
pub stop_sequence: Option<String>,
pub usage: MessagesUsage,
pub container: Option<MessagesContainer>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesStreamEvent {
MessageStart {
message: MessagesStreamMessage,
},
ContentBlockStart {
index: u32,
content_block: MessagesContentBlock,
},
ContentBlockDelta {
index: u32,
delta: MessagesContentDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
delta: MessagesMessageDelta,
usage: MessagesUsage,
},
MessageStop,
Ping,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesStreamMessage {
pub id: String,
#[serde(rename = "type")]
pub obj_type: String,
pub role: MessagesRole,
pub content: Vec<Value>, // Initially empty
pub model: String,
pub stop_reason: Option<MessagesStopReason>,
pub stop_sequence: Option<String>,
pub usage: MessagesUsage,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum MessagesContentDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesMessageDelta {
pub stop_reason: MessagesStopReason,
pub stop_sequence: Option<String>,
}
// Helper functions for API detection and conversion
impl MessagesRequest {
pub fn api_type() -> AnthropicApi {
AnthropicApi::Messages
}
}
impl MessagesResponse {
pub fn api_type() -> AnthropicApi {
AnthropicApi::Messages
}
}
impl MessagesStreamEvent {
pub fn api_type() -> AnthropicApi {
AnthropicApi::Messages
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_anthropic_required_fields() {
// Create a JSON object with only required fields
let original_json = json!({
"model": "claude-3-sonnet-20240229",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"max_tokens": 100
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_request.messages.len(), 1);
assert_eq!(deserialized_request.max_tokens, 100);
let message = &deserialized_request.messages[0];
assert_eq!(message.role, MessagesRole::User);
if let MessagesMessageContent::Single(content) = &message.content {
assert_eq!(content, "Hello");
} else {
panic!("Expected single content");
}
// Validate optional fields are None
assert!(deserialized_request.system.is_none());
assert!(deserialized_request.container.is_none());
assert!(deserialized_request.mcp_servers.is_none());
assert!(deserialized_request.service_tier.is_none());
assert!(deserialized_request.thinking.is_none());
assert!(deserialized_request.temperature.is_none());
assert!(deserialized_request.top_p.is_none());
assert!(deserialized_request.top_k.is_none());
assert!(deserialized_request.stream.is_none());
assert!(deserialized_request.stop_sequences.is_none());
assert!(deserialized_request.tools.is_none());
assert!(deserialized_request.tool_choice.is_none());
assert!(deserialized_request.metadata.is_none());
// Serialize back to JSON and compare
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
assert_eq!(original_json, serialized_json);
}
#[test]
fn test_anthropic_optional_fields() {
// Create a JSON object with optional fields set
let original_json = json!({
"model": "claude-3-sonnet-20240229",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"system": "You are a helpful assistant",
"service_tier": "auto",
"thinking": {
"enabled": true
},
"metadata": {
"user_id": "123"
}
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_request.messages.len(), 1);
assert_eq!(deserialized_request.max_tokens, 100);
// Validate optional fields are properly set
assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6);
assert!((deserialized_request.top_p.unwrap() - 0.9).abs() < 1e-6);
assert_eq!(deserialized_request.service_tier, Some(ServiceTier::Auto));
if let Some(MessagesSystemPrompt::Single(system)) = &deserialized_request.system {
assert_eq!(system, "You are a helpful assistant");
} else {
panic!("Expected single system prompt");
}
if let Some(thinking) = &deserialized_request.thinking {
assert_eq!(thinking.enabled, true);
} else {
panic!("Expected thinking config");
}
assert!(deserialized_request.metadata.is_some());
// Validate fields not in JSON are None
assert!(deserialized_request.container.is_none());
assert!(deserialized_request.mcp_servers.is_none());
assert!(deserialized_request.top_k.is_none());
assert!(deserialized_request.stream.is_none());
assert!(deserialized_request.stop_sequences.is_none());
assert!(deserialized_request.tools.is_none());
assert!(deserialized_request.tool_choice.is_none());
// Serialize back to JSON and compare (handle floating point precision)
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
// Compare all fields except floating point ones
assert_eq!(serialized_json["model"], original_json["model"]);
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["system"], original_json["system"]);
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
// Handle floating point fields with tolerance
let original_temp = original_json["temperature"].as_f64().unwrap();
let serialized_temp = serialized_json["temperature"].as_f64().unwrap();
assert!((original_temp - serialized_temp).abs() < 1e-6);
let original_top_p = original_json["top_p"].as_f64().unwrap();
let serialized_top_p = serialized_json["top_p"].as_f64().unwrap();
assert!((original_top_p - serialized_top_p).abs() < 1e-6);
}
#[test]
fn test_anthropic_nested_types() {
// Create a comprehensive JSON object with nested types - a MessagesRequest with complex message content and tools
let original_json = json!({
"model": "claude-3-sonnet-20240229",
"max_tokens": 1000,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What can you see in this image and what's the weather like?"
},
{
"type": "image",
"source": {
"base64": {
"media_type": "image/jpeg",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
}
}
}
]
},
{
"role": "assistant",
"content": [
{
"type": "thinking",
"text": "Let me analyze the image and then check the weather..."
},
{
"type": "text",
"text": "I can see the image. Let me check the weather for you."
},
{
"type": "tool_use",
"id": "toolu_weather123",
"name": "get_weather",
"input": {
"location": "San Francisco, CA"
}
}
]
}
],
"tools": [
{
"name": "get_weather",
"description": "Get current weather information for a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
},
"required": ["location"]
}
}
],
"tool_choice": {
"type": "auto"
},
"system": [
{
"type": "text",
"text": "You are a helpful assistant that can analyze images and provide weather information."
}
]
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_request.max_tokens, 1000);
assert_eq!(deserialized_request.messages.len(), 2);
// Validate first message (user with text and image content)
let user_message = &deserialized_request.messages[0];
assert_eq!(user_message.role, MessagesRole::User);
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
assert_eq!(content_blocks.len(), 2);
// Validate text content block
if let MessagesContentBlock::Text { text } = &content_blocks[0] {
assert_eq!(text, "What can you see in this image and what's the weather like?");
} else {
panic!("Expected text content block");
}
// Validate image content block
if let MessagesContentBlock::Image { ref source } = content_blocks[1] {
if let MessagesImageSource::Base64 { media_type, data } = source {
assert_eq!(media_type, "image/jpeg");
assert_eq!(data, "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==");
} else {
panic!("Expected base64 image source");
}
} else {
panic!("Expected image content block");
}
} else {
panic!("Expected content blocks for user message");
}
// Validate second message (assistant with thinking, text, and tool use)
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, MessagesRole::Assistant);
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
assert_eq!(content_blocks.len(), 3);
// Validate thinking content block
if let MessagesContentBlock::Thinking { text } = &content_blocks[0] {
assert_eq!(text, "Let me analyze the image and then check the weather...");
} else {
panic!("Expected thinking content block");
}
// Validate text content block
if let MessagesContentBlock::Text { text } = &content_blocks[1] {
assert_eq!(text, "I can see the image. Let me check the weather for you.");
} else {
panic!("Expected text content block");
}
// Validate tool use content block
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input } = content_blocks[2] {
assert_eq!(id, "toolu_weather123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
} else {
panic!("Expected tool use content block");
}
} else {
panic!("Expected content blocks for assistant message");
}
// Validate tools array
assert!(deserialized_request.tools.is_some());
let tools = deserialized_request.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let tool = &tools[0];
assert_eq!(tool.name, "get_weather");
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
assert_eq!(tool.input_schema["type"], "object");
assert!(tool.input_schema["properties"]["location"].is_object());
// Validate tool choice
assert!(deserialized_request.tool_choice.is_some());
let tool_choice = deserialized_request.tool_choice.as_ref().unwrap();
assert_eq!(tool_choice.kind, MessagesToolChoiceType::Auto);
assert!(tool_choice.name.is_none());
// Validate system prompt with content blocks
assert!(deserialized_request.system.is_some());
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
assert_eq!(system_blocks.len(), 1);
if let MessagesContentBlock::Text { text } = &system_blocks[0] {
assert_eq!(text, "You are a helpful assistant that can analyze images and provide weather information.");
} else {
panic!("Expected text content block in system prompt");
}
} else {
panic!("Expected system prompt with content blocks");
}
// Serialize back to JSON and compare
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
assert_eq!(original_json, serialized_json);
}
#[test]
fn test_anthropic_mcp_server_configuration() {
// Test MCP Server configuration with JSON-first approach
let mcp_server_json = json!({
"name": "test-server",
"type": "url",
"url": "https://example.com/mcp",
"authorization_token": "secret-token",
"tool_configuration": {
"allowed_tools": ["tool1", "tool2"],
"enabled": true
}
});
let deserialized_mcp: McpServer = serde_json::from_value(mcp_server_json.clone()).unwrap();
assert_eq!(deserialized_mcp.name, "test-server");
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
assert_eq!(tool_config.enabled, Some(true));
} else {
panic!("Expected tool configuration");
}
let serialized_mcp_json = serde_json::to_value(&deserialized_mcp).unwrap();
assert_eq!(mcp_server_json, serialized_mcp_json);
// Test MCP Server with minimal configuration (optional fields as None)
let minimal_mcp_json = json!({
"name": "minimal-server",
"type": "url",
"url": "https://minimal.com/mcp"
});
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
assert_eq!(deserialized_minimal.name, "minimal-server");
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
assert!(deserialized_minimal.authorization_token.is_none());
assert!(deserialized_minimal.tool_configuration.is_none());
let serialized_minimal_json = serde_json::to_value(&deserialized_minimal).unwrap();
assert_eq!(minimal_mcp_json, serialized_minimal_json);
}
#[test]
fn test_anthropic_response_types() {
// Test MessagesResponse deserialization
let response_json = json!({
"id": "msg_01ABC123",
"type": "message",
"role": "assistant",
"content": [
{
"type": "text",
"text": "Hello! How can I help you today?"
}
],
"model": "claude-3-sonnet-20240229",
"stop_reason": "end_turn",
"usage": {
"input_tokens": 10,
"output_tokens": 25,
"cache_creation_input_tokens": 5,
"cache_read_input_tokens": 3
}
});
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.id, "msg_01ABC123");
assert_eq!(deserialized_response.obj_type, "message");
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
assert!(deserialized_response.stop_sequence.is_none());
assert!(deserialized_response.container.is_none());
// Check content
assert_eq!(deserialized_response.content.len(), 1);
if let MessagesContentBlock::Text { text } = &deserialized_response.content[0] {
assert_eq!(text, "Hello! How can I help you today?");
} else {
panic!("Expected text content block");
}
// Check usage
assert_eq!(deserialized_response.usage.input_tokens, 10);
assert_eq!(deserialized_response.usage.output_tokens, 25);
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
assert_eq!(response_json, serialized_response_json);
// Test streaming event
let stream_event_json = json!({
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": " How"
}
});
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0);
if let MessagesContentDelta::TextDelta { text } = delta {
assert_eq!(text, " How");
} else {
panic!("Expected text delta");
}
} else {
panic!("Expected content block delta event");
}
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
assert_eq!(stream_event_json, serialized_event_json);
}
#[test]
fn test_anthropic_tool_use_content() {
// Test tool use and tool result content blocks
let tool_use_json = json!({
"type": "tool_use",
"id": "toolu_01ABC123",
"name": "get_weather",
"input": {
"location": "San Francisco, CA"
}
});
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input } = deserialized_tool_use {
assert_eq!(id, "toolu_01ABC123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
} else {
panic!("Expected tool use content block");
}
let serialized_tool_use_json = serde_json::to_value(&deserialized_tool_use).unwrap();
assert_eq!(tool_use_json, serialized_tool_use_json);
// Test tool result content block
let tool_result_json = json!({
"type": "tool_result",
"tool_use_id": "toolu_01ABC123",
"content": [
{
"type": "text",
"text": "The weather in San Francisco is sunny, 72°F"
}
]
});
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content } = deserialized_tool_result {
assert_eq!(tool_use_id, "toolu_01ABC123");
assert!(is_error.is_none());
assert_eq!(content.len(), 1);
if let MessagesContentBlock::Text { text } = &content[0] {
assert_eq!(text, "The weather in San Francisco is sunny, 72°F");
} else {
panic!("Expected text content in tool result");
}
} else {
panic!("Expected tool result content block");
}
let serialized_tool_result_json = serde_json::to_value(&deserialized_tool_result).unwrap();
assert_eq!(tool_result_json, serialized_tool_result_json);
}
#[test]
fn test_anthropic_api_provider_trait_implementation() {
// Test that AnthropicApi implements ApiDefinition trait correctly
let api = AnthropicApi::Messages;
// Test trait methods
assert_eq!(api.endpoint(), "/v1/messages");
assert!(api.supports_streaming());
assert!(api.supports_tools());
assert!(api.supports_vision());
// Test from_endpoint trait method
let found_api = AnthropicApi::from_endpoint("/v1/messages");
assert_eq!(found_api, Some(AnthropicApi::Messages));
let not_found = AnthropicApi::from_endpoint("/v1/unknown");
assert_eq!(not_found, None);
// Test all_variants
let all_variants = AnthropicApi::all_variants();
assert_eq!(all_variants.len(), 1);
assert_eq!(all_variants[0], AnthropicApi::Messages);
}
}

View file

@ -0,0 +1,197 @@
pub mod anthropic;
pub mod openai;
// Re-export all types for convenience
pub use anthropic::*;
pub use openai::*;
/// Common trait that all API definitions must implement
///
/// This trait ensures consistency across different AI provider API definitions
/// and makes it easy to add new providers like Gemini, Claude, etc.
///
/// Note: This is different from the `ApiProvider` enum in `clients::endpoints`
/// which represents provider identification, while this trait defines API capabilities.
///
/// # Benefits
///
/// - **Consistency**: All API providers implement the same interface
/// - **Extensibility**: Easy to add new providers without breaking existing code
/// - **Type Safety**: Compile-time guarantees that all providers implement required methods
/// - **Discoverability**: Clear documentation of what capabilities each API supports
///
/// # Example implementation for a new provider:
///
/// ```rust,ignore
/// use serde::{Deserialize, Serialize};
/// use super::ApiDefinition;
///
/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
/// pub enum GeminiApi {
/// GenerateContent,
/// ChatCompletions,
/// }
///
/// impl GeminiApi {
/// pub fn endpoint(&self) -> &'static str {
/// match self {
/// GeminiApi::GenerateContent => "/v1/models/gemini-pro:generateContent",
/// GeminiApi::ChatCompletions => "/v1/models/gemini-pro:chat",
/// }
/// }
///
/// pub fn from_endpoint(endpoint: &str) -> Option<Self> {
/// match endpoint {
/// "/v1/models/gemini-pro:generateContent" => Some(GeminiApi::GenerateContent),
/// "/v1/models/gemini-pro:chat" => Some(GeminiApi::ChatCompletions),
/// _ => None,
/// }
/// }
///
/// pub fn supports_streaming(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => true,
/// }
/// }
///
/// pub fn supports_tools(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => false,
/// }
/// }
///
/// pub fn supports_vision(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => false,
/// }
/// }
/// }
///
/// impl ApiDefinition for GeminiApi {
/// fn endpoint(&self) -> &'static str {
/// self.endpoint()
/// }
///
/// fn from_endpoint(endpoint: &str) -> Option<Self> {
/// Self::from_endpoint(endpoint)
/// }
///
/// fn supports_streaming(&self) -> bool {
/// self.supports_streaming()
/// }
///
/// fn supports_tools(&self) -> bool {
/// self.supports_tools()
/// }
///
/// fn supports_vision(&self) -> bool {
/// self.supports_vision()
/// }
/// }
///
/// // Now you can use generic code that works with any API:
/// fn print_api_info<T: ApiDefinition>(api: &T) {
/// println!("Endpoint: {}", api.endpoint());
/// println!("Supports streaming: {}", api.supports_streaming());
/// println!("Supports tools: {}", api.supports_tools());
/// println!("Supports vision: {}", api.supports_vision());
/// }
///
/// // Works with both OpenAI and Anthropic (and future Gemini)
/// print_api_info(&OpenAIApi::ChatCompletions);
/// print_api_info(&AnthropicApi::Messages);
/// print_api_info(&GeminiApi::GenerateContent);
/// ```
pub trait ApiDefinition {
/// Returns the endpoint path for this API
fn endpoint(&self) -> &'static str;
/// Creates an API instance from an endpoint path
fn from_endpoint(endpoint: &str) -> Option<Self>
where
Self: Sized;
/// Returns whether this API supports streaming responses
fn supports_streaming(&self) -> bool;
/// Returns whether this API supports tool/function calling
fn supports_tools(&self) -> bool;
/// Returns whether this API supports vision/image processing
fn supports_vision(&self) -> bool;
/// Returns all variants of this API enum
fn all_variants() -> Vec<Self>
where
Self: Sized;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generic_api_functionality() {
// Test that our generic API functionality works with both providers
fn test_api<T: ApiDefinition>(api: &T) {
let endpoint = api.endpoint();
assert!(!endpoint.is_empty());
assert!(endpoint.starts_with('/'));
}
test_api(&OpenAIApi::ChatCompletions);
test_api(&AnthropicApi::Messages);
}
#[test]
fn test_api_detection_from_endpoints() {
// Test that we can detect APIs from endpoints using the trait
let endpoints = vec![
"/v1/chat/completions",
"/v1/messages",
"/v1/unknown"
];
let mut detected_apis = Vec::new();
for endpoint in endpoints {
if let Some(api) = OpenAIApi::from_endpoint(endpoint) {
detected_apis.push(format!("OpenAI: {:?}", api));
} else if let Some(api) = AnthropicApi::from_endpoint(endpoint) {
detected_apis.push(format!("Anthropic: {:?}", api));
} else {
detected_apis.push("Unknown API".to_string());
}
}
assert_eq!(detected_apis, vec![
"OpenAI: ChatCompletions",
"Anthropic: Messages",
"Unknown API"
]);
}
#[test]
fn test_all_variants_method() {
// Test that all_variants returns the expected variants
let openai_variants = OpenAIApi::all_variants();
assert_eq!(openai_variants.len(), 1);
assert!(openai_variants.contains(&OpenAIApi::ChatCompletions));
let anthropic_variants = AnthropicApi::all_variants();
assert_eq!(anthropic_variants.len(), 1);
assert!(anthropic_variants.contains(&AnthropicApi::Messages));
// Verify each variant has a valid endpoint
for variant in openai_variants {
assert!(!variant.endpoint().is_empty());
}
for variant in anthropic_variants {
assert!(!variant.endpoint().is_empty());
}
}
}

View file

@ -0,0 +1,883 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
// ============================================================================
// OPENAI API ENUMERATION
// ============================================================================
/// Enum for all supported OpenAI APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum OpenAIApi {
ChatCompletions,
// Future APIs can be added here:
// Embeddings,
// FineTuning,
// etc.
}
impl ApiDefinition for OpenAIApi {
fn endpoint(&self) -> &'static str {
match self {
OpenAIApi::ChatCompletions => "/v1/chat/completions",
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
"/v1/chat/completions" => Some(OpenAIApi::ChatCompletions),
_ => None,
}
}
fn supports_streaming(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
}
}
fn supports_tools(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
}
}
fn supports_vision(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
}
}
fn all_variants() -> Vec<Self> {
vec![
OpenAIApi::ChatCompletions,
]
}
}
/// Chat completions API request
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ChatCompletionsRequest {
pub messages: Vec<Message>,
pub model: String,
// pub audio: Option<Audio> // GOOD FIRST ISSUE: future support for audio input
pub frequency_penalty: Option<f32>,
// Function calling configuration has been deprecated, but we keep it for compatibility
pub function_call: Option<FunctionChoice>,
pub functions: Option<Vec<Tool>>,
pub logit_bias: Option<HashMap<String, i32>>,
pub logprobs: Option<bool>,
pub max_completion_tokens: Option<u32>,
// Maximum tokens in the response has been deprecated, but we keep it for compatibility
pub max_tokens: Option<u32>,
pub modalities: Option<Vec<String>>,
pub metadata: Option<HashMap<String, String>>,
pub n: Option<u32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
pub prediction: Option<StaticContent>,
// pub reasoning_effect: Option<bool>, // GOOD FIRST ISSUE: Future support for reasoning effects
pub response_format: Option<Value>,
// pub safety_identifier: Option<String>, // GOOD FIRST ISSUE: Future support for safety identifiers
pub seed: Option<i32>,
pub service_tier: Option<String>,
pub stop: Option<Vec<String>>,
pub store: Option<bool>,
pub stream: Option<bool>,
pub stream_options: Option<StreamOptions>,
pub temperature: Option<f32>,
pub tool_choice: Option<ToolChoice>,
pub tools: Option<Vec<Tool>>,
pub top_p: Option<f32>,
pub top_logprobs: Option<u32>,
pub user: Option<String>,
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
}
// ============================================================================
// CHAT COMPLETIONS API TYPES
// ============================================================================
/// Message role in a chat conversation
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub content: MessageContent,
pub role: Role,
pub name: Option<String>,
/// Tool calls made by the assistant (only present for assistant role)
pub tool_calls: Option<Vec<ToolCall>>,
/// ID of the tool call that this message is responding to (only present for tool role)
pub tool_call_id: Option<String>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ResponseMessage {
pub role: Role,
/// The contents of the message (can be null for some cases)
pub content: Option<String>,
/// The refusal message generated by the model
pub refusal: Option<String>,
/// Annotations for the message, when applicable, as when using the web search tool
pub annotations: Option<Vec<Value>>,
/// If the audio output modality is requested, this object contains data about the audio response
pub audio: Option<Value>,
/// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called
pub function_call: Option<FunctionCall>,
/// The tool calls generated by the model, such as function calls
pub tool_calls: Option<Vec<ToolCall>>,
}
impl ResponseMessage {
/// Convert ResponseMessage to Message for internal processing
/// This is useful for transformations that need to work with the request Message type
pub fn to_message(&self) -> Message {
Message {
role: self.role.clone(),
content: self.content.as_ref()
.map(|s| MessageContent::Text(s.clone()))
.unwrap_or(MessageContent::Text(String::new())),
name: None, // Response messages don't have names in the same way request messages do
tool_calls: self.tool_calls.clone(),
tool_call_id: None, // Response messages don't have tool_call_id
}
}
}
/// In the OpenAI API, this is represented as either:
/// - A string for simple text content
/// - An array of content parts for multimodal content (text + images)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
/// Individual content part within a message (text or image)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
/// Image URL configuration for vision capabilities
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImageUrl {
pub url: String,
pub detail: Option<String>,
}
/// A single message in a chat conversation
/// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
/// Function call within a tool call
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
/// Tool definition for function calling
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: Function,
}
/// Function definition within a tool
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Function {
pub name: String,
pub description: Option<String>,
pub parameters: Value,
pub strict: Option<bool>,
}
/// Tool choice string values
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoiceType {
/// Let the model automatically decide whether to call tools
Auto,
/// Force the model to call at least one tool
Required,
/// Prevent the model from calling any tools
None,
}
/// Tool choice configuration
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum ToolChoice {
/// String-based tool choice (auto, required, none)
Type(ToolChoiceType),
/// Specific function to call
Function {
#[serde(rename = "type")]
choice_type: String,
function: FunctionChoice,
},
}
/// Specific function choice
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionChoice {
pub name: String,
}
/// Static content for prediction/prefill functionality
///
/// Static predicted output content, such as the content of a text file
/// that is being regenerated.
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StaticContent {
/// The type of the predicted content you want to provide.
/// This type is currently always "content".
#[serde(rename = "type")]
pub content_type: String,
/// The content that should be matched when generating a model response.
/// If generated tokens would match this content, the entire model response
/// can be returned much more quickly.
///
/// Can be either:
/// - A string for simple text content
/// - An array of content parts for structured content
pub content: StaticContentType,
}
/// Content type for static/predicted content
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum StaticContentType {
/// Simple text content - the content used for a Predicted Output.
/// This is often the text of a file you are regenerating with minor changes.
Text(String),
/// An array of content parts with a defined type.
/// Can contain text inputs and other supported content types.
Parts(Vec<ContentPart>),
}
/// Chat completions API response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatCompletionsResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
pub system_fingerprint: Option<String>,
}
/// Finish reason for completion
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
FunctionCall, // Legacy
}
/// Token usage information
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub prompt_tokens_details: Option<PromptTokensDetails>,
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
/// Detailed breakdown of prompt tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PromptTokensDetails {
pub cached_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
}
/// Detailed breakdown of completion tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
pub accepted_prediction_tokens: Option<u32>,
pub rejected_prediction_tokens: Option<u32>,
}
/// A single choice in the response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Choice {
pub index: u32,
pub message: ResponseMessage,
pub finish_reason: Option<FinishReason>,
pub logprobs: Option<Value>,
}
// ============================================================================
// STREAMING API TYPES
// ============================================================================
/// Streaming response from chat completions API
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatCompletionsStreamResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
pub usage: Option<Usage>, // Only in final chunk
pub system_fingerprint: Option<String>,
/// Specifies the processing type used for serving the request
pub service_tier: Option<String>,
}
/// A choice in a streaming response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamChoice {
pub index: u32,
pub delta: MessageDelta,
pub finish_reason: Option<FinishReason>,
pub logprobs: Option<Value>,
}
/// Message delta for streaming updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessageDelta {
pub role: Option<Role>,
pub content: Option<String>,
/// The refusal message generated by the model
pub refusal: Option<String>,
/// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called
pub function_call: Option<FunctionCall>,
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
/// Tool call delta for streaming tool call updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCallDelta {
pub index: u32,
pub id: Option<String>,
#[serde(rename = "type")]
pub call_type: Option<String>,
pub function: Option<FunctionCallDelta>,
}
/// Function call delta for streaming function call updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCallDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
/// Stream options for controlling streaming behavior
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamOptions {
pub include_usage: Option<bool>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_required_fields() {
// Create a JSON object with only required fields
let original_json = json!({
"model": "gpt-4",
"messages": [
{
"content": "Hello, world!",
"role": "user"
}
]
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "gpt-4");
assert_eq!(deserialized_request.messages.len(), 1);
let message = &deserialized_request.messages[0];
assert_eq!(message.role, Role::User);
if let MessageContent::Text(content) = &message.content {
assert_eq!(content, "Hello, world!");
} else {
panic!("Expected text content");
}
// Serialize the ChatCompletionsRequest back to JSON
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
assert_eq!(original_json, serialized_json);
}
#[test]
fn test_optional_fields_serialization() {
// Create a JSON object with optional fields set
let original_json = json!({
"model": "gpt-4",
"messages": [
{
"content": "Test message",
"role": "user",
"name": "test_user"
}
],
"temperature": 0.7,
"max_tokens": 150,
"stream": true,
"stream_options": {
"include_usage": true
},
"metadata": {
"user_id": "123"
}
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "gpt-4");
assert_eq!(deserialized_request.messages.len(), 1);
let message = &deserialized_request.messages[0];
assert_eq!(message.role, Role::User);
if let MessageContent::Text(content) = &message.content {
assert_eq!(content, "Test message");
} else {
panic!("Expected text content");
}
assert_eq!(message.name, Some("test_user".to_string()));
// Validate optional fields are properly set
assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6);
assert_eq!(deserialized_request.max_tokens, Some(150));
assert_eq!(deserialized_request.stream, Some(true));
assert!(deserialized_request.stream_options.is_some());
assert!(deserialized_request.metadata.is_some());
// Validate fields not in JSON are None
assert!(deserialized_request.top_p.is_none());
assert!(deserialized_request.frequency_penalty.is_none());
assert!(deserialized_request.presence_penalty.is_none());
assert!(deserialized_request.stop.is_none());
assert!(deserialized_request.tools.is_none());
// Serialize back to JSON and compare (handle floating point precision)
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
// Compare all fields except temperature which needs floating point comparison
assert_eq!(serialized_json["model"], original_json["model"]);
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["stream"], original_json["stream"]);
assert_eq!(serialized_json["stream_options"], original_json["stream_options"]);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
// Handle temperature with floating point tolerance
let original_temp = original_json["temperature"].as_f64().unwrap();
let serialized_temp = serialized_json["temperature"].as_f64().unwrap();
assert!((original_temp - serialized_temp).abs() < 1e-6);
}
#[test]
fn test_nested_types_serialization() {
// Create a comprehensive JSON object with nested types - a ChatCompletionsRequest with complex message content and tools
let original_json = json!({
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What can you see in this image and what's the weather like in the location shown?"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/cityscape.jpg",
"detail": "high"
}
}
]
},
{
"role": "assistant",
"content": "I can see a beautiful cityscape. Let me check the weather for you.",
"tool_calls": [
{
"id": "call_weather123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\": \"New York, NY\"}"
}
}
]
},
{
"role": "tool",
"content": "Current weather in New York: 72°F, sunny",
"tool_call_id": "call_weather123"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather information for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
},
"required": ["location"]
},
"strict": true
}
}
],
"tool_choice": "auto",
"temperature": 0.7,
"max_tokens": 1000,
"prediction": {
"type": "content",
"content": "Based on the image analysis and weather data, I can provide you with comprehensive information."
}
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
assert_eq!(deserialized_request.messages.len(), 3);
assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6);
assert_eq!(deserialized_request.max_tokens, Some(1000));
// Validate first message (user with multimodal content)
let user_message = &deserialized_request.messages[0];
assert_eq!(user_message.role, Role::User);
if let MessageContent::Parts(ref content_parts) = user_message.content {
assert_eq!(content_parts.len(), 2);
// Validate text content part
if let ContentPart::Text { text } = &content_parts[0] {
assert_eq!(text, "What can you see in this image and what's the weather like in the location shown?");
} else {
panic!("Expected text content part");
}
// Validate image URL content part
if let ContentPart::ImageUrl { ref image_url } = content_parts[1] {
assert_eq!(image_url.url, "https://example.com/cityscape.jpg");
assert_eq!(image_url.detail, Some("high".to_string()));
} else {
panic!("Expected image URL content part");
}
} else {
panic!("Expected multimodal content parts for user message");
}
// Validate second message (assistant with tool calls)
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, Role::Assistant);
if let MessageContent::Text(text) = &assistant_message.content {
assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you.");
} else {
panic!("Expected text content for assistant message");
}
// Validate tool calls in assistant message
assert!(assistant_message.tool_calls.is_some());
let tool_calls = assistant_message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.id, "call_weather123");
assert_eq!(tool_call.call_type, "function");
assert_eq!(tool_call.function.name, "get_weather");
assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}");
// Validate third message (tool response)
let tool_message = &deserialized_request.messages[2];
assert_eq!(tool_message.role, Role::Tool);
if let MessageContent::Text(text) = &tool_message.content {
assert_eq!(text, "Current weather in New York: 72°F, sunny");
} else {
panic!("Expected text content for tool message");
}
assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string()));
// Validate tools array
assert!(deserialized_request.tools.is_some());
let tools = deserialized_request.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let tool = &tools[0];
assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "get_weather");
assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string()));
assert_eq!(tool.function.strict, Some(true));
// Validate tool parameters schema
let parameters = &tool.function.parameters;
assert_eq!(parameters["type"], "object");
assert!(parameters["properties"]["location"].is_object());
assert_eq!(parameters["required"], json!(["location"]));
// Validate tool choice
if let Some(ToolChoice::Type(choice)) = &deserialized_request.tool_choice {
assert_eq!(choice, &ToolChoiceType::Auto);
} else {
panic!("Expected auto tool choice");
}
// Validate prediction
assert!(deserialized_request.prediction.is_some());
let prediction = deserialized_request.prediction.as_ref().unwrap();
assert_eq!(prediction.content_type, "content");
if let StaticContentType::Text(text) = &prediction.content {
assert_eq!(text, "Based on the image analysis and weather data, I can provide you with comprehensive information.");
} else {
panic!("Expected text prediction content");
}
// Serialize back to JSON and compare (handle floating point precision)
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
// Compare all fields except floating point ones
assert_eq!(serialized_json["model"], original_json["model"]);
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["tools"], original_json["tools"]);
assert_eq!(serialized_json["tool_choice"], original_json["tool_choice"]);
assert_eq!(serialized_json["prediction"], original_json["prediction"]);
// Handle floating point field with tolerance
let original_temp = original_json["temperature"].as_f64().unwrap();
let serialized_temp = serialized_json["temperature"].as_f64().unwrap();
assert!((original_temp - serialized_temp).abs() < 1e-6);
}
#[test]
fn test_api_provider_trait() {
// Test the ApiDefinition trait implementation
let api = OpenAIApi::ChatCompletions;
// Test trait methods
assert_eq!(api.endpoint(), "/v1/chat/completions");
assert!(api.supports_streaming());
assert!(api.supports_tools());
assert!(api.supports_vision());
// Test from_endpoint
let found_api = OpenAIApi::from_endpoint("/v1/chat/completions");
assert_eq!(found_api, Some(OpenAIApi::ChatCompletions));
let not_found = OpenAIApi::from_endpoint("/v1/unknown");
assert_eq!(not_found, None);
// Test all_variants
let all_variants = OpenAIApi::all_variants();
assert_eq!(all_variants.len(), 1);
assert_eq!(all_variants[0], OpenAIApi::ChatCompletions);
}
#[test]
fn test_role_specific_behavior() {
// Test 1: User message - basic content, no tool-related fields
let user_json = json!({
"content": "Hello!",
"role": "user",
"name": "user123"
});
let deserialized_user: Message = serde_json::from_value(user_json.clone()).unwrap();
assert_eq!(deserialized_user.role, Role::User);
if let MessageContent::Text(content) = &deserialized_user.content {
assert_eq!(content, "Hello!");
} else {
panic!("Expected text content");
}
assert_eq!(deserialized_user.name, Some("user123".to_string()));
assert!(deserialized_user.tool_calls.is_none());
assert!(deserialized_user.tool_call_id.is_none());
let serialized_user_json = serde_json::to_value(&deserialized_user).unwrap();
assert_eq!(user_json, serialized_user_json);
// Test 2: Assistant message with tool calls
let assistant_json = json!({
"content": "I'll help with that.",
"role": "assistant",
"tool_calls": [
{
"id": "call_456",
"type": "function",
"function": {
"name": "get_weather",
"arguments": r#"{"location":"SF"}"#
}
}
]
});
let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap();
assert_eq!(deserialized_assistant.role, Role::Assistant);
if let MessageContent::Text(content) = &deserialized_assistant.content {
assert_eq!(content, "I'll help with that.");
} else {
panic!("Expected text content");
}
assert!(deserialized_assistant.tool_calls.is_some());
assert!(deserialized_assistant.tool_call_id.is_none());
assert!(deserialized_assistant.name.is_none());
let tool_calls = deserialized_assistant.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "call_456");
assert_eq!(tool_calls[0].function.name, "get_weather");
let serialized_assistant_json = serde_json::to_value(&deserialized_assistant).unwrap();
assert_eq!(assistant_json, serialized_assistant_json);
// Test 3: Tool message responding to a call
let tool_json = json!({
"content": "Weather is sunny",
"role": "tool",
"tool_call_id": "call_456"
});
let deserialized_tool: Message = serde_json::from_value(tool_json.clone()).unwrap();
assert_eq!(deserialized_tool.role, Role::Tool);
if let MessageContent::Text(content) = &deserialized_tool.content {
assert_eq!(content, "Weather is sunny");
} else {
panic!("Expected text content");
}
assert_eq!(deserialized_tool.tool_call_id, Some("call_456".to_string()));
assert!(deserialized_tool.tool_calls.is_none());
assert!(deserialized_tool.name.is_none());
let serialized_tool_json = serde_json::to_value(&deserialized_tool).unwrap();
assert_eq!(tool_json, serialized_tool_json);
// Test 4: ResponseMessage vs Message differences
let response_json = json!({
"role": "assistant",
"content": "Response content",
"annotations": [
{"type": "citation"}
]
});
let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.role, Role::Assistant);
assert_eq!(deserialized_response.content, Some("Response content".to_string()));
assert!(deserialized_response.annotations.is_some());
assert!(deserialized_response.refusal.is_none());
assert!(deserialized_response.function_call.is_none());
assert!(deserialized_response.tool_calls.is_none());
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
assert_eq!(response_json, serialized_response_json);
// Test conversion from ResponseMessage to Message
let converted = deserialized_response.to_message();
assert_eq!(converted.role, Role::Assistant);
if let MessageContent::Text(text) = converted.content {
assert_eq!(text, "Response content");
} else {
panic!("Expected text content");
}
assert!(converted.name.is_none());
assert!(converted.tool_call_id.is_none());
}
#[test]
fn test_tool_choice_type_serialization() {
// Test that the enum serializes to the correct string values
let auto_choice = ToolChoice::Type(ToolChoiceType::Auto);
let required_choice = ToolChoice::Type(ToolChoiceType::Required);
let none_choice = ToolChoice::Type(ToolChoiceType::None);
let auto_json = serde_json::to_value(&auto_choice).unwrap();
let required_json = serde_json::to_value(&required_choice).unwrap();
let none_json = serde_json::to_value(&none_choice).unwrap();
assert_eq!(auto_json, "auto");
assert_eq!(required_json, "required");
assert_eq!(none_json, "none");
// Test deserialization from string values
let auto_deserialized: ToolChoice = serde_json::from_value(json!("auto")).unwrap();
let required_deserialized: ToolChoice = serde_json::from_value(json!("required")).unwrap();
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
// Test that invalid string values fail deserialization (type safety!)
let invalid_result: Result<ToolChoice, _> = serde_json::from_value(json!("invalid"));
assert!(invalid_result.is_err());
}
}

View file

@ -0,0 +1,130 @@
//! Supported endpoint registry for LLM APIs
//!
//! This module provides a simple registry to check which API endpoint paths
//! we support across different providers.
//!
//! # Examples
//!
//! ```rust
//! use hermesllm::clients::endpoints::{is_supported_endpoint, supported_endpoints};
//!
//! // Check if we support an endpoint
//! assert!(is_supported_endpoint("/v1/chat/completions"));
//! assert!(is_supported_endpoint("/v1/messages"));
//! assert!(!is_supported_endpoint("/v1/unknown"));
//!
//! // Get all supported endpoints
//! let endpoints = supported_endpoints();
//! assert_eq!(endpoints.len(), 2);
//! assert!(endpoints.contains(&"/v1/chat/completions"));
//! assert!(endpoints.contains(&"/v1/messages"));
//! ```
use crate::apis::{AnthropicApi, OpenAIApi, ApiDefinition};
/// Check if the given endpoint path is supported
pub fn is_supported_endpoint(endpoint: &str) -> bool {
// Try OpenAI APIs
if OpenAIApi::from_endpoint(endpoint).is_some() {
return true;
}
// Try Anthropic APIs
if AnthropicApi::from_endpoint(endpoint).is_some() {
return true;
}
false
}
/// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new();
// Add all OpenAI endpoints
for api in OpenAIApi::all_variants() {
endpoints.push(api.endpoint());
}
// Add all Anthropic endpoints
for api in AnthropicApi::all_variants() {
endpoints.push(api.endpoint());
}
endpoints
}
/// Identify which provider supports a given endpoint
pub fn identify_provider(endpoint: &str) -> Option<&'static str> {
if OpenAIApi::from_endpoint(endpoint).is_some() {
return Some("openai");
}
if AnthropicApi::from_endpoint(endpoint).is_some() {
return Some("anthropic");
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_supported_endpoint() {
// OpenAI endpoints
assert!(is_supported_endpoint("/v1/chat/completions"));
// Anthropic endpoints
assert!(is_supported_endpoint("/v1/messages"));
// Unsupported endpoints
assert!(!is_supported_endpoint("/v1/unknown"));
assert!(!is_supported_endpoint("/v2/chat"));
assert!(!is_supported_endpoint(""));
}
#[test]
fn test_supported_endpoints() {
let endpoints = supported_endpoints();
assert_eq!(endpoints.len(), 2);
assert!(endpoints.contains(&"/v1/chat/completions"));
assert!(endpoints.contains(&"/v1/messages"));
}
#[test]
fn test_identify_provider() {
assert_eq!(identify_provider("/v1/chat/completions"), Some("openai"));
assert_eq!(identify_provider("/v1/messages"), Some("anthropic"));
assert_eq!(identify_provider("/v1/unknown"), None);
}
#[test]
fn test_endpoints_generated_from_api_definitions() {
let endpoints = supported_endpoints();
// Verify that we get endpoints from all API variants
let openai_endpoints: Vec<_> = OpenAIApi::all_variants()
.iter()
.map(|api| api.endpoint())
.collect();
let anthropic_endpoints: Vec<_> = AnthropicApi::all_variants()
.iter()
.map(|api| api.endpoint())
.collect();
// All OpenAI endpoints should be in the result
for endpoint in openai_endpoints {
assert!(endpoints.contains(&endpoint), "Missing OpenAI endpoint: {}", endpoint);
}
// All Anthropic endpoints should be in the result
for endpoint in anthropic_endpoints {
assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint);
}
// Total should match
assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len());
}
}

View file

@ -0,0 +1,33 @@
//! Helper functions and utilities for API transformations
//! Contains error types and shared utilities
use thiserror::Error;
// ============================================================================
// ERROR TYPES
// ============================================================================
#[derive(Error, Debug)]
pub enum TransformError {
#[error("JSON serialization error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Unsupported content type: {0}")]
UnsupportedContent(String),
#[error("Invalid tool input format")]
InvalidToolInput,
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Unsupported conversion: {0}")]
UnsupportedConversion(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_types() {
let error = TransformError::MissingField("test".to_string());
assert!(matches!(error, TransformError::MissingField(_)));
}
}

View file

@ -0,0 +1,9 @@
pub mod lib;
pub mod transformer;
pub mod endpoints;
// Re-export the main items for easier access
pub use lib::*;
pub use endpoints::{is_supported_endpoint, supported_endpoints, identify_provider};
// Note: transformer module contains TryFrom trait implementations that are automatically available

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,12 @@
//! hermesllm: A library for translating LLM API requests and responses //! hermesllm: A library for translating LLM API requests and responses
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats. //! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
use std::fmt::Display;
pub mod providers; pub mod providers;
pub mod apis;
pub mod clients;
use std::fmt::Display;
pub enum Provider { pub enum Provider {
Arch, Arch,
Mistral, Mistral,

View file

@ -0,0 +1,2 @@
pub mod providers;
pub mod clients;