mirror of
https://github.com/katanemo/plano.git
synced 2026-06-02 14:35:14 +02:00
making Messages.Content optional, and having the upstream LLM fail if the right fields aren't set (#699)
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
626f556cc6
commit
cdc1d7cee2
17 changed files with 294 additions and 133 deletions
|
|
@ -402,7 +402,7 @@ async fn handle_agent_chat(
|
|||
// and add it to the conversation history
|
||||
current_messages.push(OpenAIMessage {
|
||||
role: hermesllm::apis::openai::Role::Assistant,
|
||||
content: hermesllm::apis::openai::MessageContent::Text(response_text),
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
|
||||
name: Some(agent_name.clone()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
|
|
@ -638,7 +638,7 @@ impl ArchFunctionHandler {
|
|||
let system_prompt = self.format_system_prompt(tools)?;
|
||||
processed_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(system_prompt),
|
||||
content: Some(MessageContent::Text(system_prompt)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -649,8 +649,9 @@ impl ArchFunctionHandler {
|
|||
for (idx, message) in messages.iter().enumerate() {
|
||||
let mut role = message.role.clone();
|
||||
let mut content = match &message.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(_) => String::new(),
|
||||
Some(MessageContent::Text(text)) => text.clone(),
|
||||
Some(MessageContent::Parts(_)) => String::new(),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
// Handle tool calls
|
||||
|
|
@ -675,7 +676,8 @@ impl ArchFunctionHandler {
|
|||
} else {
|
||||
// Get the tool call from previous message
|
||||
if idx > 0 {
|
||||
if let MessageContent::Text(prev_content) = &messages[idx - 1].content {
|
||||
if let Some(MessageContent::Text(prev_content)) = &messages[idx - 1].content
|
||||
{
|
||||
let mut tool_call_msg = prev_content.clone();
|
||||
|
||||
// Strip markdown code blocks
|
||||
|
|
@ -721,7 +723,7 @@ impl ArchFunctionHandler {
|
|||
|
||||
processed_messages.push(Message {
|
||||
role,
|
||||
content: MessageContent::Text(content),
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: message.name.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -740,7 +742,7 @@ impl ArchFunctionHandler {
|
|||
// Add extra instruction if provided
|
||||
if let Some(instruction) = extra_instruction {
|
||||
if let Some(last) = processed_messages.last_mut() {
|
||||
if let MessageContent::Text(content) = &mut last.content {
|
||||
if let Some(MessageContent::Text(content)) = &mut last.content {
|
||||
content.push('\n');
|
||||
content.push_str(instruction);
|
||||
}
|
||||
|
|
@ -761,7 +763,7 @@ impl ArchFunctionHandler {
|
|||
// Keep system message if present
|
||||
if let Some(first) = messages.first() {
|
||||
if first.role == Role::System {
|
||||
if let MessageContent::Text(content) = &first.content {
|
||||
if let Some(MessageContent::Text(content)) = &first.content {
|
||||
num_tokens += content.len() / 4; // Approximate 4 chars per token
|
||||
}
|
||||
conversation_idx = 1;
|
||||
|
|
@ -772,7 +774,7 @@ impl ArchFunctionHandler {
|
|||
// Start with message_idx pointing past the end (will be used if no truncation needed)
|
||||
let mut message_idx = messages.len();
|
||||
for i in (conversation_idx..messages.len()).rev() {
|
||||
if let MessageContent::Text(content) = &messages[i].content {
|
||||
if let Some(MessageContent::Text(content)) = &messages[i].content {
|
||||
num_tokens += content.len() / 4;
|
||||
if num_tokens >= max_tokens && messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
|
|
@ -802,7 +804,7 @@ impl ArchFunctionHandler {
|
|||
pub fn prefill_message(&self, mut messages: Vec<Message>, prefill: &str) -> Vec<Message> {
|
||||
messages.push(Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text(prefill.to_string()),
|
||||
content: Some(MessageContent::Text(prefill.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ mod tests {
|
|||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -129,7 +129,7 @@ mod tests {
|
|||
let processed_messages = result.unwrap();
|
||||
// With empty filter chain, should return the original messages unchanged
|
||||
assert_eq!(processed_messages.len(), 1);
|
||||
if let MessageContent::Text(content) = &processed_messages[0].content {
|
||||
if let Some(MessageContent::Text(content)) = &processed_messages[0].content {
|
||||
assert_eq!(content, "Hello world!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
|
|||
|
|
@ -887,7 +887,7 @@ mod tests {
|
|||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,9 @@ pub async fn router_chat_get_upstream_model(
|
|||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
msg.content
|
||||
.as_ref()
|
||||
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use std::collections::HashMap;
|
|||
|
||||
use common::configuration::{AgentUsagePreference, OrchestrationPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
@ -181,7 +182,9 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty()
|
||||
m.role != Role::System
|
||||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
|
|
@ -190,7 +193,7 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR;
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
|
|
@ -240,7 +243,12 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
.rev()
|
||||
.map(|message| Message {
|
||||
role: message.role.clone(),
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -262,7 +270,7 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
ChatCompletionsRequest {
|
||||
model: self.orchestration_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: MessageContent::Text(orchestrator_message),
|
||||
content: Some(MessageContent::Text(orchestrator_message)),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -539,7 +547,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -618,7 +626,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
}]);
|
||||
let req = orchestrator.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -689,7 +697,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -761,7 +769,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -848,7 +856,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -940,7 +948,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -1058,7 +1066,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req: ChatCompletionsRequest = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use std::collections::HashMap;
|
|||
|
||||
use common::configuration::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
@ -78,7 +79,9 @@ impl RouterModel for RouterModelV1 {
|
|||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty()
|
||||
m.role != Role::System
|
||||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
|
|
@ -87,7 +90,7 @@ impl RouterModel for RouterModelV1 {
|
|||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR;
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
|
|
@ -136,7 +139,12 @@ impl RouterModel for RouterModelV1 {
|
|||
Message {
|
||||
role: message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -154,7 +162,7 @@ impl RouterModel for RouterModelV1 {
|
|||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: MessageContent::Text(router_message),
|
||||
content: Some(MessageContent::Text(router_message)),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -344,7 +352,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -409,7 +417,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -469,7 +477,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -530,7 +538,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -598,7 +606,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -667,7 +675,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -762,7 +770,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1122,9 +1122,9 @@ pub struct TextBasedSignalAnalyzer {
|
|||
|
||||
impl TextBasedSignalAnalyzer {
|
||||
/// Extract text content from MessageContent, skipping non-text content
|
||||
fn extract_text(content: &hermesllm::apis::openai::MessageContent) -> Option<String> {
|
||||
fn extract_text(content: &Option<hermesllm::apis::openai::MessageContent>) -> Option<String> {
|
||||
match content {
|
||||
hermesllm::apis::openai::MessageContent::Text(text) => Some(text.clone()),
|
||||
Some(hermesllm::apis::openai::MessageContent::Text(text)) => Some(text.clone()),
|
||||
// Tool calls and other structured content are skipped
|
||||
_ => None,
|
||||
}
|
||||
|
|
@ -1941,12 +1941,13 @@ impl Default for TextBasedSignalAnalyzer {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai::MessageContent;
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use std::time::Instant;
|
||||
|
||||
fn create_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -2130,7 +2131,7 @@ mod tests {
|
|||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, msg)| {
|
||||
let text = msg.content.to_string();
|
||||
let text = msg.content.extract_text();
|
||||
(i, msg.role.clone(), NormalizedMessage::from_text(&text))
|
||||
})
|
||||
.collect()
|
||||
|
|
@ -2532,7 +2533,7 @@ mod tests {
|
|||
|content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: tool_id.to_string(),
|
||||
|
|
@ -2550,7 +2551,7 @@ mod tests {
|
|||
let create_tool_message = |tool_call_id: &str, content: &str| -> Message {
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
|
|
@ -2665,7 +2666,7 @@ mod tests {
|
|||
|content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: tool_id.to_string(),
|
||||
|
|
@ -2683,7 +2684,7 @@ mod tests {
|
|||
let create_tool_message = |tool_call_id: &str, content: &str| -> Message {
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
|
|
|
|||
|
|
@ -225,7 +225,7 @@ impl ProviderRequest for ConverseRequest {
|
|||
if let SystemContentBlock::Text { text } = sys_block {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
content: Some(MessageContent::Text(text.clone())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -258,7 +258,7 @@ impl ProviderRequest for ConverseRequest {
|
|||
|
||||
openai_messages.push(Message {
|
||||
role,
|
||||
content: MessageContent::Text(content),
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -279,7 +279,7 @@ impl ProviderRequest for ConverseRequest {
|
|||
for msg in messages {
|
||||
match msg.role {
|
||||
crate::apis::openai::Role::System => {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content {
|
||||
system_blocks.push(SystemContentBlock::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
|
|
@ -290,12 +290,13 @@ impl ProviderRequest for ConverseRequest {
|
|||
_ => continue,
|
||||
};
|
||||
|
||||
let content =
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let content = if let Some(crate::apis::openai::MessageContent::Text(text)) =
|
||||
&msg.content
|
||||
{
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content });
|
||||
}
|
||||
|
|
|
|||
|
|
@ -584,7 +584,7 @@ impl ProviderRequest for MessagesRequest {
|
|||
let system_text = system_messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
|
|
|
|||
|
|
@ -155,7 +155,8 @@ pub enum Role {
|
|||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: MessageContent,
|
||||
/// The contents of the message. Required unless tool_calls is specified (for assistant role)
|
||||
pub content: Option<MessageContent>,
|
||||
pub name: Option<String>,
|
||||
/// Tool calls made by the assistant (only present for assistant role)
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
|
@ -204,8 +205,7 @@ impl ResponseMessage {
|
|||
content: self
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|s| MessageContent::Text(s.clone()))
|
||||
.unwrap_or(MessageContent::Text(String::new())),
|
||||
.map(|s| MessageContent::Text(s.clone())),
|
||||
name: None, // Response messages don't have names in the same way request messages do
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None, // Response messages don't have tool_call_id
|
||||
|
|
@ -233,6 +233,12 @@ impl ExtractText for MessageContent {
|
|||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Option<MessageContent> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.as_ref().map(|c| c.extract_text()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<ContentPart> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
|
|
@ -247,23 +253,7 @@ impl ExtractText for Vec<ContentPart> {
|
|||
|
||||
impl Display for MessageContent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageContent::Text(text) => write!(f, "{}", text),
|
||||
MessageContent::Parts(parts) => {
|
||||
let text_parts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
ContentPart::ImageUrl { .. } => {
|
||||
// skip image URLs or their data in text representation
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let combined_text = text_parts.join("\n");
|
||||
write!(f, "{}", combined_text)
|
||||
}
|
||||
}
|
||||
write!(f, "{}", self.extract_text())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -622,8 +612,10 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " "
|
||||
+ &match &m.content {
|
||||
let content_text = m
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|content| match content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts
|
||||
.iter()
|
||||
|
|
@ -633,16 +625,18 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
acc + " " + &content_text
|
||||
})
|
||||
}
|
||||
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
self.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
msg.content.as_ref().and_then(|content| match content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Parts(_) => None, // No user message in parts
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -778,7 +772,8 @@ mod tests {
|
|||
|
||||
let message = &deserialized_request.messages[0];
|
||||
assert_eq!(message.role, Role::User);
|
||||
if let MessageContent::Text(content) = &message.content {
|
||||
assert!(message.content.is_some());
|
||||
if let Some(MessageContent::Text(content)) = &message.content {
|
||||
assert_eq!(content, "Hello, world!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -822,7 +817,8 @@ mod tests {
|
|||
|
||||
let message = &deserialized_request.messages[0];
|
||||
assert_eq!(message.role, Role::User);
|
||||
if let MessageContent::Text(content) = &message.content {
|
||||
assert!(message.content.is_some());
|
||||
if let Some(MessageContent::Text(content)) = &message.content {
|
||||
assert_eq!(content, "Test message");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -947,7 +943,8 @@ mod tests {
|
|||
// Validate first message (user with multimodal content)
|
||||
let user_message = &deserialized_request.messages[0];
|
||||
assert_eq!(user_message.role, Role::User);
|
||||
if let MessageContent::Parts(ref content_parts) = user_message.content {
|
||||
assert!(user_message.content.is_some());
|
||||
if let Some(MessageContent::Parts(ref content_parts)) = user_message.content {
|
||||
assert_eq!(content_parts.len(), 2);
|
||||
|
||||
// Validate text content part
|
||||
|
|
@ -971,7 +968,8 @@ mod tests {
|
|||
// Validate second message (assistant with tool calls)
|
||||
let assistant_message = &deserialized_request.messages[1];
|
||||
assert_eq!(assistant_message.role, Role::Assistant);
|
||||
if let MessageContent::Text(text) = &assistant_message.content {
|
||||
assert!(assistant_message.content.is_some());
|
||||
if let Some(MessageContent::Text(text)) = &assistant_message.content {
|
||||
assert_eq!(
|
||||
text,
|
||||
"I can see a beautiful cityscape. Let me check the weather for you."
|
||||
|
|
@ -997,7 +995,8 @@ mod tests {
|
|||
// Validate third message (tool response)
|
||||
let tool_message = &deserialized_request.messages[2];
|
||||
assert_eq!(tool_message.role, Role::Tool);
|
||||
if let MessageContent::Text(text) = &tool_message.content {
|
||||
assert!(tool_message.content.is_some());
|
||||
if let Some(MessageContent::Text(text)) = &tool_message.content {
|
||||
assert_eq!(text, "Current weather in New York: 72°F, sunny");
|
||||
} else {
|
||||
panic!("Expected text content for tool message");
|
||||
|
|
@ -1061,6 +1060,62 @@ mod tests {
|
|||
assert!((original_temp - serialized_temp).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_assistant_message_with_tool_calls_no_content() {
|
||||
// Test that assistant messages can have tool_calls without content
|
||||
let json_with_tool_calls_no_content = json!({
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"location\": \"San Francisco, CA\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Should deserialize successfully
|
||||
let request: ChatCompletionsRequest =
|
||||
serde_json::from_value(json_with_tool_calls_no_content.clone()).unwrap();
|
||||
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
|
||||
// Check user message
|
||||
let user_msg = &request.messages[0];
|
||||
assert_eq!(user_msg.role, Role::User);
|
||||
assert!(user_msg.content.is_some());
|
||||
|
||||
// Check assistant message - should have tool_calls but no content
|
||||
let assistant_msg = &request.messages[1];
|
||||
assert_eq!(assistant_msg.role, Role::Assistant);
|
||||
assert!(assistant_msg.content.is_none());
|
||||
assert!(assistant_msg.tool_calls.is_some());
|
||||
|
||||
let tool_calls = assistant_msg.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, "call_123");
|
||||
assert_eq!(tool_calls[0].function.name, "get_weather");
|
||||
|
||||
// Should serialize back without content field
|
||||
let serialized = serde_json::to_value(&request).unwrap();
|
||||
// Verify the assistant message doesn't have a content field in serialized JSON
|
||||
let serialized_assistant_msg = &serialized["messages"][1];
|
||||
assert!(serialized_assistant_msg.get("content").is_none());
|
||||
assert!(serialized_assistant_msg.get("tool_calls").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_provider_trait() {
|
||||
// Test the ApiDefinition trait implementation
|
||||
|
|
@ -1097,7 +1152,7 @@ mod tests {
|
|||
|
||||
let deserialized_user: Message = serde_json::from_value(user_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_user.role, Role::User);
|
||||
if let MessageContent::Text(content) = &deserialized_user.content {
|
||||
if let Some(MessageContent::Text(content)) = &deserialized_user.content {
|
||||
assert_eq!(content, "Hello!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -1128,7 +1183,7 @@ mod tests {
|
|||
let deserialized_assistant: Message =
|
||||
serde_json::from_value(assistant_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_assistant.role, Role::Assistant);
|
||||
if let MessageContent::Text(content) = &deserialized_assistant.content {
|
||||
if let Some(MessageContent::Text(content)) = &deserialized_assistant.content {
|
||||
assert_eq!(content, "I'll help with that.");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -1154,7 +1209,7 @@ mod tests {
|
|||
|
||||
let deserialized_tool: Message = serde_json::from_value(tool_json.clone()).unwrap();
|
||||
assert_eq!(deserialized_tool.role, Role::Tool);
|
||||
if let MessageContent::Text(content) = &deserialized_tool.content {
|
||||
if let Some(MessageContent::Text(content)) = &deserialized_tool.content {
|
||||
assert_eq!(content, "Weather is sunny");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
@ -1193,7 +1248,7 @@ mod tests {
|
|||
// Test conversion from ResponseMessage to Message
|
||||
let converted = deserialized_response.to_message();
|
||||
assert_eq!(converted.role, Role::Assistant);
|
||||
if let MessageContent::Text(text) = converted.content {
|
||||
if let Some(MessageContent::Text(text)) = converted.content {
|
||||
assert_eq!(text, "Response content");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
|
|||
|
|
@ -1146,7 +1146,7 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
.iter()
|
||||
.filter(|msg| msg.role == crate::apis::openai::Role::System)
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
|
|
@ -1170,7 +1170,8 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
if !input_messages.is_empty() {
|
||||
// If there's only one message, use Text format
|
||||
if input_messages.len() == 1 {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) =
|
||||
&input_messages[0].content
|
||||
{
|
||||
self.input = crate::apis::openai_responses::InputParam::Text(text.clone());
|
||||
}
|
||||
|
|
@ -1180,7 +1181,8 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
let combined_text = input_messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content
|
||||
{
|
||||
Some(format!(
|
||||
"{}: {}",
|
||||
match msg.role {
|
||||
|
|
|
|||
|
|
@ -671,14 +671,16 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
content: Some(MessageContent::Text(
|
||||
"You are a helpful assistant".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
content: Some(MessageContent::Text("Hello!".to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -900,7 +902,7 @@ mod tests {
|
|||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
content: Some(MessageContent::Text("Hello!".to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -993,14 +995,14 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are helpful".to_string()),
|
||||
content: Some(MessageContent::Text("You are helpful".to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
content: Some(MessageContent::Text("Hello!".to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ pub fn convert_openai_message_to_anthropic_content(
|
|||
|
||||
// Handle regular content
|
||||
match &message.content {
|
||||
MessageContent::Text(text) => {
|
||||
Some(MessageContent::Text(text)) => {
|
||||
if !text.is_empty() {
|
||||
blocks.push(MessagesContentBlock::Text {
|
||||
text: text.clone(),
|
||||
|
|
@ -196,7 +196,7 @@ pub fn convert_openai_message_to_anthropic_content(
|
|||
});
|
||||
}
|
||||
}
|
||||
MessageContent::Parts(parts) => {
|
||||
Some(MessageContent::Parts(parts)) => {
|
||||
for part in parts {
|
||||
match part {
|
||||
ContentPart::Text { text } => {
|
||||
|
|
@ -212,6 +212,7 @@ pub fn convert_openai_message_to_anthropic_content(
|
|||
}
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
|
|||
MessagesMessageContent::Single(text) => {
|
||||
result.push(Message {
|
||||
role: message.role.into(),
|
||||
content: MessageContent::Text(text),
|
||||
content: Some(MessageContent::Text(text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -186,7 +186,7 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
|
|||
for (tool_use_id, result_text, _is_error) in tool_results {
|
||||
result.push(Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::Text(result_text),
|
||||
content: Some(MessageContent::Text(result_text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_use_id),
|
||||
|
|
@ -260,7 +260,7 @@ impl From<MessagesSystemPrompt> for Message {
|
|||
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: system_content,
|
||||
content: Some(system_content),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -317,16 +317,19 @@ fn convert_anthropic_tool_choice(
|
|||
fn build_openai_content(
|
||||
content_parts: Vec<ContentPart>,
|
||||
tool_calls: &[ToolCall],
|
||||
) -> MessageContent {
|
||||
if content_parts.len() == 1 && tool_calls.is_empty() {
|
||||
) -> Option<MessageContent> {
|
||||
if content_parts.is_empty() && !tool_calls.is_empty() {
|
||||
// For assistant messages with only tool calls, content is optional
|
||||
None
|
||||
} else if content_parts.len() == 1 && tool_calls.is_empty() {
|
||||
match &content_parts[0] {
|
||||
ContentPart::Text { text } => MessageContent::Text(text.clone()),
|
||||
_ => MessageContent::Parts(content_parts),
|
||||
ContentPart::Text { text } => Some(MessageContent::Text(text.clone())),
|
||||
_ => Some(MessageContent::Parts(content_parts)),
|
||||
}
|
||||
} else if content_parts.is_empty() {
|
||||
MessageContent::Text("".to_string())
|
||||
Some(MessageContent::Text("".to_string()))
|
||||
} else {
|
||||
MessageContent::Parts(content_parts)
|
||||
Some(MessageContent::Parts(content_parts))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ use crate::apis::openai_responses::{
|
|||
ResponsesAPIRequest, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::transforms::lib::*;
|
||||
use crate::transforms::*;
|
||||
|
||||
|
|
@ -48,7 +47,7 @@ impl TryFrom<ResponsesInputConverter> for Vec<Message> {
|
|||
if let Some(instructions) = converter.instructions {
|
||||
messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(instructions),
|
||||
content: Some(MessageContent::Text(instructions)),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -58,7 +57,7 @@ impl TryFrom<ResponsesInputConverter> for Vec<Message> {
|
|||
// Add the user message
|
||||
messages.push(Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text(text),
|
||||
content: Some(MessageContent::Text(text)),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -74,7 +73,7 @@ impl TryFrom<ResponsesInputConverter> for Vec<Message> {
|
|||
if let Some(instructions) = converter.instructions {
|
||||
converted_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(instructions),
|
||||
content: Some(MessageContent::Text(instructions)),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -154,7 +153,7 @@ impl TryFrom<ResponsesInputConverter> for Vec<Message> {
|
|||
|
||||
converted_messages.push(Message {
|
||||
role,
|
||||
content,
|
||||
content: Some(content),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -174,11 +173,7 @@ impl TryFrom<ResponsesInputConverter> for Vec<Message> {
|
|||
|
||||
impl From<Message> for MessagesSystemPrompt {
|
||||
fn from(val: Message) -> Self {
|
||||
let system_text = match val.content {
|
||||
MessageContent::Text(text) => text,
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
};
|
||||
MessagesSystemPrompt::Single(system_text)
|
||||
MessagesSystemPrompt::Single(val.content.extract_text())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -191,6 +186,8 @@ impl TryFrom<Message> for MessagesMessage {
|
|||
Role::Assistant => MessagesRole::Assistant,
|
||||
Role::Tool => {
|
||||
// Tool messages become user messages with tool results
|
||||
// Extract content text first, before moving tool_call_id
|
||||
let content_text = message.content.extract_text();
|
||||
let tool_call_id = message.tool_call_id.ok_or_else(|| {
|
||||
TransformError::MissingField(
|
||||
"tool_call_id required for Tool messages".to_string(),
|
||||
|
|
@ -204,7 +201,7 @@ impl TryFrom<Message> for MessagesMessage {
|
|||
tool_use_id: tool_call_id,
|
||||
is_error: None,
|
||||
content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text {
|
||||
text: message.content.extract_text(),
|
||||
text: content_text,
|
||||
cache_control: None,
|
||||
}]),
|
||||
cache_control: None,
|
||||
|
|
@ -248,12 +245,12 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
Role::User => {
|
||||
// Convert user message content to content blocks
|
||||
match message.content {
|
||||
MessageContent::Text(text) => {
|
||||
Some(MessageContent::Text(text)) => {
|
||||
if !text.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text { text });
|
||||
}
|
||||
}
|
||||
MessageContent::Parts(parts) => {
|
||||
Some(MessageContent::Parts(parts)) => {
|
||||
// Convert OpenAI content parts to Bedrock ContentBlocks
|
||||
for part in parts {
|
||||
match part {
|
||||
|
|
@ -293,6 +290,9 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Empty content for user - shouldn't happen but handle gracefully
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have at least one content block
|
||||
|
|
@ -550,10 +550,7 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
|
|||
for message in req.messages {
|
||||
match message.role {
|
||||
Role::System => {
|
||||
let system_text = match message.content {
|
||||
MessageContent::Text(text) => text,
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
};
|
||||
let system_text = message.content.extract_text();
|
||||
system_messages.push(SystemContentBlock::Text { text: system_text });
|
||||
}
|
||||
_ => {
|
||||
|
|
@ -778,14 +775,16 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant.".to_string()),
|
||||
content: Some(MessageContent::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello, how are you?".to_string()),
|
||||
content: Some(MessageContent::Text("Hello, how are you?".to_string())),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -840,7 +839,7 @@ mod tests {
|
|||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("What's the weather like?".to_string()),
|
||||
content: Some(MessageContent::Text("What's the weather like?".to_string())),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -907,7 +906,7 @@ mod tests {
|
|||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Help me with something".to_string()),
|
||||
content: Some(MessageContent::Text("Help me with something".to_string())),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -950,28 +949,30 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("Be concise".to_string()),
|
||||
content: Some(MessageContent::Text("Be concise".to_string())),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello".to_string()),
|
||||
content: Some(MessageContent::Text("Hello".to_string())),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text("Hi there! How can I help you?".to_string()),
|
||||
content: Some(MessageContent::Text(
|
||||
"Hi there! How can I help you?".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("What's 2+2?".to_string()),
|
||||
content: Some(MessageContent::Text("What's 2+2?".to_string())),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -1009,7 +1010,7 @@ mod tests {
|
|||
fn test_openai_message_to_bedrock_conversion() {
|
||||
let openai_message = Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Test message".to_string()),
|
||||
content: Some(MessageContent::Text("Test message".to_string())),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,81 @@ LLM_GATEWAY_ENDPOINT = os.getenv(
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def test_assistant_message_with_null_content_and_tool_calls():
|
||||
"""Test that assistant messages with null content and tool_calls are properly handled"""
|
||||
logger.info(
|
||||
"Testing assistant message with null content and tool_calls (multi-turn conversation)"
|
||||
)
|
||||
|
||||
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||
client = openai.OpenAI(
|
||||
api_key="test-key",
|
||||
base_url=f"{base_url}/v1",
|
||||
)
|
||||
|
||||
# Simulate a multi-turn conversation where:
|
||||
# 1. User asks a question
|
||||
# 2. Assistant makes a tool call (with null content)
|
||||
# 3. Tool responds
|
||||
# 4. Assistant should provide final answer
|
||||
completion = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
max_tokens=500,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a weather assistant. Use the get_weather tool to fetch weather information.",
|
||||
},
|
||||
{"role": "user", "content": "What's the weather in Seattle?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None, # This is the key test - null content with tool_calls
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_test123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Seattle"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_test123",
|
||||
"content": '{"location": "Seattle", "temperature": "10°C", "condition": "Partly cloudy"}',
|
||||
},
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information for a city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
response_content = completion.choices[0].message.content
|
||||
logger.info(f"Response after tool call: {response_content}")
|
||||
|
||||
# The assistant should provide a final response using the tool result
|
||||
assert response_content is not None
|
||||
assert len(response_content) > 0
|
||||
logger.info(
|
||||
"✓ Assistant message with null content and tool_calls handled correctly"
|
||||
)
|
||||
|
||||
|
||||
def test_openai_client_with_alias_arch_summarize_v1():
|
||||
"""Test OpenAI client using model alias 'arch.summarize.v1' which should resolve to '4o-mini'"""
|
||||
logger.info("Testing OpenAI client with alias 'arch.summarize.v1' -> '4o-mini'")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue