diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index bac9607b..5ced34c0 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -402,7 +402,7 @@ async fn handle_agent_chat( // and add it to the conversation history current_messages.push(OpenAIMessage { role: hermesllm::apis::openai::Role::Assistant, - content: hermesllm::apis::openai::MessageContent::Text(response_text), + content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)), name: Some(agent_name.clone()), tool_calls: None, tool_call_id: None, diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs index 8f641df6..7ba15e2d 100644 --- a/crates/brightstaff/src/handlers/function_calling.rs +++ b/crates/brightstaff/src/handlers/function_calling.rs @@ -638,7 +638,7 @@ impl ArchFunctionHandler { let system_prompt = self.format_system_prompt(tools)?; processed_messages.push(Message { role: Role::System, - content: MessageContent::Text(system_prompt), + content: Some(MessageContent::Text(system_prompt)), name: None, tool_calls: None, tool_call_id: None, @@ -649,8 +649,9 @@ impl ArchFunctionHandler { for (idx, message) in messages.iter().enumerate() { let mut role = message.role.clone(); let mut content = match &message.content { - MessageContent::Text(text) => text.clone(), - MessageContent::Parts(_) => String::new(), + Some(MessageContent::Text(text)) => text.clone(), + Some(MessageContent::Parts(_)) => String::new(), + None => String::new(), }; // Handle tool calls @@ -675,7 +676,8 @@ impl ArchFunctionHandler { } else { // Get the tool call from previous message if idx > 0 { - if let MessageContent::Text(prev_content) = &messages[idx - 1].content { + if let Some(MessageContent::Text(prev_content)) = &messages[idx - 1].content + { let mut tool_call_msg = prev_content.clone(); // Strip markdown code blocks @@ -721,7 +723,7 @@ impl ArchFunctionHandler { processed_messages.push(Message { role, - content: MessageContent::Text(content), + content: Some(MessageContent::Text(content)), name: message.name.clone(), tool_calls: None, tool_call_id: None, @@ -740,7 +742,7 @@ impl ArchFunctionHandler { // Add extra instruction if provided if let Some(instruction) = extra_instruction { if let Some(last) = processed_messages.last_mut() { - if let MessageContent::Text(content) = &mut last.content { + if let Some(MessageContent::Text(content)) = &mut last.content { content.push('\n'); content.push_str(instruction); } @@ -761,7 +763,7 @@ impl ArchFunctionHandler { // Keep system message if present if let Some(first) = messages.first() { if first.role == Role::System { - if let MessageContent::Text(content) = &first.content { + if let Some(MessageContent::Text(content)) = &first.content { num_tokens += content.len() / 4; // Approximate 4 chars per token } conversation_idx = 1; @@ -772,7 +774,7 @@ impl ArchFunctionHandler { // Start with message_idx pointing past the end (will be used if no truncation needed) let mut message_idx = messages.len(); for i in (conversation_idx..messages.len()).rev() { - if let MessageContent::Text(content) = &messages[i].content { + if let Some(MessageContent::Text(content)) = &messages[i].content { num_tokens += content.len() / 4; if num_tokens >= max_tokens && messages[i].role == Role::User { // Set message_idx to current position and break @@ -802,7 +804,7 @@ impl ArchFunctionHandler { pub fn prefill_message(&self, mut messages: Vec, prefill: &str) -> Vec { messages.push(Message { role: Role::Assistant, - content: MessageContent::Text(prefill.to_string()), + content: Some(MessageContent::Text(prefill.to_string())), name: None, tool_calls: None, tool_call_id: None, diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index 29552f83..9239f94a 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -28,7 +28,7 @@ mod tests { fn create_test_message(role: Role, content: &str) -> Message { Message { role, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: None, @@ -129,7 +129,7 @@ mod tests { let processed_messages = result.unwrap(); // With empty filter chain, should return the original messages unchanged assert_eq!(processed_messages.len(), 1); - if let MessageContent::Text(content) = &processed_messages[0].content { + if let Some(MessageContent::Text(content)) = &processed_messages[0].content { assert_eq!(content, "Hello world!"); } else { panic!("Expected text content"); diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index 09520617..bc36de01 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -887,7 +887,7 @@ mod tests { fn create_test_message(role: Role, content: &str) -> Message { Message { role, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: None, diff --git a/crates/brightstaff/src/handlers/router_chat.rs b/crates/brightstaff/src/handlers/router_chat.rs index 67e25338..701e8e51 100644 --- a/crates/brightstaff/src/handlers/router_chat.rs +++ b/crates/brightstaff/src/handlers/router_chat.rs @@ -95,7 +95,9 @@ pub async fn router_chat_get_upstream_model( .messages .last() .map_or("None".to_string(), |msg| { - msg.content.to_string().replace('\n', "\\n") + msg.content + .as_ref() + .map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n")) }); const MAX_MESSAGE_LENGTH: usize = 50; diff --git a/crates/brightstaff/src/router/orchestrator_model_v1.rs b/crates/brightstaff/src/router/orchestrator_model_v1.rs index ef32db83..8d64f8e7 100644 --- a/crates/brightstaff/src/router/orchestrator_model_v1.rs +++ b/crates/brightstaff/src/router/orchestrator_model_v1.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use common::configuration::{AgentUsagePreference, OrchestrationPreference}; use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; +use hermesllm::transforms::lib::ExtractText; use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize}; use tracing::{debug, warn}; @@ -181,7 +182,9 @@ impl OrchestratorModel for OrchestratorModelV1 { let messages_vec = messages .iter() .filter(|m| { - m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty() + m.role != Role::System + && m.role != Role::Tool + && !m.content.extract_text().is_empty() }) .collect::>(); @@ -190,7 +193,7 @@ impl OrchestratorModel for OrchestratorModelV1 { let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut selected_messages_list_reversed: Vec<&Message> = vec![]; for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { - let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR; + let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR; token_count += message_token_count; if token_count > self.max_token_length { debug!( @@ -240,7 +243,12 @@ impl OrchestratorModel for OrchestratorModelV1 { .rev() .map(|message| Message { role: message.role.clone(), - content: MessageContent::Text(message.content.to_string()), + content: Some(MessageContent::Text( + message + .content + .as_ref() + .map_or(String::new(), |c| c.to_string()), + )), name: None, tool_calls: None, tool_call_id: None, @@ -262,7 +270,7 @@ impl OrchestratorModel for OrchestratorModelV1 { ChatCompletionsRequest { model: self.orchestration_model.clone(), messages: vec![Message { - content: MessageContent::Text(orchestrator_message), + content: Some(MessageContent::Text(orchestrator_message)), role: Role::User, name: None, tool_calls: None, @@ -539,7 +547,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -618,7 +626,7 @@ If no routes are needed, return an empty list for `route`. }]); let req = orchestrator.generate_request(&conversation, &usage_preferences); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -689,7 +697,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -761,7 +769,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -848,7 +856,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -940,7 +948,7 @@ If no routes are needed, return an empty list for `route`. let req = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -1058,7 +1066,7 @@ If no routes are needed, return an empty list for `route`. let req: ChatCompletionsRequest = orchestrator.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 84680928..796dfaac 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use common::configuration::{ModelUsagePreference, RoutingPreference}; use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; +use hermesllm::transforms::lib::ExtractText; use serde::{Deserialize, Serialize}; use tracing::{debug, warn}; @@ -78,7 +79,9 @@ impl RouterModel for RouterModelV1 { let messages_vec = messages .iter() .filter(|m| { - m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty() + m.role != Role::System + && m.role != Role::Tool + && !m.content.extract_text().is_empty() }) .collect::>(); @@ -87,7 +90,7 @@ impl RouterModel for RouterModelV1 { let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut selected_messages_list_reversed: Vec<&Message> = vec![]; for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { - let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR; + let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR; token_count += message_token_count; if token_count > self.max_token_length { debug!( @@ -136,7 +139,12 @@ impl RouterModel for RouterModelV1 { Message { role: message.role.clone(), // we can unwrap here because we have already filtered out messages without content - content: MessageContent::Text(message.content.to_string()), + content: Some(MessageContent::Text( + message + .content + .as_ref() + .map_or(String::new(), |c| c.to_string()), + )), name: None, tool_calls: None, tool_call_id: None, @@ -154,7 +162,7 @@ impl RouterModel for RouterModelV1 { ChatCompletionsRequest { model: self.routing_model.clone(), messages: vec![Message { - content: MessageContent::Text(router_message), + content: Some(MessageContent::Text(router_message)), role: Role::User, name: None, tool_calls: None, @@ -344,7 +352,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -409,7 +417,7 @@ Based on your analysis, provide your response in the following JSON formats if y }]); let req = router.generate_request(&conversation, &usage_preferences); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -469,7 +477,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -530,7 +538,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -598,7 +606,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -667,7 +675,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } @@ -762,7 +770,7 @@ Based on your analysis, provide your response in the following JSON formats if y let req: ChatCompletionsRequest = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.to_string(); + let prompt = req.messages[0].content.extract_text(); assert_eq!(expected_prompt, prompt); } diff --git a/crates/brightstaff/src/signals/analyzer.rs b/crates/brightstaff/src/signals/analyzer.rs index 9880bf2c..5ee3c7d9 100644 --- a/crates/brightstaff/src/signals/analyzer.rs +++ b/crates/brightstaff/src/signals/analyzer.rs @@ -1122,9 +1122,9 @@ pub struct TextBasedSignalAnalyzer { impl TextBasedSignalAnalyzer { /// Extract text content from MessageContent, skipping non-text content - fn extract_text(content: &hermesllm::apis::openai::MessageContent) -> Option { + fn extract_text(content: &Option) -> Option { match content { - hermesllm::apis::openai::MessageContent::Text(text) => Some(text.clone()), + Some(hermesllm::apis::openai::MessageContent::Text(text)) => Some(text.clone()), // Tool calls and other structured content are skipped _ => None, } @@ -1941,12 +1941,13 @@ impl Default for TextBasedSignalAnalyzer { mod tests { use super::*; use hermesllm::apis::openai::MessageContent; + use hermesllm::transforms::lib::ExtractText; use std::time::Instant; fn create_message(role: Role, content: &str) -> Message { Message { role, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: None, @@ -2130,7 +2131,7 @@ mod tests { .iter() .enumerate() .map(|(i, msg)| { - let text = msg.content.to_string(); + let text = msg.content.extract_text(); (i, msg.role.clone(), NormalizedMessage::from_text(&text)) }) .collect() @@ -2532,7 +2533,7 @@ mod tests { |content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message { Message { role: Role::Assistant, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: Some(vec![ToolCall { id: tool_id.to_string(), @@ -2550,7 +2551,7 @@ mod tests { let create_tool_message = |tool_call_id: &str, content: &str| -> Message { Message { role: Role::Tool, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: Some(tool_call_id.to_string()), @@ -2665,7 +2666,7 @@ mod tests { |content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message { Message { role: Role::Assistant, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: Some(vec![ToolCall { id: tool_id.to_string(), @@ -2683,7 +2684,7 @@ mod tests { let create_tool_message = |tool_call_id: &str, content: &str| -> Message { Message { role: Role::Tool, - content: MessageContent::Text(content.to_string()), + content: Some(MessageContent::Text(content.to_string())), name: None, tool_calls: None, tool_call_id: Some(tool_call_id.to_string()), diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs index dbada283..e2cbc201 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock.rs +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -225,7 +225,7 @@ impl ProviderRequest for ConverseRequest { if let SystemContentBlock::Text { text } = sys_block { openai_messages.push(Message { role: Role::System, - content: MessageContent::Text(text.clone()), + content: Some(MessageContent::Text(text.clone())), name: None, tool_calls: None, tool_call_id: None, @@ -258,7 +258,7 @@ impl ProviderRequest for ConverseRequest { openai_messages.push(Message { role, - content: MessageContent::Text(content), + content: Some(MessageContent::Text(content)), name: None, tool_calls: None, tool_call_id: None, @@ -279,7 +279,7 @@ impl ProviderRequest for ConverseRequest { for msg in messages { match msg.role { crate::apis::openai::Role::System => { - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content { system_blocks.push(SystemContentBlock::Text { text: text.clone() }); } } @@ -290,12 +290,13 @@ impl ProviderRequest for ConverseRequest { _ => continue, }; - let content = - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { - vec![ContentBlock::Text { text: text.clone() }] - } else { - vec![] - }; + let content = if let Some(crate::apis::openai::MessageContent::Text(text)) = + &msg.content + { + vec![ContentBlock::Text { text: text.clone() }] + } else { + vec![] + }; bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content }); } diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index ed3317ce..6e53e6db 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -584,7 +584,7 @@ impl ProviderRequest for MessagesRequest { let system_text = system_messages .iter() .filter_map(|msg| { - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content { Some(text.as_str()) } else { None diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 834c33ec..cd4e7d0b 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -155,7 +155,8 @@ pub enum Role { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Message { pub role: Role, - pub content: MessageContent, + /// The contents of the message. Required unless tool_calls is specified (for assistant role) + pub content: Option, pub name: Option, /// Tool calls made by the assistant (only present for assistant role) pub tool_calls: Option>, @@ -204,8 +205,7 @@ impl ResponseMessage { content: self .content .as_ref() - .map(|s| MessageContent::Text(s.clone())) - .unwrap_or(MessageContent::Text(String::new())), + .map(|s| MessageContent::Text(s.clone())), name: None, // Response messages don't have names in the same way request messages do tool_calls: self.tool_calls.clone(), tool_call_id: None, // Response messages don't have tool_call_id @@ -233,6 +233,12 @@ impl ExtractText for MessageContent { } } +impl ExtractText for Option { + fn extract_text(&self) -> String { + self.as_ref().map(|c| c.extract_text()).unwrap_or_default() + } +} + impl ExtractText for Vec { fn extract_text(&self) -> String { self.iter() @@ -247,23 +253,7 @@ impl ExtractText for Vec { impl Display for MessageContent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MessageContent::Text(text) => write!(f, "{}", text), - MessageContent::Parts(parts) => { - let text_parts: Vec = parts - .iter() - .filter_map(|part| match part { - ContentPart::Text { text } => Some(text.clone()), - ContentPart::ImageUrl { .. } => { - // skip image URLs or their data in text representation - None - } - }) - .collect(); - let combined_text = text_parts.join("\n"); - write!(f, "{}", combined_text) - } - } + write!(f, "{}", self.extract_text()) } } @@ -622,8 +612,10 @@ impl ProviderRequest for ChatCompletionsRequest { fn extract_messages_text(&self) -> String { self.messages.iter().fold(String::new(), |acc, m| { - acc + " " - + &match &m.content { + let content_text = m + .content + .as_ref() + .map(|content| match content { MessageContent::Text(text) => text.clone(), MessageContent::Parts(parts) => parts .iter() @@ -633,16 +625,18 @@ impl ProviderRequest for ChatCompletionsRequest { }) .collect::>() .join(" "), - } + }) + .unwrap_or_default(); + acc + " " + &content_text }) } fn get_recent_user_message(&self) -> Option { self.messages.last().and_then(|msg| { - match &msg.content { + msg.content.as_ref().and_then(|content| match content { MessageContent::Text(text) => Some(text.clone()), MessageContent::Parts(_) => None, // No user message in parts - } + }) }) } @@ -778,7 +772,8 @@ mod tests { let message = &deserialized_request.messages[0]; assert_eq!(message.role, Role::User); - if let MessageContent::Text(content) = &message.content { + assert!(message.content.is_some()); + if let Some(MessageContent::Text(content)) = &message.content { assert_eq!(content, "Hello, world!"); } else { panic!("Expected text content"); @@ -822,7 +817,8 @@ mod tests { let message = &deserialized_request.messages[0]; assert_eq!(message.role, Role::User); - if let MessageContent::Text(content) = &message.content { + assert!(message.content.is_some()); + if let Some(MessageContent::Text(content)) = &message.content { assert_eq!(content, "Test message"); } else { panic!("Expected text content"); @@ -947,7 +943,8 @@ mod tests { // Validate first message (user with multimodal content) let user_message = &deserialized_request.messages[0]; assert_eq!(user_message.role, Role::User); - if let MessageContent::Parts(ref content_parts) = user_message.content { + assert!(user_message.content.is_some()); + if let Some(MessageContent::Parts(ref content_parts)) = user_message.content { assert_eq!(content_parts.len(), 2); // Validate text content part @@ -971,7 +968,8 @@ mod tests { // Validate second message (assistant with tool calls) let assistant_message = &deserialized_request.messages[1]; assert_eq!(assistant_message.role, Role::Assistant); - if let MessageContent::Text(text) = &assistant_message.content { + assert!(assistant_message.content.is_some()); + if let Some(MessageContent::Text(text)) = &assistant_message.content { assert_eq!( text, "I can see a beautiful cityscape. Let me check the weather for you." @@ -997,7 +995,8 @@ mod tests { // Validate third message (tool response) let tool_message = &deserialized_request.messages[2]; assert_eq!(tool_message.role, Role::Tool); - if let MessageContent::Text(text) = &tool_message.content { + assert!(tool_message.content.is_some()); + if let Some(MessageContent::Text(text)) = &tool_message.content { assert_eq!(text, "Current weather in New York: 72°F, sunny"); } else { panic!("Expected text content for tool message"); @@ -1061,6 +1060,62 @@ mod tests { assert!((original_temp - serialized_temp).abs() < 1e-6); } + #[test] + fn test_assistant_message_with_tool_calls_no_content() { + // Test that assistant messages can have tool_calls without content + let json_with_tool_calls_no_content = json!({ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "What's the weather in San Francisco?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"San Francisco, CA\"}" + } + } + ] + } + ] + }); + + // Should deserialize successfully + let request: ChatCompletionsRequest = + serde_json::from_value(json_with_tool_calls_no_content.clone()).unwrap(); + + assert_eq!(request.messages.len(), 2); + + // Check user message + let user_msg = &request.messages[0]; + assert_eq!(user_msg.role, Role::User); + assert!(user_msg.content.is_some()); + + // Check assistant message - should have tool_calls but no content + let assistant_msg = &request.messages[1]; + assert_eq!(assistant_msg.role, Role::Assistant); + assert!(assistant_msg.content.is_none()); + assert!(assistant_msg.tool_calls.is_some()); + + let tool_calls = assistant_msg.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "call_123"); + assert_eq!(tool_calls[0].function.name, "get_weather"); + + // Should serialize back without content field + let serialized = serde_json::to_value(&request).unwrap(); + // Verify the assistant message doesn't have a content field in serialized JSON + let serialized_assistant_msg = &serialized["messages"][1]; + assert!(serialized_assistant_msg.get("content").is_none()); + assert!(serialized_assistant_msg.get("tool_calls").is_some()); + } + #[test] fn test_api_provider_trait() { // Test the ApiDefinition trait implementation @@ -1097,7 +1152,7 @@ mod tests { let deserialized_user: Message = serde_json::from_value(user_json.clone()).unwrap(); assert_eq!(deserialized_user.role, Role::User); - if let MessageContent::Text(content) = &deserialized_user.content { + if let Some(MessageContent::Text(content)) = &deserialized_user.content { assert_eq!(content, "Hello!"); } else { panic!("Expected text content"); @@ -1128,7 +1183,7 @@ mod tests { let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap(); assert_eq!(deserialized_assistant.role, Role::Assistant); - if let MessageContent::Text(content) = &deserialized_assistant.content { + if let Some(MessageContent::Text(content)) = &deserialized_assistant.content { assert_eq!(content, "I'll help with that."); } else { panic!("Expected text content"); @@ -1154,7 +1209,7 @@ mod tests { let deserialized_tool: Message = serde_json::from_value(tool_json.clone()).unwrap(); assert_eq!(deserialized_tool.role, Role::Tool); - if let MessageContent::Text(content) = &deserialized_tool.content { + if let Some(MessageContent::Text(content)) = &deserialized_tool.content { assert_eq!(content, "Weather is sunny"); } else { panic!("Expected text content"); @@ -1193,7 +1248,7 @@ mod tests { // Test conversion from ResponseMessage to Message let converted = deserialized_response.to_message(); assert_eq!(converted.role, Role::Assistant); - if let MessageContent::Text(text) = converted.content { + if let Some(MessageContent::Text(text)) = converted.content { assert_eq!(text, "Response content"); } else { panic!("Expected text content"); diff --git a/crates/hermesllm/src/apis/openai_responses.rs b/crates/hermesllm/src/apis/openai_responses.rs index 720e24d3..dbc82f8b 100644 --- a/crates/hermesllm/src/apis/openai_responses.rs +++ b/crates/hermesllm/src/apis/openai_responses.rs @@ -1146,7 +1146,7 @@ impl ProviderRequest for ResponsesAPIRequest { .iter() .filter(|msg| msg.role == crate::apis::openai::Role::System) .filter_map(|msg| { - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content { Some(text.as_str()) } else { None @@ -1170,7 +1170,8 @@ impl ProviderRequest for ResponsesAPIRequest { if !input_messages.is_empty() { // If there's only one message, use Text format if input_messages.len() == 1 { - if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content + if let Some(crate::apis::openai::MessageContent::Text(text)) = + &input_messages[0].content { self.input = crate::apis::openai_responses::InputParam::Text(text.clone()); } @@ -1180,7 +1181,8 @@ impl ProviderRequest for ResponsesAPIRequest { let combined_text = input_messages .iter() .filter_map(|msg| { - if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + if let Some(crate::apis::openai::MessageContent::Text(text)) = &msg.content + { Some(format!( "{}: {}", match msg.role { diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index d1d85888..e97e8a68 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -671,14 +671,16 @@ mod tests { messages: vec![ Message { role: Role::System, - content: MessageContent::Text("You are a helpful assistant".to_string()), + content: Some(MessageContent::Text( + "You are a helpful assistant".to_string(), + )), name: None, tool_calls: None, tool_call_id: None, }, Message { role: Role::User, - content: MessageContent::Text("Hello!".to_string()), + content: Some(MessageContent::Text("Hello!".to_string())), name: None, tool_calls: None, tool_call_id: None, @@ -900,7 +902,7 @@ mod tests { model: "gpt-4".to_string(), messages: vec![Message { role: Role::User, - content: MessageContent::Text("Hello!".to_string()), + content: Some(MessageContent::Text("Hello!".to_string())), name: None, tool_calls: None, tool_call_id: None, @@ -993,14 +995,14 @@ mod tests { messages: vec![ Message { role: Role::System, - content: MessageContent::Text("You are helpful".to_string()), + content: Some(MessageContent::Text("You are helpful".to_string())), name: None, tool_calls: None, tool_call_id: None, }, Message { role: Role::User, - content: MessageContent::Text("Hello!".to_string()), + content: Some(MessageContent::Text("Hello!".to_string())), name: None, tool_calls: None, tool_call_id: None, diff --git a/crates/hermesllm/src/transforms/lib.rs b/crates/hermesllm/src/transforms/lib.rs index a44f8d79..115f061c 100644 --- a/crates/hermesllm/src/transforms/lib.rs +++ b/crates/hermesllm/src/transforms/lib.rs @@ -188,7 +188,7 @@ pub fn convert_openai_message_to_anthropic_content( // Handle regular content match &message.content { - MessageContent::Text(text) => { + Some(MessageContent::Text(text)) => { if !text.is_empty() { blocks.push(MessagesContentBlock::Text { text: text.clone(), @@ -196,7 +196,7 @@ pub fn convert_openai_message_to_anthropic_content( }); } } - MessageContent::Parts(parts) => { + Some(MessageContent::Parts(parts)) => { for part in parts { match part { ContentPart::Text { text } => { @@ -212,6 +212,7 @@ pub fn convert_openai_message_to_anthropic_content( } } } + None => {} } // Handle tool calls diff --git a/crates/hermesllm/src/transforms/request/from_anthropic.rs b/crates/hermesllm/src/transforms/request/from_anthropic.rs index c07be4e5..82dbe547 100644 --- a/crates/hermesllm/src/transforms/request/from_anthropic.rs +++ b/crates/hermesllm/src/transforms/request/from_anthropic.rs @@ -174,7 +174,7 @@ impl TryFrom for Vec { MessagesMessageContent::Single(text) => { result.push(Message { role: message.role.into(), - content: MessageContent::Text(text), + content: Some(MessageContent::Text(text)), name: None, tool_calls: None, tool_call_id: None, @@ -186,7 +186,7 @@ impl TryFrom for Vec { for (tool_use_id, result_text, _is_error) in tool_results { result.push(Message { role: Role::Tool, - content: MessageContent::Text(result_text), + content: Some(MessageContent::Text(result_text)), name: None, tool_calls: None, tool_call_id: Some(tool_use_id), @@ -260,7 +260,7 @@ impl From for Message { Message { role: Role::System, - content: system_content, + content: Some(system_content), name: None, tool_calls: None, tool_call_id: None, @@ -317,16 +317,19 @@ fn convert_anthropic_tool_choice( fn build_openai_content( content_parts: Vec, tool_calls: &[ToolCall], -) -> MessageContent { - if content_parts.len() == 1 && tool_calls.is_empty() { +) -> Option { + if content_parts.is_empty() && !tool_calls.is_empty() { + // For assistant messages with only tool calls, content is optional + None + } else if content_parts.len() == 1 && tool_calls.is_empty() { match &content_parts[0] { - ContentPart::Text { text } => MessageContent::Text(text.clone()), - _ => MessageContent::Parts(content_parts), + ContentPart::Text { text } => Some(MessageContent::Text(text.clone())), + _ => Some(MessageContent::Parts(content_parts)), } } else if content_parts.is_empty() { - MessageContent::Text("".to_string()) + Some(MessageContent::Text("".to_string())) } else { - MessageContent::Parts(content_parts) + Some(MessageContent::Parts(content_parts)) } } diff --git a/crates/hermesllm/src/transforms/request/from_openai.rs b/crates/hermesllm/src/transforms/request/from_openai.rs index e39cfed3..ddc3b1ca 100644 --- a/crates/hermesllm/src/transforms/request/from_openai.rs +++ b/crates/hermesllm/src/transforms/request/from_openai.rs @@ -18,7 +18,6 @@ use crate::apis::openai_responses::{ ResponsesAPIRequest, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice, }; use crate::clients::TransformError; -use crate::transforms::lib::ExtractText; use crate::transforms::lib::*; use crate::transforms::*; @@ -48,7 +47,7 @@ impl TryFrom for Vec { if let Some(instructions) = converter.instructions { messages.push(Message { role: Role::System, - content: MessageContent::Text(instructions), + content: Some(MessageContent::Text(instructions)), name: None, tool_call_id: None, tool_calls: None, @@ -58,7 +57,7 @@ impl TryFrom for Vec { // Add the user message messages.push(Message { role: Role::User, - content: MessageContent::Text(text), + content: Some(MessageContent::Text(text)), name: None, tool_call_id: None, tool_calls: None, @@ -74,7 +73,7 @@ impl TryFrom for Vec { if let Some(instructions) = converter.instructions { converted_messages.push(Message { role: Role::System, - content: MessageContent::Text(instructions), + content: Some(MessageContent::Text(instructions)), name: None, tool_call_id: None, tool_calls: None, @@ -154,7 +153,7 @@ impl TryFrom for Vec { converted_messages.push(Message { role, - content, + content: Some(content), name: None, tool_call_id: None, tool_calls: None, @@ -174,11 +173,7 @@ impl TryFrom for Vec { impl From for MessagesSystemPrompt { fn from(val: Message) -> Self { - let system_text = match val.content { - MessageContent::Text(text) => text, - MessageContent::Parts(parts) => parts.extract_text(), - }; - MessagesSystemPrompt::Single(system_text) + MessagesSystemPrompt::Single(val.content.extract_text()) } } @@ -191,6 +186,8 @@ impl TryFrom for MessagesMessage { Role::Assistant => MessagesRole::Assistant, Role::Tool => { // Tool messages become user messages with tool results + // Extract content text first, before moving tool_call_id + let content_text = message.content.extract_text(); let tool_call_id = message.tool_call_id.ok_or_else(|| { TransformError::MissingField( "tool_call_id required for Tool messages".to_string(), @@ -204,7 +201,7 @@ impl TryFrom for MessagesMessage { tool_use_id: tool_call_id, is_error: None, content: ToolResultContent::Blocks(vec![MessagesContentBlock::Text { - text: message.content.extract_text(), + text: content_text, cache_control: None, }]), cache_control: None, @@ -248,12 +245,12 @@ impl TryFrom for BedrockMessage { Role::User => { // Convert user message content to content blocks match message.content { - MessageContent::Text(text) => { + Some(MessageContent::Text(text)) => { if !text.is_empty() { content_blocks.push(ContentBlock::Text { text }); } } - MessageContent::Parts(parts) => { + Some(MessageContent::Parts(parts)) => { // Convert OpenAI content parts to Bedrock ContentBlocks for part in parts { match part { @@ -293,6 +290,9 @@ impl TryFrom for BedrockMessage { } } } + None => { + // Empty content for user - shouldn't happen but handle gracefully + } } // Ensure we have at least one content block @@ -550,10 +550,7 @@ impl TryFrom for ConverseRequest { for message in req.messages { match message.role { Role::System => { - let system_text = match message.content { - MessageContent::Text(text) => text, - MessageContent::Parts(parts) => parts.extract_text(), - }; + let system_text = message.content.extract_text(); system_messages.push(SystemContentBlock::Text { text: system_text }); } _ => { @@ -778,14 +775,16 @@ mod tests { messages: vec![ Message { role: Role::System, - content: MessageContent::Text("You are a helpful assistant.".to_string()), + content: Some(MessageContent::Text( + "You are a helpful assistant.".to_string(), + )), name: None, tool_call_id: None, tool_calls: None, }, Message { role: Role::User, - content: MessageContent::Text("Hello, how are you?".to_string()), + content: Some(MessageContent::Text("Hello, how are you?".to_string())), name: None, tool_call_id: None, tool_calls: None, @@ -840,7 +839,7 @@ mod tests { model: "gpt-4".to_string(), messages: vec![Message { role: Role::User, - content: MessageContent::Text("What's the weather like?".to_string()), + content: Some(MessageContent::Text("What's the weather like?".to_string())), name: None, tool_call_id: None, tool_calls: None, @@ -907,7 +906,7 @@ mod tests { model: "gpt-4".to_string(), messages: vec![Message { role: Role::User, - content: MessageContent::Text("Help me with something".to_string()), + content: Some(MessageContent::Text("Help me with something".to_string())), name: None, tool_call_id: None, tool_calls: None, @@ -950,28 +949,30 @@ mod tests { messages: vec![ Message { role: Role::System, - content: MessageContent::Text("Be concise".to_string()), + content: Some(MessageContent::Text("Be concise".to_string())), name: None, tool_call_id: None, tool_calls: None, }, Message { role: Role::User, - content: MessageContent::Text("Hello".to_string()), + content: Some(MessageContent::Text("Hello".to_string())), name: None, tool_call_id: None, tool_calls: None, }, Message { role: Role::Assistant, - content: MessageContent::Text("Hi there! How can I help you?".to_string()), + content: Some(MessageContent::Text( + "Hi there! How can I help you?".to_string(), + )), name: None, tool_call_id: None, tool_calls: None, }, Message { role: Role::User, - content: MessageContent::Text("What's 2+2?".to_string()), + content: Some(MessageContent::Text("What's 2+2?".to_string())), name: None, tool_call_id: None, tool_calls: None, @@ -1009,7 +1010,7 @@ mod tests { fn test_openai_message_to_bedrock_conversion() { let openai_message = Message { role: Role::User, - content: MessageContent::Text("Test message".to_string()), + content: Some(MessageContent::Text("Test message".to_string())), name: None, tool_call_id: None, tool_calls: None, diff --git a/tests/e2e/test_model_alias_routing.py b/tests/e2e/test_model_alias_routing.py index 7af14df1..f20c24af 100644 --- a/tests/e2e/test_model_alias_routing.py +++ b/tests/e2e/test_model_alias_routing.py @@ -22,6 +22,81 @@ LLM_GATEWAY_ENDPOINT = os.getenv( # ============================================================================= +def test_assistant_message_with_null_content_and_tool_calls(): + """Test that assistant messages with null content and tool_calls are properly handled""" + logger.info( + "Testing assistant message with null content and tool_calls (multi-turn conversation)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI( + api_key="test-key", + base_url=f"{base_url}/v1", + ) + + # Simulate a multi-turn conversation where: + # 1. User asks a question + # 2. Assistant makes a tool call (with null content) + # 3. Tool responds + # 4. Assistant should provide final answer + completion = client.chat.completions.create( + model="gpt-4o", + max_tokens=500, + messages=[ + { + "role": "system", + "content": "You are a weather assistant. Use the get_weather tool to fetch weather information.", + }, + {"role": "user", "content": "What's the weather in Seattle?"}, + { + "role": "assistant", + "content": None, # This is the key test - null content with tool_calls + "tool_calls": [ + { + "id": "call_test123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Seattle"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_test123", + "content": '{"location": "Seattle", "temperature": "10°C", "condition": "Partly cloudy"}', + }, + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"} + }, + "required": ["city"], + }, + }, + } + ], + ) + + response_content = completion.choices[0].message.content + logger.info(f"Response after tool call: {response_content}") + + # The assistant should provide a final response using the tool result + assert response_content is not None + assert len(response_content) > 0 + logger.info( + "✓ Assistant message with null content and tool_calls handled correctly" + ) + + def test_openai_client_with_alias_arch_summarize_v1(): """Test OpenAI client using model alias 'arch.summarize.v1' which should resolve to '4o-mini'""" logger.info("Testing OpenAI client with alias 'arch.summarize.v1' -> '4o-mini'")