From c80a04eb39da025d8a1fc3257b8cdd43656da774 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 27 May 2025 19:18:39 -0700 Subject: [PATCH] add warning if first/last message is not from user --- .../brightstaff/src/router/router_model_v1.rs | 48 ++++++++----------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 3b88eb5b..cbea39ff 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -3,7 +3,7 @@ use common::{ consts::{SYSTEM_ROLE, USER_ROLE}, }; use serde::{Deserialize, Serialize}; -use tracing::debug; +use tracing::{debug, warn}; use super::router_model::{RouterModel, RoutingModelError}; @@ -61,17 +61,12 @@ impl RouterModel for RouterModelV1 { let messages_vec = messages .iter() .filter(|m| m.role != SYSTEM_ROLE) - // .map(|m| { - // let content_json_str = serde_json::to_string(&m.content).unwrap_or_default(); - // format!("{}: {}", m.role, content_json_str) - // }) - // .collect::>(); .collect::>(); // Following code is to ensure that the conversation does not exceed max token length // Note: we use a simple heuristic to estimate token count based on character length to optimize for performance let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; - let mut selected_messages_list: Vec<&Message> = vec![]; + let mut selected_messages_list_reversed: Vec<&Message> = vec![]; for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { let message_token_count = message .content @@ -91,41 +86,38 @@ impl RouterModel for RouterModelV1 { ); if message.role == USER_ROLE { // If message that exceeds max token length is from user, we need to keep it - selected_messages_list.push(message); + selected_messages_list_reversed.push(message); } break; } // If we are here, it means that the message is within the max token length - selected_messages_list.push(message); + selected_messages_list_reversed.push(message); } - if selected_messages_list.is_empty() { + if selected_messages_list_reversed.is_empty() { debug!( "RouterModelV1: no messages selected, using the last message in the conversation" ); if let Some(last_message) = messages_vec.last() { - selected_messages_list.push(last_message); + selected_messages_list_reversed.push(last_message); } } - // if selected_messsage_count == 0 { - // debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)", - // self.max_token_length); - // messages_vec = messages_vec - // .last() - // .map_or_else(Vec::new, |last_message| vec![last_message.to_string()]); - // } else { - // let skip_messages_count = messages_vec.len() - selected_messsage_count; - // if skip_messages_count > 0 { - // debug!( - // "RouterModelV1: skipping first {} messages from the beginning of the conversation", - // skip_messages_count - // ); - // messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect(); - // } - // } + // ensure that first and last selected message is from user + if let Some(first_message) = selected_messages_list_reversed.first() { + if first_message.role != USER_ROLE { + warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing"); + } + } + if let Some(last_message) = selected_messages_list_reversed.last() { + if last_message.role != USER_ROLE { + warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing"); + } + } - let selected_conversation_list_str = selected_messages_list + // Reverse the selected messages to maintain the conversation order + + let selected_conversation_list_str = selected_messages_list_reversed .iter() .rev() .map(|m| {