mirror of
https://github.com/katanemo/plano.git
synced 2026-05-06 22:32:42 +02:00
making Messages.Content optional, and having the upstream LLM fail if the right fields aren't set (#699)
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
626f556cc6
commit
cdc1d7cee2
17 changed files with 294 additions and 133 deletions
|
|
@ -402,7 +402,7 @@ async fn handle_agent_chat(
|
|||
// and add it to the conversation history
|
||||
current_messages.push(OpenAIMessage {
|
||||
role: hermesllm::apis::openai::Role::Assistant,
|
||||
content: hermesllm::apis::openai::MessageContent::Text(response_text),
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
|
||||
name: Some(agent_name.clone()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
|
|
@ -638,7 +638,7 @@ impl ArchFunctionHandler {
|
|||
let system_prompt = self.format_system_prompt(tools)?;
|
||||
processed_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(system_prompt),
|
||||
content: Some(MessageContent::Text(system_prompt)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -649,8 +649,9 @@ impl ArchFunctionHandler {
|
|||
for (idx, message) in messages.iter().enumerate() {
|
||||
let mut role = message.role.clone();
|
||||
let mut content = match &message.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(_) => String::new(),
|
||||
Some(MessageContent::Text(text)) => text.clone(),
|
||||
Some(MessageContent::Parts(_)) => String::new(),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
// Handle tool calls
|
||||
|
|
@ -675,7 +676,8 @@ impl ArchFunctionHandler {
|
|||
} else {
|
||||
// Get the tool call from previous message
|
||||
if idx > 0 {
|
||||
if let MessageContent::Text(prev_content) = &messages[idx - 1].content {
|
||||
if let Some(MessageContent::Text(prev_content)) = &messages[idx - 1].content
|
||||
{
|
||||
let mut tool_call_msg = prev_content.clone();
|
||||
|
||||
// Strip markdown code blocks
|
||||
|
|
@ -721,7 +723,7 @@ impl ArchFunctionHandler {
|
|||
|
||||
processed_messages.push(Message {
|
||||
role,
|
||||
content: MessageContent::Text(content),
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: message.name.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -740,7 +742,7 @@ impl ArchFunctionHandler {
|
|||
// Add extra instruction if provided
|
||||
if let Some(instruction) = extra_instruction {
|
||||
if let Some(last) = processed_messages.last_mut() {
|
||||
if let MessageContent::Text(content) = &mut last.content {
|
||||
if let Some(MessageContent::Text(content)) = &mut last.content {
|
||||
content.push('\n');
|
||||
content.push_str(instruction);
|
||||
}
|
||||
|
|
@ -761,7 +763,7 @@ impl ArchFunctionHandler {
|
|||
// Keep system message if present
|
||||
if let Some(first) = messages.first() {
|
||||
if first.role == Role::System {
|
||||
if let MessageContent::Text(content) = &first.content {
|
||||
if let Some(MessageContent::Text(content)) = &first.content {
|
||||
num_tokens += content.len() / 4; // Approximate 4 chars per token
|
||||
}
|
||||
conversation_idx = 1;
|
||||
|
|
@ -772,7 +774,7 @@ impl ArchFunctionHandler {
|
|||
// Start with message_idx pointing past the end (will be used if no truncation needed)
|
||||
let mut message_idx = messages.len();
|
||||
for i in (conversation_idx..messages.len()).rev() {
|
||||
if let MessageContent::Text(content) = &messages[i].content {
|
||||
if let Some(MessageContent::Text(content)) = &messages[i].content {
|
||||
num_tokens += content.len() / 4;
|
||||
if num_tokens >= max_tokens && messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
|
|
@ -802,7 +804,7 @@ impl ArchFunctionHandler {
|
|||
pub fn prefill_message(&self, mut messages: Vec<Message>, prefill: &str) -> Vec<Message> {
|
||||
messages.push(Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text(prefill.to_string()),
|
||||
content: Some(MessageContent::Text(prefill.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ mod tests {
|
|||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -129,7 +129,7 @@ mod tests {
|
|||
let processed_messages = result.unwrap();
|
||||
// With empty filter chain, should return the original messages unchanged
|
||||
assert_eq!(processed_messages.len(), 1);
|
||||
if let MessageContent::Text(content) = &processed_messages[0].content {
|
||||
if let Some(MessageContent::Text(content)) = &processed_messages[0].content {
|
||||
assert_eq!(content, "Hello world!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
|
|||
|
|
@ -887,7 +887,7 @@ mod tests {
|
|||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
|
|
@ -95,7 +95,9 @@ pub async fn router_chat_get_upstream_model(
|
|||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
msg.content
|
||||
.as_ref()
|
||||
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use std::collections::HashMap;
|
|||
|
||||
use common::configuration::{AgentUsagePreference, OrchestrationPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
@ -181,7 +182,9 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty()
|
||||
m.role != Role::System
|
||||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
|
|
@ -190,7 +193,7 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR;
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
|
|
@ -240,7 +243,12 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
.rev()
|
||||
.map(|message| Message {
|
||||
role: message.role.clone(),
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -262,7 +270,7 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
ChatCompletionsRequest {
|
||||
model: self.orchestration_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: MessageContent::Text(orchestrator_message),
|
||||
content: Some(MessageContent::Text(orchestrator_message)),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -539,7 +547,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -618,7 +626,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
}]);
|
||||
let req = orchestrator.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -689,7 +697,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -761,7 +769,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -848,7 +856,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -940,7 +948,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -1058,7 +1066,7 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
let req: ChatCompletionsRequest = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use std::collections::HashMap;
|
|||
|
||||
use common::configuration::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
@ -78,7 +79,9 @@ impl RouterModel for RouterModelV1 {
|
|||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty()
|
||||
m.role != Role::System
|
||||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
|
|
@ -87,7 +90,7 @@ impl RouterModel for RouterModelV1 {
|
|||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR;
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
|
|
@ -136,7 +139,12 @@ impl RouterModel for RouterModelV1 {
|
|||
Message {
|
||||
role: message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -154,7 +162,7 @@ impl RouterModel for RouterModelV1 {
|
|||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: MessageContent::Text(router_message),
|
||||
content: Some(MessageContent::Text(router_message)),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -344,7 +352,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -409,7 +417,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -469,7 +477,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -530,7 +538,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -598,7 +606,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -667,7 +675,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
@ -762,7 +770,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.to_string();
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1122,9 +1122,9 @@ pub struct TextBasedSignalAnalyzer {
|
|||
|
||||
impl TextBasedSignalAnalyzer {
|
||||
/// Extract text content from MessageContent, skipping non-text content
|
||||
fn extract_text(content: &hermesllm::apis::openai::MessageContent) -> Option<String> {
|
||||
fn extract_text(content: &Option<hermesllm::apis::openai::MessageContent>) -> Option<String> {
|
||||
match content {
|
||||
hermesllm::apis::openai::MessageContent::Text(text) => Some(text.clone()),
|
||||
Some(hermesllm::apis::openai::MessageContent::Text(text)) => Some(text.clone()),
|
||||
// Tool calls and other structured content are skipped
|
||||
_ => None,
|
||||
}
|
||||
|
|
@ -1941,12 +1941,13 @@ impl Default for TextBasedSignalAnalyzer {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai::MessageContent;
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use std::time::Instant;
|
||||
|
||||
fn create_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
@ -2130,7 +2131,7 @@ mod tests {
|
|||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, msg)| {
|
||||
let text = msg.content.to_string();
|
||||
let text = msg.content.extract_text();
|
||||
(i, msg.role.clone(), NormalizedMessage::from_text(&text))
|
||||
})
|
||||
.collect()
|
||||
|
|
@ -2532,7 +2533,7 @@ mod tests {
|
|||
|content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: tool_id.to_string(),
|
||||
|
|
@ -2550,7 +2551,7 @@ mod tests {
|
|||
let create_tool_message = |tool_call_id: &str, content: &str| -> Message {
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
|
|
@ -2665,7 +2666,7 @@ mod tests {
|
|||
|content: &str, tool_id: &str, tool_name: &str, args: &str| -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: tool_id.to_string(),
|
||||
|
|
@ -2683,7 +2684,7 @@ mod tests {
|
|||
let create_tool_message = |tool_call_id: &str, content: &str| -> Message {
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: MessageContent::Text(content.to_string()),
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue