mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
pushing new apis module for hermes
This commit is contained in:
parent
62a092fa63
commit
da7c3e2c37
3 changed files with 2068 additions and 0 deletions
568
crates/hermesllm/src/apis/anthropic.rs
Normal file
568
crates/hermesllm/src/apis/anthropic.rs
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::ApiDefinition;
|
||||
|
||||
// Enum for all supported Anthropic APIs
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum AnthropicApi {
|
||||
Messages,
|
||||
// Future APIs can be added here:
|
||||
// Embeddings,
|
||||
// etc.
|
||||
}
|
||||
|
||||
impl ApiDefinition for AnthropicApi {
|
||||
fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
AnthropicApi::Messages => "/v1/messages",
|
||||
}
|
||||
}
|
||||
|
||||
fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
match endpoint {
|
||||
"/v1/messages" => Some(AnthropicApi::Messages),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_streaming(&self) -> bool {
|
||||
match self {
|
||||
AnthropicApi::Messages => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
AnthropicApi::Messages => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_vision(&self) -> bool {
|
||||
match self {
|
||||
AnthropicApi::Messages => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![
|
||||
AnthropicApi::Messages,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// Service tier enum for request priority
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ServiceTier {
|
||||
Auto,
|
||||
StandardOnly,
|
||||
}
|
||||
|
||||
// Thinking configuration
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ThinkingConfig {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
// MCP Server types
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum McpServerType {
|
||||
Url,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct McpToolConfiguration {
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct McpServer {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub server_type: McpServerType,
|
||||
pub url: String,
|
||||
pub authorization_token: Option<String>,
|
||||
pub tool_configuration: Option<McpToolConfiguration>,
|
||||
}
|
||||
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<MessagesMessage>,
|
||||
pub max_tokens: u32,
|
||||
pub container: Option<String>,
|
||||
pub mcp_servers: Option<Vec<McpServer>>,
|
||||
pub system: Option<MessagesSystemPrompt>,
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
pub service_tier: Option<ServiceTier>,
|
||||
pub thinking: Option<ThinkingConfig>,
|
||||
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub top_k: Option<u32>,
|
||||
pub stream: Option<bool>,
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub tools: Option<Vec<MessagesTool>>,
|
||||
pub tool_choice: Option<MessagesToolChoice>,
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Messages API specific types
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MessagesRole {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesContentBlock {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
Thinking {
|
||||
text: String,
|
||||
},
|
||||
Image {
|
||||
source: MessagesImageSource,
|
||||
},
|
||||
Document {
|
||||
source: MessagesDocumentSource,
|
||||
},
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
is_error: Option<bool>,
|
||||
content: Vec<MessagesContentBlock>,
|
||||
},
|
||||
ServerToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
WebSearchToolResult {
|
||||
tool_use_id: String,
|
||||
is_error: Option<bool>,
|
||||
content: Vec<MessagesContentBlock>,
|
||||
},
|
||||
CodeExecutionToolResult {
|
||||
tool_use_id: String,
|
||||
is_error: Option<bool>,
|
||||
content: Vec<MessagesContentBlock>,
|
||||
},
|
||||
McpToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
McpToolResult {
|
||||
tool_use_id: String,
|
||||
is_error: Option<bool>,
|
||||
content: Vec<MessagesContentBlock>,
|
||||
},
|
||||
ContainerUpload {
|
||||
id: String,
|
||||
name: String,
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesImageSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesDocumentSource {
|
||||
Base64 {
|
||||
media_type: String,
|
||||
data: String,
|
||||
},
|
||||
Url {
|
||||
url: String,
|
||||
},
|
||||
File {
|
||||
file_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessagesMessageContent {
|
||||
Single(String),
|
||||
Blocks(Vec<MessagesContentBlock>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessagesSystemPrompt {
|
||||
Single(String),
|
||||
Blocks(Vec<MessagesContentBlock>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesMessage {
|
||||
pub role: MessagesRole,
|
||||
pub content: MessagesMessageContent,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesTool {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub input_schema: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesToolChoiceType {
|
||||
Auto,
|
||||
Any,
|
||||
Tool,
|
||||
None,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesToolChoice {
|
||||
#[serde(rename = "type")]
|
||||
pub kind: MessagesToolChoiceType,
|
||||
pub name: Option<String>,
|
||||
pub disable_parallel_tool_use: Option<bool>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesStopReason {
|
||||
EndTurn,
|
||||
MaxTokens,
|
||||
StopSequence,
|
||||
ToolUse,
|
||||
PauseTurn,
|
||||
Refusal,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesUsage {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
pub cache_creation_input_tokens: Option<u32>,
|
||||
pub cache_read_input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// Container response object
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesContainer {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub container_type: String,
|
||||
pub name: String,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub obj_type: String,
|
||||
pub role: MessagesRole,
|
||||
pub content: Vec<MessagesContentBlock>,
|
||||
pub model: String,
|
||||
pub stop_reason: MessagesStopReason,
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: MessagesUsage,
|
||||
pub container: Option<MessagesContainer>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesStreamEvent {
|
||||
MessageStart {
|
||||
message: MessagesStreamMessage,
|
||||
},
|
||||
ContentBlockStart {
|
||||
index: u32,
|
||||
content_block: MessagesContentBlock,
|
||||
},
|
||||
ContentBlockDelta {
|
||||
index: u32,
|
||||
delta: MessagesContentDelta,
|
||||
},
|
||||
ContentBlockStop {
|
||||
index: u32,
|
||||
},
|
||||
MessageDelta {
|
||||
delta: MessagesMessageDelta,
|
||||
usage: MessagesUsage,
|
||||
},
|
||||
MessageStop,
|
||||
Ping,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesStreamMessage {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub obj_type: String,
|
||||
pub role: MessagesRole,
|
||||
pub content: Vec<Value>, // Initially empty
|
||||
pub model: String,
|
||||
pub stop_reason: Option<MessagesStopReason>,
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: MessagesUsage,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagesContentDelta {
|
||||
#[serde(rename = "text_delta")]
|
||||
TextDelta { text: String },
|
||||
#[serde(rename = "input_json_delta")]
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesMessageDelta {
|
||||
pub stop_reason: MessagesStopReason,
|
||||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
|
||||
// Helper functions for API detection and conversion
|
||||
impl MessagesRequest {
|
||||
pub fn api_type() -> AnthropicApi {
|
||||
AnthropicApi::Messages
|
||||
}
|
||||
}
|
||||
|
||||
impl MessagesResponse {
|
||||
pub fn api_type() -> AnthropicApi {
|
||||
AnthropicApi::Messages
|
||||
}
|
||||
}
|
||||
|
||||
impl MessagesStreamEvent {
|
||||
pub fn api_type() -> AnthropicApi {
|
||||
AnthropicApi::Messages
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json;
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_skip_serializing_none_annotations() {
|
||||
// Test that skip_serializing_none works correctly for MessagesRequest
|
||||
let request = MessagesRequest {
|
||||
model: "claude-3-sonnet-20240229".to_string(),
|
||||
system: None, // Should be skipped
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello".to_string()),
|
||||
}],
|
||||
max_tokens: 100,
|
||||
container: None, // Should be skipped
|
||||
mcp_servers: None, // Should be skipped
|
||||
service_tier: None, // Should be skipped
|
||||
thinking: None, // Should be skipped
|
||||
temperature: None, // Should be skipped
|
||||
top_p: Some(0.9), // Should be included
|
||||
top_k: None, // Should be skipped
|
||||
stream: None, // Should be skipped
|
||||
stop_sequences: None, // Should be skipped
|
||||
tools: None, // Should be skipped
|
||||
tool_choice: None, // Should be skipped
|
||||
metadata: None, // Should be skipped
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&request).unwrap();
|
||||
let obj = json.as_object().unwrap();
|
||||
|
||||
// Verify that None fields are not present in the JSON
|
||||
assert!(!obj.contains_key("system"));
|
||||
assert!(!obj.contains_key("container"));
|
||||
assert!(!obj.contains_key("mcp_servers"));
|
||||
assert!(!obj.contains_key("service_tier"));
|
||||
assert!(!obj.contains_key("thinking"));
|
||||
assert!(!obj.contains_key("temperature"));
|
||||
assert!(!obj.contains_key("top_k"));
|
||||
assert!(!obj.contains_key("stream"));
|
||||
assert!(!obj.contains_key("stop_sequences"));
|
||||
assert!(!obj.contains_key("tools"));
|
||||
assert!(!obj.contains_key("tool_choice"));
|
||||
assert!(!obj.contains_key("metadata"));
|
||||
|
||||
// Verify that required fields and Some fields are present
|
||||
assert!(obj.contains_key("model"));
|
||||
assert!(obj.contains_key("messages"));
|
||||
assert!(obj.contains_key("max_tokens"));
|
||||
assert!(obj.contains_key("top_p")); // This was Some(0.9)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_tool_serialization() {
|
||||
// Test MessagesTool with skip_serializing_none
|
||||
let tool = MessagesTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: None, // Should be skipped
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&tool).unwrap();
|
||||
let obj = json.as_object().unwrap();
|
||||
|
||||
assert!(obj.contains_key("name"));
|
||||
assert!(obj.contains_key("input_schema"));
|
||||
assert!(!obj.contains_key("description")); // Should be skipped
|
||||
|
||||
// Test with description present
|
||||
let tool_with_desc = MessagesTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: Some("Get weather information".to_string()),
|
||||
input_schema: serde_json::json!({"type": "object"}),
|
||||
};
|
||||
|
||||
let json_with_desc = serde_json::to_value(&tool_with_desc).unwrap();
|
||||
let obj_with_desc = json_with_desc.as_object().unwrap();
|
||||
|
||||
assert!(obj_with_desc.contains_key("description")); // Should be included
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_server_serialization() {
|
||||
// Test McpServer with skip_serializing_none
|
||||
let mcp_server = McpServer {
|
||||
name: "test-server".to_string(),
|
||||
server_type: McpServerType::Url,
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
authorization_token: None, // Should be skipped
|
||||
tool_configuration: Some(McpToolConfiguration {
|
||||
allowed_tools: Some(vec!["tool1".to_string(), "tool2".to_string()]),
|
||||
enabled: None, // Should be skipped
|
||||
}),
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&mcp_server).unwrap();
|
||||
let obj = json.as_object().unwrap();
|
||||
|
||||
// Verify required fields are present
|
||||
assert!(obj.contains_key("name"));
|
||||
assert!(obj.contains_key("type"));
|
||||
assert!(obj.contains_key("url"));
|
||||
assert!(obj.contains_key("tool_configuration"));
|
||||
|
||||
// Verify None fields are not present
|
||||
assert!(!obj.contains_key("authorization_token"));
|
||||
|
||||
// Check tool_configuration
|
||||
let tool_config = obj.get("tool_configuration").unwrap().as_object().unwrap();
|
||||
assert!(tool_config.contains_key("allowed_tools"));
|
||||
assert!(!tool_config.contains_key("enabled")); // Should be skipped
|
||||
|
||||
// Verify type serialization
|
||||
assert_eq!(obj.get("type").unwrap().as_str().unwrap(), "url");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_tier_and_thinking_serialization() {
|
||||
// Test with service_tier and thinking enabled
|
||||
let request_with_fields = MessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
system: None,
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello".to_string()),
|
||||
}],
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
thinking: Some(ThinkingConfig { enabled: true }),
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&request_with_fields).unwrap();
|
||||
let obj = json.as_object().unwrap();
|
||||
|
||||
// Verify that Some fields are present
|
||||
assert!(obj.contains_key("service_tier"));
|
||||
assert!(obj.contains_key("thinking"));
|
||||
|
||||
// Verify service_tier serialization
|
||||
assert_eq!(obj.get("service_tier").unwrap().as_str().unwrap(), "auto");
|
||||
|
||||
// Verify thinking serialization
|
||||
let thinking = obj.get("thinking").unwrap().as_object().unwrap();
|
||||
assert!(thinking.contains_key("enabled"));
|
||||
assert_eq!(thinking.get("enabled").unwrap().as_bool().unwrap(), true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_api_provider_trait_implementation() {
|
||||
use super::ApiDefinition;
|
||||
|
||||
// Test that AnthropicApi implements ApiDefinition trait correctly
|
||||
let api = AnthropicApi::Messages;
|
||||
|
||||
// Test trait methods
|
||||
assert_eq!(ApiDefinition::endpoint(&api), "/v1/messages");
|
||||
assert!(ApiDefinition::supports_streaming(&api));
|
||||
assert!(ApiDefinition::supports_tools(&api));
|
||||
assert!(ApiDefinition::supports_vision(&api));
|
||||
|
||||
// Test from_endpoint trait method
|
||||
let found_api = AnthropicApi::from_endpoint("/v1/messages");
|
||||
assert_eq!(found_api, Some(AnthropicApi::Messages));
|
||||
|
||||
let not_found = AnthropicApi::from_endpoint("/v1/unknown");
|
||||
assert_eq!(not_found, None);
|
||||
}
|
||||
}
|
||||
197
crates/hermesllm/src/apis/mod.rs
Normal file
197
crates/hermesllm/src/apis/mod.rs
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
|
||||
// Re-export all types for convenience
|
||||
pub use anthropic::*;
|
||||
pub use openai::*;
|
||||
|
||||
/// Common trait that all API definitions must implement
|
||||
///
|
||||
/// This trait ensures consistency across different AI provider API definitions
|
||||
/// and makes it easy to add new providers like Gemini, Claude, etc.
|
||||
///
|
||||
/// Note: This is different from the `ApiProvider` enum in `clients::endpoints`
|
||||
/// which represents provider identification, while this trait defines API capabilities.
|
||||
///
|
||||
/// # Benefits
|
||||
///
|
||||
/// - **Consistency**: All API providers implement the same interface
|
||||
/// - **Extensibility**: Easy to add new providers without breaking existing code
|
||||
/// - **Type Safety**: Compile-time guarantees that all providers implement required methods
|
||||
/// - **Discoverability**: Clear documentation of what capabilities each API supports
|
||||
///
|
||||
/// # Example implementation for a new provider:
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use serde::{Deserialize, Serialize};
|
||||
/// use super::ApiDefinition;
|
||||
///
|
||||
/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// pub enum GeminiApi {
|
||||
/// GenerateContent,
|
||||
/// ChatCompletions,
|
||||
/// }
|
||||
///
|
||||
/// impl GeminiApi {
|
||||
/// pub fn endpoint(&self) -> &'static str {
|
||||
/// match self {
|
||||
/// GeminiApi::GenerateContent => "/v1/models/gemini-pro:generateContent",
|
||||
/// GeminiApi::ChatCompletions => "/v1/models/gemini-pro:chat",
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
/// match endpoint {
|
||||
/// "/v1/models/gemini-pro:generateContent" => Some(GeminiApi::GenerateContent),
|
||||
/// "/v1/models/gemini-pro:chat" => Some(GeminiApi::ChatCompletions),
|
||||
/// _ => None,
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// pub fn supports_streaming(&self) -> bool {
|
||||
/// match self {
|
||||
/// GeminiApi::GenerateContent => true,
|
||||
/// GeminiApi::ChatCompletions => true,
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// pub fn supports_tools(&self) -> bool {
|
||||
/// match self {
|
||||
/// GeminiApi::GenerateContent => true,
|
||||
/// GeminiApi::ChatCompletions => false,
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// pub fn supports_vision(&self) -> bool {
|
||||
/// match self {
|
||||
/// GeminiApi::GenerateContent => true,
|
||||
/// GeminiApi::ChatCompletions => false,
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// impl ApiDefinition for GeminiApi {
|
||||
/// fn endpoint(&self) -> &'static str {
|
||||
/// self.endpoint()
|
||||
/// }
|
||||
///
|
||||
/// fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
/// Self::from_endpoint(endpoint)
|
||||
/// }
|
||||
///
|
||||
/// fn supports_streaming(&self) -> bool {
|
||||
/// self.supports_streaming()
|
||||
/// }
|
||||
///
|
||||
/// fn supports_tools(&self) -> bool {
|
||||
/// self.supports_tools()
|
||||
/// }
|
||||
///
|
||||
/// fn supports_vision(&self) -> bool {
|
||||
/// self.supports_vision()
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// // Now you can use generic code that works with any API:
|
||||
/// fn print_api_info<T: ApiDefinition>(api: &T) {
|
||||
/// println!("Endpoint: {}", api.endpoint());
|
||||
/// println!("Supports streaming: {}", api.supports_streaming());
|
||||
/// println!("Supports tools: {}", api.supports_tools());
|
||||
/// println!("Supports vision: {}", api.supports_vision());
|
||||
/// }
|
||||
///
|
||||
/// // Works with both OpenAI and Anthropic (and future Gemini)
|
||||
/// print_api_info(&OpenAIApi::ChatCompletions);
|
||||
/// print_api_info(&AnthropicApi::Messages);
|
||||
/// print_api_info(&GeminiApi::GenerateContent);
|
||||
/// ```
|
||||
pub trait ApiDefinition {
|
||||
/// Returns the endpoint path for this API
|
||||
fn endpoint(&self) -> &'static str;
|
||||
|
||||
/// Creates an API instance from an endpoint path
|
||||
fn from_endpoint(endpoint: &str) -> Option<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Returns whether this API supports streaming responses
|
||||
fn supports_streaming(&self) -> bool;
|
||||
|
||||
/// Returns whether this API supports tool/function calling
|
||||
fn supports_tools(&self) -> bool;
|
||||
|
||||
/// Returns whether this API supports vision/image processing
|
||||
fn supports_vision(&self) -> bool;
|
||||
|
||||
/// Returns all variants of this API enum
|
||||
fn all_variants() -> Vec<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generic_api_functionality() {
|
||||
// Test that our generic API functionality works with both providers
|
||||
fn test_api<T: ApiDefinition>(api: &T) {
|
||||
let endpoint = api.endpoint();
|
||||
assert!(!endpoint.is_empty());
|
||||
assert!(endpoint.starts_with('/'));
|
||||
}
|
||||
|
||||
test_api(&OpenAIApi::ChatCompletions);
|
||||
test_api(&AnthropicApi::Messages);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_detection_from_endpoints() {
|
||||
// Test that we can detect APIs from endpoints using the trait
|
||||
let endpoints = vec![
|
||||
"/v1/chat/completions",
|
||||
"/v1/messages",
|
||||
"/v1/unknown"
|
||||
];
|
||||
|
||||
let mut detected_apis = Vec::new();
|
||||
|
||||
for endpoint in endpoints {
|
||||
if let Some(api) = OpenAIApi::from_endpoint(endpoint) {
|
||||
detected_apis.push(format!("OpenAI: {:?}", api));
|
||||
} else if let Some(api) = AnthropicApi::from_endpoint(endpoint) {
|
||||
detected_apis.push(format!("Anthropic: {:?}", api));
|
||||
} else {
|
||||
detected_apis.push("Unknown API".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(detected_apis, vec![
|
||||
"OpenAI: ChatCompletions",
|
||||
"Anthropic: Messages",
|
||||
"Unknown API"
|
||||
]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_variants_method() {
|
||||
// Test that all_variants returns the expected variants
|
||||
let openai_variants = OpenAIApi::all_variants();
|
||||
assert_eq!(openai_variants.len(), 1);
|
||||
assert!(openai_variants.contains(&OpenAIApi::ChatCompletions));
|
||||
|
||||
let anthropic_variants = AnthropicApi::all_variants();
|
||||
assert_eq!(anthropic_variants.len(), 1);
|
||||
assert!(anthropic_variants.contains(&AnthropicApi::Messages));
|
||||
|
||||
// Verify each variant has a valid endpoint
|
||||
for variant in openai_variants {
|
||||
assert!(!variant.endpoint().is_empty());
|
||||
}
|
||||
|
||||
for variant in anthropic_variants {
|
||||
assert!(!variant.endpoint().is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
1303
crates/hermesllm/src/apis/openai.rs
Normal file
1303
crates/hermesllm/src/apis/openai.rs
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue