add warning if first/last message is not from user

This commit is contained in:
Adil Hafeez 2025-05-27 19:18:39 -07:00
parent a39ef5f215
commit c80a04eb39
No known key found for this signature in database
GPG key ID: 9B18EF7691369645

View file

@ -3,7 +3,7 @@ use common::{
consts::{SYSTEM_ROLE, USER_ROLE},
};
use serde::{Deserialize, Serialize};
use tracing::debug;
use tracing::{debug, warn};
use super::router_model::{RouterModel, RoutingModelError};
@ -61,17 +61,12 @@ impl RouterModel for RouterModelV1 {
let messages_vec = messages
.iter()
.filter(|m| m.role != SYSTEM_ROLE)
// .map(|m| {
// let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
// format!("{}: {}", m.role, content_json_str)
// })
// .collect::<Vec<String>>();
.collect::<Vec<&Message>>();
// Following code is to ensure that the conversation does not exceed max token length
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messages_list: Vec<&Message> = vec![];
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
let message_token_count = message
.content
@ -91,41 +86,38 @@ impl RouterModel for RouterModelV1 {
);
if message.role == USER_ROLE {
// If message that exceeds max token length is from user, we need to keep it
selected_messages_list.push(message);
selected_messages_list_reversed.push(message);
}
break;
}
// If we are here, it means that the message is within the max token length
selected_messages_list.push(message);
selected_messages_list_reversed.push(message);
}
if selected_messages_list.is_empty() {
if selected_messages_list_reversed.is_empty() {
debug!(
"RouterModelV1: no messages selected, using the last message in the conversation"
);
if let Some(last_message) = messages_vec.last() {
selected_messages_list.push(last_message);
selected_messages_list_reversed.push(last_message);
}
}
// if selected_messsage_count == 0 {
// debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)",
// self.max_token_length);
// messages_vec = messages_vec
// .last()
// .map_or_else(Vec::new, |last_message| vec![last_message.to_string()]);
// } else {
// let skip_messages_count = messages_vec.len() - selected_messsage_count;
// if skip_messages_count > 0 {
// debug!(
// "RouterModelV1: skipping first {} messages from the beginning of the conversation",
// skip_messages_count
// );
// messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect();
// }
// }
// ensure that first and last selected message is from user
if let Some(first_message) = selected_messages_list_reversed.first() {
if first_message.role != USER_ROLE {
warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing");
}
}
if let Some(last_message) = selected_messages_list_reversed.last() {
if last_message.role != USER_ROLE {
warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing");
}
}
let selected_conversation_list_str = selected_messages_list
// Reverse the selected messages to maintain the conversation order
let selected_conversation_list_str = selected_messages_list_reversed
.iter()
.rev()
.map(|m| {