diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs new file mode 100644 index 00000000..7ae085d5 --- /dev/null +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -0,0 +1,568 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::skip_serializing_none; +use std::collections::HashMap; + +use super::ApiDefinition; + +// Enum for all supported Anthropic APIs +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AnthropicApi { + Messages, + // Future APIs can be added here: + // Embeddings, + // etc. +} + +impl ApiDefinition for AnthropicApi { + fn endpoint(&self) -> &'static str { + match self { + AnthropicApi::Messages => "/v1/messages", + } + } + + fn from_endpoint(endpoint: &str) -> Option { + match endpoint { + "/v1/messages" => Some(AnthropicApi::Messages), + _ => None, + } + } + + fn supports_streaming(&self) -> bool { + match self { + AnthropicApi::Messages => true, + } + } + + fn supports_tools(&self) -> bool { + match self { + AnthropicApi::Messages => true, + } + } + + fn supports_vision(&self) -> bool { + match self { + AnthropicApi::Messages => true, + } + } + + fn all_variants() -> Vec { + vec![ + AnthropicApi::Messages, + ] + } +} + +// Service tier enum for request priority +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ServiceTier { + Auto, + StandardOnly, +} + +// Thinking configuration +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ThinkingConfig { + pub enabled: bool, +} + +// MCP Server types +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum McpServerType { + Url, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct McpToolConfiguration { + pub allowed_tools: Option>, + pub enabled: Option, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct McpServer { + pub name: String, + #[serde(rename = "type")] + pub server_type: McpServerType, + pub url: String, + pub authorization_token: Option, + pub tool_configuration: Option, +} + + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesRequest { + pub model: String, + pub messages: Vec, + pub max_tokens: u32, + pub container: Option, + pub mcp_servers: Option>, + pub system: Option, + pub metadata: Option>, + pub service_tier: Option, + pub thinking: Option, + + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub stream: Option, + pub stop_sequences: Option>, + pub tools: Option>, + pub tool_choice: Option, + +} + + +// Messages API specific types +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum MessagesRole { + User, + Assistant, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum MessagesContentBlock { + Text { + text: String, + }, + Thinking { + text: String, + }, + Image { + source: MessagesImageSource, + }, + Document { + source: MessagesDocumentSource, + }, + ToolUse { + id: String, + name: String, + input: Value, + }, + ToolResult { + tool_use_id: String, + is_error: Option, + content: Vec, + }, + ServerToolUse { + id: String, + name: String, + input: Value, + }, + WebSearchToolResult { + tool_use_id: String, + is_error: Option, + content: Vec, + }, + CodeExecutionToolResult { + tool_use_id: String, + is_error: Option, + content: Vec, + }, + McpToolUse { + id: String, + name: String, + input: Value, + }, + McpToolResult { + tool_use_id: String, + is_error: Option, + content: Vec, + }, + ContainerUpload { + id: String, + name: String, + media_type: String, + data: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum MessagesImageSource { + Base64 { + media_type: String, + data: String, + }, + Url { + url: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum MessagesDocumentSource { + Base64 { + media_type: String, + data: String, + }, + Url { + url: String, + }, + File { + file_id: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum MessagesMessageContent { + Single(String), + Blocks(Vec), +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum MessagesSystemPrompt { + Single(String), + Blocks(Vec), +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesMessage { + pub role: MessagesRole, + pub content: MessagesMessageContent, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesTool { + pub name: String, + pub description: Option, + pub input_schema: Value, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum MessagesToolChoiceType { + Auto, + Any, + Tool, + None, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesToolChoice { + #[serde(rename = "type")] + pub kind: MessagesToolChoiceType, + pub name: Option, + pub disable_parallel_tool_use: Option, +} + + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum MessagesStopReason { + EndTurn, + MaxTokens, + StopSequence, + ToolUse, + PauseTurn, + Refusal, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesUsage { + pub input_tokens: u32, + pub output_tokens: u32, + pub cache_creation_input_tokens: Option, + pub cache_read_input_tokens: Option, +} + +// Container response object +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesContainer { + pub id: String, + #[serde(rename = "type")] + pub container_type: String, + pub name: String, + pub status: String, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesResponse { + pub id: String, + #[serde(rename = "type")] + pub obj_type: String, + pub role: MessagesRole, + pub content: Vec, + pub model: String, + pub stop_reason: MessagesStopReason, + pub stop_sequence: Option, + pub usage: MessagesUsage, + pub container: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum MessagesStreamEvent { + MessageStart { + message: MessagesStreamMessage, + }, + ContentBlockStart { + index: u32, + content_block: MessagesContentBlock, + }, + ContentBlockDelta { + index: u32, + delta: MessagesContentDelta, + }, + ContentBlockStop { + index: u32, + }, + MessageDelta { + delta: MessagesMessageDelta, + usage: MessagesUsage, + }, + MessageStop, + Ping, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesStreamMessage { + pub id: String, + #[serde(rename = "type")] + pub obj_type: String, + pub role: MessagesRole, + pub content: Vec, // Initially empty + pub model: String, + pub stop_reason: Option, + pub stop_sequence: Option, + pub usage: MessagesUsage, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum MessagesContentDelta { + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "input_json_delta")] + InputJsonDelta { partial_json: String }, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesMessageDelta { + pub stop_reason: MessagesStopReason, + pub stop_sequence: Option, +} + +// Helper functions for API detection and conversion +impl MessagesRequest { + pub fn api_type() -> AnthropicApi { + AnthropicApi::Messages + } +} + +impl MessagesResponse { + pub fn api_type() -> AnthropicApi { + AnthropicApi::Messages + } +} + +impl MessagesStreamEvent { + pub fn api_type() -> AnthropicApi { + AnthropicApi::Messages + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + #[test] + fn test_anthropic_skip_serializing_none_annotations() { + // Test that skip_serializing_none works correctly for MessagesRequest + let request = MessagesRequest { + model: "claude-3-sonnet-20240229".to_string(), + system: None, // Should be skipped + messages: vec![MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Hello".to_string()), + }], + max_tokens: 100, + container: None, // Should be skipped + mcp_servers: None, // Should be skipped + service_tier: None, // Should be skipped + thinking: None, // Should be skipped + temperature: None, // Should be skipped + top_p: Some(0.9), // Should be included + top_k: None, // Should be skipped + stream: None, // Should be skipped + stop_sequences: None, // Should be skipped + tools: None, // Should be skipped + tool_choice: None, // Should be skipped + metadata: None, // Should be skipped + }; + + let json = serde_json::to_value(&request).unwrap(); + let obj = json.as_object().unwrap(); + + // Verify that None fields are not present in the JSON + assert!(!obj.contains_key("system")); + assert!(!obj.contains_key("container")); + assert!(!obj.contains_key("mcp_servers")); + assert!(!obj.contains_key("service_tier")); + assert!(!obj.contains_key("thinking")); + assert!(!obj.contains_key("temperature")); + assert!(!obj.contains_key("top_k")); + assert!(!obj.contains_key("stream")); + assert!(!obj.contains_key("stop_sequences")); + assert!(!obj.contains_key("tools")); + assert!(!obj.contains_key("tool_choice")); + assert!(!obj.contains_key("metadata")); + + // Verify that required fields and Some fields are present + assert!(obj.contains_key("model")); + assert!(obj.contains_key("messages")); + assert!(obj.contains_key("max_tokens")); + assert!(obj.contains_key("top_p")); // This was Some(0.9) + } + + #[test] + fn test_anthropic_tool_serialization() { + // Test MessagesTool with skip_serializing_none + let tool = MessagesTool { + name: "get_weather".to_string(), + description: None, // Should be skipped + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "location": {"type": "string"} + } + }), + }; + + let json = serde_json::to_value(&tool).unwrap(); + let obj = json.as_object().unwrap(); + + assert!(obj.contains_key("name")); + assert!(obj.contains_key("input_schema")); + assert!(!obj.contains_key("description")); // Should be skipped + + // Test with description present + let tool_with_desc = MessagesTool { + name: "get_weather".to_string(), + description: Some("Get weather information".to_string()), + input_schema: serde_json::json!({"type": "object"}), + }; + + let json_with_desc = serde_json::to_value(&tool_with_desc).unwrap(); + let obj_with_desc = json_with_desc.as_object().unwrap(); + + assert!(obj_with_desc.contains_key("description")); // Should be included + } + + #[test] + fn test_mcp_server_serialization() { + // Test McpServer with skip_serializing_none + let mcp_server = McpServer { + name: "test-server".to_string(), + server_type: McpServerType::Url, + url: "https://example.com/mcp".to_string(), + authorization_token: None, // Should be skipped + tool_configuration: Some(McpToolConfiguration { + allowed_tools: Some(vec!["tool1".to_string(), "tool2".to_string()]), + enabled: None, // Should be skipped + }), + }; + + let json = serde_json::to_value(&mcp_server).unwrap(); + let obj = json.as_object().unwrap(); + + // Verify required fields are present + assert!(obj.contains_key("name")); + assert!(obj.contains_key("type")); + assert!(obj.contains_key("url")); + assert!(obj.contains_key("tool_configuration")); + + // Verify None fields are not present + assert!(!obj.contains_key("authorization_token")); + + // Check tool_configuration + let tool_config = obj.get("tool_configuration").unwrap().as_object().unwrap(); + assert!(tool_config.contains_key("allowed_tools")); + assert!(!tool_config.contains_key("enabled")); // Should be skipped + + // Verify type serialization + assert_eq!(obj.get("type").unwrap().as_str().unwrap(), "url"); + } + + #[test] + fn test_service_tier_and_thinking_serialization() { + // Test with service_tier and thinking enabled + let request_with_fields = MessagesRequest { + model: "claude-3-sonnet".to_string(), + system: None, + messages: vec![MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Hello".to_string()), + }], + max_tokens: 100, + container: None, + mcp_servers: None, + service_tier: Some(ServiceTier::Auto), + thinking: Some(ThinkingConfig { enabled: true }), + temperature: None, + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + tools: None, + tool_choice: None, + metadata: None, + }; + + let json = serde_json::to_value(&request_with_fields).unwrap(); + let obj = json.as_object().unwrap(); + + // Verify that Some fields are present + assert!(obj.contains_key("service_tier")); + assert!(obj.contains_key("thinking")); + + // Verify service_tier serialization + assert_eq!(obj.get("service_tier").unwrap().as_str().unwrap(), "auto"); + + // Verify thinking serialization + let thinking = obj.get("thinking").unwrap().as_object().unwrap(); + assert!(thinking.contains_key("enabled")); + assert_eq!(thinking.get("enabled").unwrap().as_bool().unwrap(), true); + } + + #[test] + fn test_anthropic_api_provider_trait_implementation() { + use super::ApiDefinition; + + // Test that AnthropicApi implements ApiDefinition trait correctly + let api = AnthropicApi::Messages; + + // Test trait methods + assert_eq!(ApiDefinition::endpoint(&api), "/v1/messages"); + assert!(ApiDefinition::supports_streaming(&api)); + assert!(ApiDefinition::supports_tools(&api)); + assert!(ApiDefinition::supports_vision(&api)); + + // Test from_endpoint trait method + let found_api = AnthropicApi::from_endpoint("/v1/messages"); + assert_eq!(found_api, Some(AnthropicApi::Messages)); + + let not_found = AnthropicApi::from_endpoint("/v1/unknown"); + assert_eq!(not_found, None); + } +} diff --git a/crates/hermesllm/src/apis/mod.rs b/crates/hermesllm/src/apis/mod.rs new file mode 100644 index 00000000..78b634d5 --- /dev/null +++ b/crates/hermesllm/src/apis/mod.rs @@ -0,0 +1,197 @@ +pub mod anthropic; +pub mod openai; + +// Re-export all types for convenience +pub use anthropic::*; +pub use openai::*; + +/// Common trait that all API definitions must implement +/// +/// This trait ensures consistency across different AI provider API definitions +/// and makes it easy to add new providers like Gemini, Claude, etc. +/// +/// Note: This is different from the `ApiProvider` enum in `clients::endpoints` +/// which represents provider identification, while this trait defines API capabilities. +/// +/// # Benefits +/// +/// - **Consistency**: All API providers implement the same interface +/// - **Extensibility**: Easy to add new providers without breaking existing code +/// - **Type Safety**: Compile-time guarantees that all providers implement required methods +/// - **Discoverability**: Clear documentation of what capabilities each API supports +/// +/// # Example implementation for a new provider: +/// +/// ```rust,ignore +/// use serde::{Deserialize, Serialize}; +/// use super::ApiDefinition; +/// +/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +/// pub enum GeminiApi { +/// GenerateContent, +/// ChatCompletions, +/// } +/// +/// impl GeminiApi { +/// pub fn endpoint(&self) -> &'static str { +/// match self { +/// GeminiApi::GenerateContent => "/v1/models/gemini-pro:generateContent", +/// GeminiApi::ChatCompletions => "/v1/models/gemini-pro:chat", +/// } +/// } +/// +/// pub fn from_endpoint(endpoint: &str) -> Option { +/// match endpoint { +/// "/v1/models/gemini-pro:generateContent" => Some(GeminiApi::GenerateContent), +/// "/v1/models/gemini-pro:chat" => Some(GeminiApi::ChatCompletions), +/// _ => None, +/// } +/// } +/// +/// pub fn supports_streaming(&self) -> bool { +/// match self { +/// GeminiApi::GenerateContent => true, +/// GeminiApi::ChatCompletions => true, +/// } +/// } +/// +/// pub fn supports_tools(&self) -> bool { +/// match self { +/// GeminiApi::GenerateContent => true, +/// GeminiApi::ChatCompletions => false, +/// } +/// } +/// +/// pub fn supports_vision(&self) -> bool { +/// match self { +/// GeminiApi::GenerateContent => true, +/// GeminiApi::ChatCompletions => false, +/// } +/// } +/// } +/// +/// impl ApiDefinition for GeminiApi { +/// fn endpoint(&self) -> &'static str { +/// self.endpoint() +/// } +/// +/// fn from_endpoint(endpoint: &str) -> Option { +/// Self::from_endpoint(endpoint) +/// } +/// +/// fn supports_streaming(&self) -> bool { +/// self.supports_streaming() +/// } +/// +/// fn supports_tools(&self) -> bool { +/// self.supports_tools() +/// } +/// +/// fn supports_vision(&self) -> bool { +/// self.supports_vision() +/// } +/// } +/// +/// // Now you can use generic code that works with any API: +/// fn print_api_info(api: &T) { +/// println!("Endpoint: {}", api.endpoint()); +/// println!("Supports streaming: {}", api.supports_streaming()); +/// println!("Supports tools: {}", api.supports_tools()); +/// println!("Supports vision: {}", api.supports_vision()); +/// } +/// +/// // Works with both OpenAI and Anthropic (and future Gemini) +/// print_api_info(&OpenAIApi::ChatCompletions); +/// print_api_info(&AnthropicApi::Messages); +/// print_api_info(&GeminiApi::GenerateContent); +/// ``` +pub trait ApiDefinition { + /// Returns the endpoint path for this API + fn endpoint(&self) -> &'static str; + + /// Creates an API instance from an endpoint path + fn from_endpoint(endpoint: &str) -> Option + where + Self: Sized; + + /// Returns whether this API supports streaming responses + fn supports_streaming(&self) -> bool; + + /// Returns whether this API supports tool/function calling + fn supports_tools(&self) -> bool; + + /// Returns whether this API supports vision/image processing + fn supports_vision(&self) -> bool; + + /// Returns all variants of this API enum + fn all_variants() -> Vec + where + Self: Sized; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generic_api_functionality() { + // Test that our generic API functionality works with both providers + fn test_api(api: &T) { + let endpoint = api.endpoint(); + assert!(!endpoint.is_empty()); + assert!(endpoint.starts_with('/')); + } + + test_api(&OpenAIApi::ChatCompletions); + test_api(&AnthropicApi::Messages); + } + + #[test] + fn test_api_detection_from_endpoints() { + // Test that we can detect APIs from endpoints using the trait + let endpoints = vec![ + "/v1/chat/completions", + "/v1/messages", + "/v1/unknown" + ]; + + let mut detected_apis = Vec::new(); + + for endpoint in endpoints { + if let Some(api) = OpenAIApi::from_endpoint(endpoint) { + detected_apis.push(format!("OpenAI: {:?}", api)); + } else if let Some(api) = AnthropicApi::from_endpoint(endpoint) { + detected_apis.push(format!("Anthropic: {:?}", api)); + } else { + detected_apis.push("Unknown API".to_string()); + } + } + + assert_eq!(detected_apis, vec![ + "OpenAI: ChatCompletions", + "Anthropic: Messages", + "Unknown API" + ]); + } + + #[test] + fn test_all_variants_method() { + // Test that all_variants returns the expected variants + let openai_variants = OpenAIApi::all_variants(); + assert_eq!(openai_variants.len(), 1); + assert!(openai_variants.contains(&OpenAIApi::ChatCompletions)); + + let anthropic_variants = AnthropicApi::all_variants(); + assert_eq!(anthropic_variants.len(), 1); + assert!(anthropic_variants.contains(&AnthropicApi::Messages)); + + // Verify each variant has a valid endpoint + for variant in openai_variants { + assert!(!variant.endpoint().is_empty()); + } + + for variant in anthropic_variants { + assert!(!variant.endpoint().is_empty()); + } + } +} diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs new file mode 100644 index 00000000..eb270510 --- /dev/null +++ b/crates/hermesllm/src/apis/openai.rs @@ -0,0 +1,1303 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::skip_serializing_none; +use std::collections::HashMap; + +use super::ApiDefinition; + +// ============================================================================ +// OPENAI API ENUMERATION +// ============================================================================ + +/// Enum for all supported OpenAI APIs +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum OpenAIApi { + ChatCompletions, + // Future APIs can be added here: + // Embeddings, + // FineTuning, + // etc. +} + +impl ApiDefinition for OpenAIApi { + fn endpoint(&self) -> &'static str { + match self { + OpenAIApi::ChatCompletions => "/v1/chat/completions", + } + } + + fn from_endpoint(endpoint: &str) -> Option { + match endpoint { + "/v1/chat/completions" => Some(OpenAIApi::ChatCompletions), + _ => None, + } + } + + fn supports_streaming(&self) -> bool { + match self { + OpenAIApi::ChatCompletions => true, + } + } + + fn supports_tools(&self) -> bool { + match self { + OpenAIApi::ChatCompletions => true, + } + } + + fn supports_vision(&self) -> bool { + match self { + OpenAIApi::ChatCompletions => true, + } + } + + fn all_variants() -> Vec { + vec![ + OpenAIApi::ChatCompletions, + ] + } +} + +/// Chat completions API request +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatCompletionsRequest { + pub messages: Vec, + pub model: String, + // pub auduio: Option