From a39ef5f215134c0d8e46b9b8db06236682e589cd Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 27 May 2025 19:12:45 -0700 Subject: [PATCH] ensure that recent and last message is from user role --- .../brightstaff/src/router/router_model_v1.rs | 178 ++++++++++++++---- 1 file changed, 137 insertions(+), 41 deletions(-) diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 7c267252..3b88eb5b 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,6 +1,6 @@ use common::{ api::open_ai::{ChatCompletionsRequest, ContentType, Message}, - consts::{SYSTEM_ROLE, USER_ROLE} + consts::{SYSTEM_ROLE, USER_ROLE}, }; use serde::{Deserialize, Serialize}; use tracing::debug; @@ -58,21 +58,28 @@ const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for U impl RouterModel for RouterModelV1 { fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest { - let mut messages_vec = messages + let messages_vec = messages .iter() .filter(|m| m.role != SYSTEM_ROLE) - .map(|m| { - let content_json_str = serde_json::to_string(&m.content).unwrap_or_default(); - format!("{}: {}", m.role, content_json_str) - }) - .collect::>(); + // .map(|m| { + // let content_json_str = serde_json::to_string(&m.content).unwrap_or_default(); + // format!("{}: {}", m.role, content_json_str) + // }) + // .collect::>(); + .collect::>(); // Following code is to ensure that the conversation does not exceed max token length // Note: we use a simple heuristic to estimate token count based on character length to optimize for performance let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; - let mut selected_messsage_count = 0; - for message in messages_vec.iter().rev() { - let message_token_count = message.len() / TOKEN_LENGTH_DIVISOR; + let mut selected_messages_list: Vec<&Message> = vec![]; + for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { + let message_token_count = message + .content + .as_ref() + .unwrap_or(&ContentType::Text("".to_string())) + .to_string() + .len() + / TOKEN_LENGTH_DIVISOR; token_count += message_token_count; if token_count > self.max_token_length { debug!( @@ -82,31 +89,57 @@ impl RouterModel for RouterModelV1 { , selected_messsage_count, messages_vec.len() ); + if message.role == USER_ROLE { + // If message that exceeds max token length is from user, we need to keep it + selected_messages_list.push(message); + } break; } - selected_messsage_count += 1; + // If we are here, it means that the message is within the max token length + selected_messages_list.push(message); } - if selected_messsage_count == 0 { - debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)", - self.max_token_length); - messages_vec = messages_vec - .last() - .map_or_else(Vec::new, |last_message| vec![last_message.to_string()]); - } else { - let skip_messages_count = messages_vec.len() - selected_messsage_count; - if skip_messages_count > 0 { - debug!( - "RouterModelV1: skipping first {} messages from the beginning of the conversation", - skip_messages_count - ); - messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect(); + if selected_messages_list.is_empty() { + debug!( + "RouterModelV1: no messages selected, using the last message in the conversation" + ); + if let Some(last_message) = messages_vec.last() { + selected_messages_list.push(last_message); } } + // if selected_messsage_count == 0 { + // debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)", + // self.max_token_length); + // messages_vec = messages_vec + // .last() + // .map_or_else(Vec::new, |last_message| vec![last_message.to_string()]); + // } else { + // let skip_messages_count = messages_vec.len() - selected_messsage_count; + // if skip_messages_count > 0 { + // debug!( + // "RouterModelV1: skipping first {} messages from the beginning of the conversation", + // skip_messages_count + // ); + // messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect(); + // } + // } + + let selected_conversation_list_str = selected_messages_list + .iter() + .rev() + .map(|m| { + let content_json_str = serde_json::to_string(&m.content).unwrap_or_default(); + format!("{}: {}", m.role, content_json_str) + }) + .collect::>(); + let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT .replace("{routes}", &self.llm_providers_with_usage_yaml) - .replace("{conversation}", messages_vec.join("\n").as_str()); + .replace( + "{conversation}", + selected_conversation_list_str.join("\n").as_str(), + ); ChatCompletionsRequest { model: self.routing_model.clone(), @@ -212,11 +245,7 @@ user: "seattle" let routes_yaml = "route1: description1\nroute2: description2"; let routing_model = "test-model".to_string(); - let router = RouterModelV1::new( - routes_yaml.to_string(), - routing_model.clone(), - usize::MAX, - ); + let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), usize::MAX); let messages = vec![ Message { @@ -283,11 +312,7 @@ user: "seattle" let routes_yaml = "route1: description1\nroute2: description2"; let routing_model = "test-model".to_string(); - let router = RouterModelV1::new( - routes_yaml.to_string(), - routing_model.clone(), - 225 - ); + let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 223); let messages = vec![ Message { @@ -360,11 +385,7 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there let routes_yaml = "route1: description1\nroute2: description2"; let routing_model = "test-model".to_string(); - let router = RouterModelV1::new( - routes_yaml.to_string(), - routing_model.clone(), - 210, - ); + let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 210); let messages = vec![ Message { @@ -413,6 +434,81 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there assert_eq!(expected_prompt, prompt.to_string()); } + #[test] + fn test_conversation_trim_upto_user_message() { + let _tracer = init_tracer(); + let expected_prompt = r#" +You are a helpful assistant designed to find the best suited route. +You are provided with route description within XML tags: + +route1: description1 +route2: description2 + + +Your task is to decide which route is best suit with user intent on the conversation in XML tags. Follow the instruction: +1. If the latest intent from user is irrelevant, response with empty route {"route": ""}. +2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}. +3. Understand user latest intent and find the best match route in xml tags. + +Based on your analysis, provide your response in the following JSON formats if you decide to match any route: +{"route": "route_name"} + + + +user: "I want to book a flight." +assistant: "Sure, where would you like to go?" +user: "seattle" + +"#; + + let routes_yaml = "route1: description1\nroute2: description2"; + let routing_model = "test-model".to_string(); + let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 220); + + let messages = vec![ + Message { + role: "system".to_string(), + content: Some(ContentType::Text( + "You are a helpful assistant.".to_string(), + )), + ..Default::default() + }, + Message { + role: "user".to_string(), + content: Some(ContentType::Text("Hi".to_string())), + ..Default::default() + }, + Message { + role: "assistant".to_string(), + content: Some(ContentType::Text("Hello! How can I assist you".to_string())), + ..Default::default() + }, + Message { + role: "user".to_string(), + content: Some(ContentType::Text("I want to book a flight.".to_string())), + ..Default::default() + }, + Message { + role: "assistant".to_string(), + content: Some(ContentType::Text( + "Sure, where would you like to go?".to_string(), + )), + ..Default::default() + }, + Message { + role: "user".to_string(), + content: Some(ContentType::Text("seattle".to_string())), + ..Default::default() + }, + ]; + + let req = router.generate_request(&messages); + + let prompt = req.messages[0].content.as_ref().unwrap(); + + assert_eq!(expected_prompt, prompt.to_string()); + } + #[test] fn test_parse_response() { let router = RouterModelV1::new(