diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 1fdf49c8..7f75c6be 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -60,7 +60,7 @@ impl ApiDefinition for OpenAIApi { /// Chat completions API request #[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] pub struct ChatCompletionsRequest { pub messages: Vec, pub model: String, @@ -139,7 +139,6 @@ pub struct ResponseMessage { /// If the audio output modality is requested, this object contains data about the audio response pub audio: Option, /// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called - #[serde(skip_serializing_if = "Option::is_none")] pub function_call: Option, /// The tool calls generated by the model, such as function calls pub tool_calls: Option>, @@ -226,11 +225,25 @@ pub struct Function { pub strict: Option, } +/// Tool choice string values +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ToolChoiceType { + /// Let the model automatically decide whether to call tools + Auto, + /// Force the model to call at least one tool + Required, + /// Prevent the model from calling any tools + None, +} + /// Tool choice configuration -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(untagged)] pub enum ToolChoice { - String(String), // "none", "auto", "required" + /// String-based tool choice (auto, required, none) + Type(ToolChoiceType), + /// Specific function to call Function { #[serde(rename = "type")] choice_type: String, @@ -239,7 +252,7 @@ pub enum ToolChoice { } /// Specific function choice -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct FunctionChoice { pub name: String, } @@ -671,10 +684,10 @@ mod tests { assert_eq!(parameters["required"], json!(["location"])); // Validate tool choice - if let Some(ToolChoice::String(choice)) = &deserialized_request.tool_choice { - assert_eq!(choice, "auto"); + if let Some(ToolChoice::Type(choice)) = &deserialized_request.tool_choice { + assert_eq!(choice, &ToolChoiceType::Auto); } else { - panic!("Expected auto tool choice string"); + panic!("Expected auto tool choice"); } // Validate prediction @@ -838,4 +851,33 @@ mod tests { assert!(converted.name.is_none()); assert!(converted.tool_call_id.is_none()); } + + #[test] + fn test_tool_choice_type_serialization() { + // Test that the enum serializes to the correct string values + let auto_choice = ToolChoice::Type(ToolChoiceType::Auto); + let required_choice = ToolChoice::Type(ToolChoiceType::Required); + let none_choice = ToolChoice::Type(ToolChoiceType::None); + + let auto_json = serde_json::to_value(&auto_choice).unwrap(); + let required_json = serde_json::to_value(&required_choice).unwrap(); + let none_json = serde_json::to_value(&none_choice).unwrap(); + + assert_eq!(auto_json, "auto"); + assert_eq!(required_json, "required"); + assert_eq!(none_json, "none"); + + // Test deserialization from string values + let auto_deserialized: ToolChoice = serde_json::from_value(json!("auto")).unwrap(); + let required_deserialized: ToolChoice = serde_json::from_value(json!("required")).unwrap(); + let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap(); + + assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto)); + assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required)); + assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None)); + + // Test that invalid string values fail deserialization (type safety!) + let invalid_result: Result = serde_json::from_value(json!("invalid")); + assert!(invalid_result.is_err()); + } } diff --git a/crates/hermesllm/src/clients/transformer.rs b/crates/hermesllm/src/clients/transformer.rs index ab4bb48e..c6d524f4 100644 --- a/crates/hermesllm/src/clients/transformer.rs +++ b/crates/hermesllm/src/clients/transformer.rs @@ -13,14 +13,14 @@ //! //! ```rust //! use hermesllm::apis::{ -//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage, +//! AnthropicMessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage, //! MessagesMessageContent, MessagesSystemPrompt, //! }; //! use hermesllm::clients::TransformError; //! use std::convert::TryInto; //! //! // Transform Anthropic to OpenAI -//! let anthropic_req = MessagesRequest { +//! let anthropic_req = AnthropicMessagesRequest { //! model: "claude-3-sonnet".to_string(), //! system: None, //! messages: vec![], @@ -49,6 +49,13 @@ use std::time::{SystemTime, UNIX_EPOCH}; use crate::apis::*; use super::TransformError; +// ============================================================================ +// CONSTANTS +// ============================================================================ + +/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified +const DEFAULT_MAX_TOKENS: u32 = 4096; + // ============================================================================ // UTILITY TRAITS - Shared traits for content manipulation // ============================================================================ @@ -68,10 +75,13 @@ trait ContentUtils { // MAIN REQUEST TRANSFORMATIONS // ============================================================================ -impl TryFrom for ChatCompletionsRequest { +type AnthropicMessagesRequest = MessagesRequest; + + +impl TryFrom for ChatCompletionsRequest { type Error = TransformError; - fn try_from(req: MessagesRequest) -> Result { + fn try_from(req: AnthropicMessagesRequest) -> Result { let mut openai_messages: Vec = Vec::new(); // Convert system prompt to system message if present @@ -95,34 +105,17 @@ impl TryFrom for ChatCompletionsRequest { temperature: req.temperature, top_p: req.top_p, max_tokens: Some(req.max_tokens), - max_completion_tokens: None, stream: req.stream, - stream_options: None, stop: req.stop_sequences, tools: openai_tools, tool_choice: openai_tool_choice, parallel_tool_calls, - user: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: None, - top_logprobs: None, - n: None, - seed: None, - response_format: None, - service_tier: None, - store: None, - metadata: None, - modalities: None, - function_call: None, - functions: None, - prediction: None, + ..Default::default() }) } } -impl TryFrom for MessagesRequest { +impl TryFrom for AnthropicMessagesRequest { type Error = TransformError; fn try_from(req: ChatCompletionsRequest) -> Result { @@ -145,11 +138,11 @@ impl TryFrom for MessagesRequest { let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); - Ok(MessagesRequest { + Ok(AnthropicMessagesRequest { model: req.model, system: system_prompt, messages, - max_tokens: req.max_tokens.unwrap_or(4096), + max_tokens: req.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS), container: None, mcp_servers: None, service_tier: None, @@ -785,9 +778,9 @@ fn convert_anthropic_tool_choice(tool_choice: Option) -> (Op match tool_choice { Some(choice) => { let openai_choice = match choice.kind { - MessagesToolChoiceType::Auto => ToolChoice::String("auto".to_string()), - MessagesToolChoiceType::Any => ToolChoice::String("required".to_string()), - MessagesToolChoiceType::None => ToolChoice::String("none".to_string()), + MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto), + MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required), + MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None), MessagesToolChoiceType::Tool => { if let Some(name) = choice.name { ToolChoice::Function { @@ -795,7 +788,7 @@ fn convert_anthropic_tool_choice(tool_choice: Option) -> (Op function: FunctionChoice { name }, } } else { - ToolChoice::String("auto".to_string()) + ToolChoice::Type(ToolChoiceType::Auto) } } }; @@ -813,27 +806,22 @@ fn convert_openai_tool_choice( ) -> Option { tool_choice.map(|choice| { match choice { - ToolChoice::String(s) => match s.as_str() { - "auto" => MessagesToolChoice { + ToolChoice::Type(tool_type) => match tool_type { + ToolChoiceType::Auto => MessagesToolChoice { kind: MessagesToolChoiceType::Auto, name: None, disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), }, - "required" => MessagesToolChoice { + ToolChoiceType::Required => MessagesToolChoice { kind: MessagesToolChoiceType::Any, name: None, disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), }, - "none" => MessagesToolChoice { + ToolChoiceType::None => MessagesToolChoice { kind: MessagesToolChoiceType::None, name: None, disable_parallel_tool_use: None, }, - _ => MessagesToolChoice { - kind: MessagesToolChoiceType::Auto, - name: None, - disable_parallel_tool_use: parallel_tool_calls.map(|p| !p), - }, }, ToolChoice::Function { function, .. } => MessagesToolChoice { kind: MessagesToolChoiceType::Tool, @@ -1098,7 +1086,7 @@ mod tests { #[test] fn test_anthropic_to_openai_basic_request() { - let anthropic_req = MessagesRequest { + let anthropic_req = AnthropicMessagesRequest { model: "claude-3-sonnet-20240229".to_string(), system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())), messages: vec![MessagesMessage { @@ -1134,7 +1122,7 @@ mod tests { #[test] fn test_roundtrip_consistency() { // Test that converting back and forth maintains consistency - let original_anthropic = MessagesRequest { + let original_anthropic = AnthropicMessagesRequest { model: "claude-3-sonnet".to_string(), system: Some(MessagesSystemPrompt::Single("System prompt".to_string())), messages: vec![MessagesMessage { @@ -1158,7 +1146,7 @@ mod tests { // Convert to OpenAI and back let openai_req: ChatCompletionsRequest = original_anthropic.clone().try_into().unwrap(); - let roundtrip_anthropic: MessagesRequest = openai_req.try_into().unwrap(); + let roundtrip_anthropic: AnthropicMessagesRequest = openai_req.try_into().unwrap(); // Check key fields are preserved assert_eq!(original_anthropic.model, roundtrip_anthropic.model); @@ -1171,7 +1159,7 @@ mod tests { #[test] fn test_tool_choice_auto() { - let anthropic_req = MessagesRequest { + let anthropic_req = AnthropicMessagesRequest { model: "claude-3".to_string(), system: None, messages: vec![], @@ -1203,8 +1191,8 @@ mod tests { assert!(openai_req.tools.is_some()); assert_eq!(openai_req.tools.as_ref().unwrap().len(), 1); - if let Some(ToolChoice::String(choice)) = openai_req.tool_choice { - assert_eq!(choice, "auto"); + if let Some(ToolChoice::Type(choice)) = openai_req.tool_choice { + assert_eq!(choice, ToolChoiceType::Auto); } else { panic!("Expected auto tool choice"); } @@ -1212,6 +1200,27 @@ mod tests { assert_eq!(openai_req.parallel_tool_calls, Some(false)); } + #[test] + fn test_default_max_tokens_used_when_openai_has_none() { + // Test that DEFAULT_MAX_TOKENS is used when OpenAI request has no max_tokens + let openai_req = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![Message { + role: Role::User, + content: MessageContent::Text("Hello".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }], + max_tokens: None, // No max_tokens specified + ..Default::default() + }; + + let anthropic_req: AnthropicMessagesRequest = openai_req.try_into().unwrap(); + + assert_eq!(anthropic_req.max_tokens, DEFAULT_MAX_TOKENS); + } + #[test] fn test_anthropic_message_start_streaming() { let event = MessagesStreamEvent::MessageStart {