fixes based on code review

This commit is contained in:
Salman Paracha 2025-08-06 23:45:58 -07:00
parent 7490571f66
commit 268aea763e
2 changed files with 103 additions and 52 deletions

View file

@ -60,7 +60,7 @@ impl ApiDefinition for OpenAIApi {
/// Chat completions API request
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ChatCompletionsRequest {
pub messages: Vec<Message>,
pub model: String,
@ -139,7 +139,6 @@ pub struct ResponseMessage {
/// If the audio output modality is requested, this object contains data about the audio response
pub audio: Option<Value>,
/// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
/// The tool calls generated by the model, such as function calls
pub tool_calls: Option<Vec<ToolCall>>,
@ -226,11 +225,25 @@ pub struct Function {
pub strict: Option<bool>,
}
/// Tool choice string values
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoiceType {
/// Let the model automatically decide whether to call tools
Auto,
/// Force the model to call at least one tool
Required,
/// Prevent the model from calling any tools
None,
}
/// Tool choice configuration
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum ToolChoice {
String(String), // "none", "auto", "required"
/// String-based tool choice (auto, required, none)
Type(ToolChoiceType),
/// Specific function to call
Function {
#[serde(rename = "type")]
choice_type: String,
@ -239,7 +252,7 @@ pub enum ToolChoice {
}
/// Specific function choice
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionChoice {
pub name: String,
}
@ -671,10 +684,10 @@ mod tests {
assert_eq!(parameters["required"], json!(["location"]));
// Validate tool choice
if let Some(ToolChoice::String(choice)) = &deserialized_request.tool_choice {
assert_eq!(choice, "auto");
if let Some(ToolChoice::Type(choice)) = &deserialized_request.tool_choice {
assert_eq!(choice, &ToolChoiceType::Auto);
} else {
panic!("Expected auto tool choice string");
panic!("Expected auto tool choice");
}
// Validate prediction
@ -838,4 +851,33 @@ mod tests {
assert!(converted.name.is_none());
assert!(converted.tool_call_id.is_none());
}
#[test]
fn test_tool_choice_type_serialization() {
// Test that the enum serializes to the correct string values
let auto_choice = ToolChoice::Type(ToolChoiceType::Auto);
let required_choice = ToolChoice::Type(ToolChoiceType::Required);
let none_choice = ToolChoice::Type(ToolChoiceType::None);
let auto_json = serde_json::to_value(&auto_choice).unwrap();
let required_json = serde_json::to_value(&required_choice).unwrap();
let none_json = serde_json::to_value(&none_choice).unwrap();
assert_eq!(auto_json, "auto");
assert_eq!(required_json, "required");
assert_eq!(none_json, "none");
// Test deserialization from string values
let auto_deserialized: ToolChoice = serde_json::from_value(json!("auto")).unwrap();
let required_deserialized: ToolChoice = serde_json::from_value(json!("required")).unwrap();
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
// Test that invalid string values fail deserialization (type safety!)
let invalid_result: Result<ToolChoice, _> = serde_json::from_value(json!("invalid"));
assert!(invalid_result.is_err());
}
}

View file

@ -13,14 +13,14 @@
//!
//! ```rust
//! use hermesllm::apis::{
//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
//! AnthropicMessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
//! MessagesMessageContent, MessagesSystemPrompt,
//! };
//! use hermesllm::clients::TransformError;
//! use std::convert::TryInto;
//!
//! // Transform Anthropic to OpenAI
//! let anthropic_req = MessagesRequest {
//! let anthropic_req = AnthropicMessagesRequest {
//! model: "claude-3-sonnet".to_string(),
//! system: None,
//! messages: vec![],
@ -49,6 +49,13 @@ use std::time::{SystemTime, UNIX_EPOCH};
use crate::apis::*;
use super::TransformError;
// ============================================================================
// CONSTANTS
// ============================================================================
/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified
const DEFAULT_MAX_TOKENS: u32 = 4096;
// ============================================================================
// UTILITY TRAITS - Shared traits for content manipulation
// ============================================================================
@ -68,10 +75,13 @@ trait ContentUtils<T> {
// MAIN REQUEST TRANSFORMATIONS
// ============================================================================
impl TryFrom<MessagesRequest> for ChatCompletionsRequest {
type AnthropicMessagesRequest = MessagesRequest;
impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
type Error = TransformError;
fn try_from(req: MessagesRequest) -> Result<Self, Self::Error> {
fn try_from(req: AnthropicMessagesRequest) -> Result<Self, Self::Error> {
let mut openai_messages: Vec<Message> = Vec::new();
// Convert system prompt to system message if present
@ -95,34 +105,17 @@ impl TryFrom<MessagesRequest> for ChatCompletionsRequest {
temperature: req.temperature,
top_p: req.top_p,
max_tokens: Some(req.max_tokens),
max_completion_tokens: None,
stream: req.stream,
stream_options: None,
stop: req.stop_sequences,
tools: openai_tools,
tool_choice: openai_tool_choice,
parallel_tool_calls,
user: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: None,
top_logprobs: None,
n: None,
seed: None,
response_format: None,
service_tier: None,
store: None,
metadata: None,
modalities: None,
function_call: None,
functions: None,
prediction: None,
..Default::default()
})
}
}
impl TryFrom<ChatCompletionsRequest> for MessagesRequest {
impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
type Error = TransformError;
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
@ -145,11 +138,11 @@ impl TryFrom<ChatCompletionsRequest> for MessagesRequest {
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
Ok(MessagesRequest {
Ok(AnthropicMessagesRequest {
model: req.model,
system: system_prompt,
messages,
max_tokens: req.max_tokens.unwrap_or(4096),
max_tokens: req.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
container: None,
mcp_servers: None,
service_tier: None,
@ -785,9 +778,9 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
match tool_choice {
Some(choice) => {
let openai_choice = match choice.kind {
MessagesToolChoiceType::Auto => ToolChoice::String("auto".to_string()),
MessagesToolChoiceType::Any => ToolChoice::String("required".to_string()),
MessagesToolChoiceType::None => ToolChoice::String("none".to_string()),
MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto),
MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required),
MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None),
MessagesToolChoiceType::Tool => {
if let Some(name) = choice.name {
ToolChoice::Function {
@ -795,7 +788,7 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
function: FunctionChoice { name },
}
} else {
ToolChoice::String("auto".to_string())
ToolChoice::Type(ToolChoiceType::Auto)
}
}
};
@ -813,27 +806,22 @@ fn convert_openai_tool_choice(
) -> Option<MessagesToolChoice> {
tool_choice.map(|choice| {
match choice {
ToolChoice::String(s) => match s.as_str() {
"auto" => MessagesToolChoice {
ToolChoice::Type(tool_type) => match tool_type {
ToolChoiceType::Auto => MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
"required" => MessagesToolChoice {
ToolChoiceType::Required => MessagesToolChoice {
kind: MessagesToolChoiceType::Any,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
"none" => MessagesToolChoice {
ToolChoiceType::None => MessagesToolChoice {
kind: MessagesToolChoiceType::None,
name: None,
disable_parallel_tool_use: None,
},
_ => MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
},
},
ToolChoice::Function { function, .. } => MessagesToolChoice {
kind: MessagesToolChoiceType::Tool,
@ -1098,7 +1086,7 @@ mod tests {
#[test]
fn test_anthropic_to_openai_basic_request() {
let anthropic_req = MessagesRequest {
let anthropic_req = AnthropicMessagesRequest {
model: "claude-3-sonnet-20240229".to_string(),
system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())),
messages: vec![MessagesMessage {
@ -1134,7 +1122,7 @@ mod tests {
#[test]
fn test_roundtrip_consistency() {
// Test that converting back and forth maintains consistency
let original_anthropic = MessagesRequest {
let original_anthropic = AnthropicMessagesRequest {
model: "claude-3-sonnet".to_string(),
system: Some(MessagesSystemPrompt::Single("System prompt".to_string())),
messages: vec![MessagesMessage {
@ -1158,7 +1146,7 @@ mod tests {
// Convert to OpenAI and back
let openai_req: ChatCompletionsRequest = original_anthropic.clone().try_into().unwrap();
let roundtrip_anthropic: MessagesRequest = openai_req.try_into().unwrap();
let roundtrip_anthropic: AnthropicMessagesRequest = openai_req.try_into().unwrap();
// Check key fields are preserved
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
@ -1171,7 +1159,7 @@ mod tests {
#[test]
fn test_tool_choice_auto() {
let anthropic_req = MessagesRequest {
let anthropic_req = AnthropicMessagesRequest {
model: "claude-3".to_string(),
system: None,
messages: vec![],
@ -1203,8 +1191,8 @@ mod tests {
assert!(openai_req.tools.is_some());
assert_eq!(openai_req.tools.as_ref().unwrap().len(), 1);
if let Some(ToolChoice::String(choice)) = openai_req.tool_choice {
assert_eq!(choice, "auto");
if let Some(ToolChoice::Type(choice)) = openai_req.tool_choice {
assert_eq!(choice, ToolChoiceType::Auto);
} else {
panic!("Expected auto tool choice");
}
@ -1212,6 +1200,27 @@ mod tests {
assert_eq!(openai_req.parallel_tool_calls, Some(false));
}
#[test]
fn test_default_max_tokens_used_when_openai_has_none() {
// Test that DEFAULT_MAX_TOKENS is used when OpenAI request has no max_tokens
let openai_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("Hello".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}],
max_tokens: None, // No max_tokens specified
..Default::default()
};
let anthropic_req: AnthropicMessagesRequest = openai_req.try_into().unwrap();
assert_eq!(anthropic_req.max_tokens, DEFAULT_MAX_TOKENS);
}
#[test]
fn test_anthropic_message_start_streaming() {
let event = MessagesStreamEvent::MessageStart {