mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixes based on code review
This commit is contained in:
parent
7490571f66
commit
268aea763e
2 changed files with 103 additions and 52 deletions
|
|
@ -60,7 +60,7 @@ impl ApiDefinition for OpenAIApi {
|
|||
|
||||
/// Chat completions API request
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
pub messages: Vec<Message>,
|
||||
pub model: String,
|
||||
|
|
@ -139,7 +139,6 @@ pub struct ResponseMessage {
|
|||
/// If the audio output modality is requested, this object contains data about the audio response
|
||||
pub audio: Option<Value>,
|
||||
/// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_call: Option<FunctionCall>,
|
||||
/// The tool calls generated by the model, such as function calls
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
|
@ -226,11 +225,25 @@ pub struct Function {
|
|||
pub strict: Option<bool>,
|
||||
}
|
||||
|
||||
/// Tool choice string values
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ToolChoiceType {
|
||||
/// Let the model automatically decide whether to call tools
|
||||
Auto,
|
||||
/// Force the model to call at least one tool
|
||||
Required,
|
||||
/// Prevent the model from calling any tools
|
||||
None,
|
||||
}
|
||||
|
||||
/// Tool choice configuration
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
String(String), // "none", "auto", "required"
|
||||
/// String-based tool choice (auto, required, none)
|
||||
Type(ToolChoiceType),
|
||||
/// Specific function to call
|
||||
Function {
|
||||
#[serde(rename = "type")]
|
||||
choice_type: String,
|
||||
|
|
@ -239,7 +252,7 @@ pub enum ToolChoice {
|
|||
}
|
||||
|
||||
/// Specific function choice
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct FunctionChoice {
|
||||
pub name: String,
|
||||
}
|
||||
|
|
@ -671,10 +684,10 @@ mod tests {
|
|||
assert_eq!(parameters["required"], json!(["location"]));
|
||||
|
||||
// Validate tool choice
|
||||
if let Some(ToolChoice::String(choice)) = &deserialized_request.tool_choice {
|
||||
assert_eq!(choice, "auto");
|
||||
if let Some(ToolChoice::Type(choice)) = &deserialized_request.tool_choice {
|
||||
assert_eq!(choice, &ToolChoiceType::Auto);
|
||||
} else {
|
||||
panic!("Expected auto tool choice string");
|
||||
panic!("Expected auto tool choice");
|
||||
}
|
||||
|
||||
// Validate prediction
|
||||
|
|
@ -838,4 +851,33 @@ mod tests {
|
|||
assert!(converted.name.is_none());
|
||||
assert!(converted.tool_call_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_type_serialization() {
|
||||
// Test that the enum serializes to the correct string values
|
||||
let auto_choice = ToolChoice::Type(ToolChoiceType::Auto);
|
||||
let required_choice = ToolChoice::Type(ToolChoiceType::Required);
|
||||
let none_choice = ToolChoice::Type(ToolChoiceType::None);
|
||||
|
||||
let auto_json = serde_json::to_value(&auto_choice).unwrap();
|
||||
let required_json = serde_json::to_value(&required_choice).unwrap();
|
||||
let none_json = serde_json::to_value(&none_choice).unwrap();
|
||||
|
||||
assert_eq!(auto_json, "auto");
|
||||
assert_eq!(required_json, "required");
|
||||
assert_eq!(none_json, "none");
|
||||
|
||||
// Test deserialization from string values
|
||||
let auto_deserialized: ToolChoice = serde_json::from_value(json!("auto")).unwrap();
|
||||
let required_deserialized: ToolChoice = serde_json::from_value(json!("required")).unwrap();
|
||||
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
|
||||
|
||||
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
|
||||
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
|
||||
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
|
||||
|
||||
// Test that invalid string values fail deserialization (type safety!)
|
||||
let invalid_result: Result<ToolChoice, _> = serde_json::from_value(json!("invalid"));
|
||||
assert!(invalid_result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,14 +13,14 @@
|
|||
//!
|
||||
//! ```rust
|
||||
//! use hermesllm::apis::{
|
||||
//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
|
||||
//! AnthropicMessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage,
|
||||
//! MessagesMessageContent, MessagesSystemPrompt,
|
||||
//! };
|
||||
//! use hermesllm::clients::TransformError;
|
||||
//! use std::convert::TryInto;
|
||||
//!
|
||||
//! // Transform Anthropic to OpenAI
|
||||
//! let anthropic_req = MessagesRequest {
|
||||
//! let anthropic_req = AnthropicMessagesRequest {
|
||||
//! model: "claude-3-sonnet".to_string(),
|
||||
//! system: None,
|
||||
//! messages: vec![],
|
||||
|
|
@ -49,6 +49,13 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
|||
use crate::apis::*;
|
||||
use super::TransformError;
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS
|
||||
// ============================================================================
|
||||
|
||||
/// Default maximum tokens when converting from OpenAI to Anthropic and no max_tokens is specified
|
||||
const DEFAULT_MAX_TOKENS: u32 = 4096;
|
||||
|
||||
// ============================================================================
|
||||
// UTILITY TRAITS - Shared traits for content manipulation
|
||||
// ============================================================================
|
||||
|
|
@ -68,10 +75,13 @@ trait ContentUtils<T> {
|
|||
// MAIN REQUEST TRANSFORMATIONS
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<MessagesRequest> for ChatCompletionsRequest {
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
|
||||
impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: MessagesRequest) -> Result<Self, Self::Error> {
|
||||
fn try_from(req: AnthropicMessagesRequest) -> Result<Self, Self::Error> {
|
||||
let mut openai_messages: Vec<Message> = Vec::new();
|
||||
|
||||
// Convert system prompt to system message if present
|
||||
|
|
@ -95,34 +105,17 @@ impl TryFrom<MessagesRequest> for ChatCompletionsRequest {
|
|||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
max_tokens: Some(req.max_tokens),
|
||||
max_completion_tokens: None,
|
||||
stream: req.stream,
|
||||
stream_options: None,
|
||||
stop: req.stop_sequences,
|
||||
tools: openai_tools,
|
||||
tool_choice: openai_tool_choice,
|
||||
parallel_tool_calls,
|
||||
user: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
logit_bias: None,
|
||||
logprobs: None,
|
||||
top_logprobs: None,
|
||||
n: None,
|
||||
seed: None,
|
||||
response_format: None,
|
||||
service_tier: None,
|
||||
store: None,
|
||||
metadata: None,
|
||||
modalities: None,
|
||||
function_call: None,
|
||||
functions: None,
|
||||
prediction: None,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsRequest> for MessagesRequest {
|
||||
impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: ChatCompletionsRequest) -> Result<Self, Self::Error> {
|
||||
|
|
@ -145,11 +138,11 @@ impl TryFrom<ChatCompletionsRequest> for MessagesRequest {
|
|||
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
|
||||
let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
|
||||
|
||||
Ok(MessagesRequest {
|
||||
Ok(AnthropicMessagesRequest {
|
||||
model: req.model,
|
||||
system: system_prompt,
|
||||
messages,
|
||||
max_tokens: req.max_tokens.unwrap_or(4096),
|
||||
max_tokens: req.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
|
|
@ -785,9 +778,9 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
|
|||
match tool_choice {
|
||||
Some(choice) => {
|
||||
let openai_choice = match choice.kind {
|
||||
MessagesToolChoiceType::Auto => ToolChoice::String("auto".to_string()),
|
||||
MessagesToolChoiceType::Any => ToolChoice::String("required".to_string()),
|
||||
MessagesToolChoiceType::None => ToolChoice::String("none".to_string()),
|
||||
MessagesToolChoiceType::Auto => ToolChoice::Type(ToolChoiceType::Auto),
|
||||
MessagesToolChoiceType::Any => ToolChoice::Type(ToolChoiceType::Required),
|
||||
MessagesToolChoiceType::None => ToolChoice::Type(ToolChoiceType::None),
|
||||
MessagesToolChoiceType::Tool => {
|
||||
if let Some(name) = choice.name {
|
||||
ToolChoice::Function {
|
||||
|
|
@ -795,7 +788,7 @@ fn convert_anthropic_tool_choice(tool_choice: Option<MessagesToolChoice>) -> (Op
|
|||
function: FunctionChoice { name },
|
||||
}
|
||||
} else {
|
||||
ToolChoice::String("auto".to_string())
|
||||
ToolChoice::Type(ToolChoiceType::Auto)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -813,27 +806,22 @@ fn convert_openai_tool_choice(
|
|||
) -> Option<MessagesToolChoice> {
|
||||
tool_choice.map(|choice| {
|
||||
match choice {
|
||||
ToolChoice::String(s) => match s.as_str() {
|
||||
"auto" => MessagesToolChoice {
|
||||
ToolChoice::Type(tool_type) => match tool_type {
|
||||
ToolChoiceType::Auto => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
"required" => MessagesToolChoice {
|
||||
ToolChoiceType::Required => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Any,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
"none" => MessagesToolChoice {
|
||||
ToolChoiceType::None => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::None,
|
||||
name: None,
|
||||
disable_parallel_tool_use: None,
|
||||
},
|
||||
_ => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: parallel_tool_calls.map(|p| !p),
|
||||
},
|
||||
},
|
||||
ToolChoice::Function { function, .. } => MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Tool,
|
||||
|
|
@ -1098,7 +1086,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_anthropic_to_openai_basic_request() {
|
||||
let anthropic_req = MessagesRequest {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet-20240229".to_string(),
|
||||
system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())),
|
||||
messages: vec![MessagesMessage {
|
||||
|
|
@ -1134,7 +1122,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_roundtrip_consistency() {
|
||||
// Test that converting back and forth maintains consistency
|
||||
let original_anthropic = MessagesRequest {
|
||||
let original_anthropic = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
system: Some(MessagesSystemPrompt::Single("System prompt".to_string())),
|
||||
messages: vec![MessagesMessage {
|
||||
|
|
@ -1158,7 +1146,7 @@ mod tests {
|
|||
|
||||
// Convert to OpenAI and back
|
||||
let openai_req: ChatCompletionsRequest = original_anthropic.clone().try_into().unwrap();
|
||||
let roundtrip_anthropic: MessagesRequest = openai_req.try_into().unwrap();
|
||||
let roundtrip_anthropic: AnthropicMessagesRequest = openai_req.try_into().unwrap();
|
||||
|
||||
// Check key fields are preserved
|
||||
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
|
||||
|
|
@ -1171,7 +1159,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_tool_choice_auto() {
|
||||
let anthropic_req = MessagesRequest {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3".to_string(),
|
||||
system: None,
|
||||
messages: vec![],
|
||||
|
|
@ -1203,8 +1191,8 @@ mod tests {
|
|||
assert!(openai_req.tools.is_some());
|
||||
assert_eq!(openai_req.tools.as_ref().unwrap().len(), 1);
|
||||
|
||||
if let Some(ToolChoice::String(choice)) = openai_req.tool_choice {
|
||||
assert_eq!(choice, "auto");
|
||||
if let Some(ToolChoice::Type(choice)) = openai_req.tool_choice {
|
||||
assert_eq!(choice, ToolChoiceType::Auto);
|
||||
} else {
|
||||
panic!("Expected auto tool choice");
|
||||
}
|
||||
|
|
@ -1212,6 +1200,27 @@ mod tests {
|
|||
assert_eq!(openai_req.parallel_tool_calls, Some(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_max_tokens_used_when_openai_has_none() {
|
||||
// Test that DEFAULT_MAX_TOKENS is used when OpenAI request has no max_tokens
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
max_tokens: None, // No max_tokens specified
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let anthropic_req: AnthropicMessagesRequest = openai_req.try_into().unwrap();
|
||||
|
||||
assert_eq!(anthropic_req.max_tokens, DEFAULT_MAX_TOKENS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_message_start_streaming() {
|
||||
let event = MessagesStreamEvent::MessageStart {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue