ensure that recent and last message is from user role

This commit is contained in:
Adil Hafeez 2025-05-27 19:12:45 -07:00
parent a74118238c
commit a39ef5f215
No known key found for this signature in database
GPG key ID: 9B18EF7691369645

View file

@ -1,6 +1,6 @@
use common::{
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
consts::{SYSTEM_ROLE, USER_ROLE}
consts::{SYSTEM_ROLE, USER_ROLE},
};
use serde::{Deserialize, Serialize};
use tracing::debug;
@ -58,21 +58,28 @@ const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for U
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
let mut messages_vec = messages
let messages_vec = messages
.iter()
.filter(|m| m.role != SYSTEM_ROLE)
.map(|m| {
let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
format!("{}: {}", m.role, content_json_str)
})
.collect::<Vec<String>>();
// .map(|m| {
// let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
// format!("{}: {}", m.role, content_json_str)
// })
// .collect::<Vec<String>>();
.collect::<Vec<&Message>>();
// Following code is to ensure that the conversation does not exceed max token length
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messsage_count = 0;
for message in messages_vec.iter().rev() {
let message_token_count = message.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messages_list: Vec<&Message> = vec![];
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
let message_token_count = message
.content
.as_ref()
.unwrap_or(&ContentType::Text("".to_string()))
.to_string()
.len()
/ TOKEN_LENGTH_DIVISOR;
token_count += message_token_count;
if token_count > self.max_token_length {
debug!(
@ -82,31 +89,57 @@ impl RouterModel for RouterModelV1 {
, selected_messsage_count,
messages_vec.len()
);
if message.role == USER_ROLE {
// If message that exceeds max token length is from user, we need to keep it
selected_messages_list.push(message);
}
break;
}
selected_messsage_count += 1;
// If we are here, it means that the message is within the max token length
selected_messages_list.push(message);
}
if selected_messsage_count == 0 {
debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)",
self.max_token_length);
messages_vec = messages_vec
.last()
.map_or_else(Vec::new, |last_message| vec![last_message.to_string()]);
} else {
let skip_messages_count = messages_vec.len() - selected_messsage_count;
if skip_messages_count > 0 {
debug!(
"RouterModelV1: skipping first {} messages from the beginning of the conversation",
skip_messages_count
);
messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect();
if selected_messages_list.is_empty() {
debug!(
"RouterModelV1: no messages selected, using the last message in the conversation"
);
if let Some(last_message) = messages_vec.last() {
selected_messages_list.push(last_message);
}
}
// if selected_messsage_count == 0 {
// debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)",
// self.max_token_length);
// messages_vec = messages_vec
// .last()
// .map_or_else(Vec::new, |last_message| vec![last_message.to_string()]);
// } else {
// let skip_messages_count = messages_vec.len() - selected_messsage_count;
// if skip_messages_count > 0 {
// debug!(
// "RouterModelV1: skipping first {} messages from the beginning of the conversation",
// skip_messages_count
// );
// messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect();
// }
// }
let selected_conversation_list_str = selected_messages_list
.iter()
.rev()
.map(|m| {
let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
format!("{}: {}", m.role, content_json_str)
})
.collect::<Vec<String>>();
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace("{conversation}", messages_vec.join("\n").as_str());
.replace(
"{conversation}",
selected_conversation_list_str.join("\n").as_str(),
);
ChatCompletionsRequest {
model: self.routing_model.clone(),
@ -212,11 +245,7 @@ user: "seattle"
let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(
routes_yaml.to_string(),
routing_model.clone(),
usize::MAX,
);
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), usize::MAX);
let messages = vec![
Message {
@ -283,11 +312,7 @@ user: "seattle"
let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(
routes_yaml.to_string(),
routing_model.clone(),
225
);
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 223);
let messages = vec![
Message {
@ -360,11 +385,7 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there
let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(
routes_yaml.to_string(),
routing_model.clone(),
210,
);
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 210);
let messages = vec![
Message {
@ -413,6 +434,81 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_conversation_trim_upto_user_message() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
route1: description1
route2: description2
</routes>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
<conversation>
user: "I want to book a flight."
assistant: "Sure, where would you like to go?"
user: "seattle"
</conversation>
"#;
let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 220);
let messages = vec![
Message {
role: "system".to_string(),
content: Some(ContentType::Text(
"You are a helpful assistant.".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("Hi".to_string())),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some(ContentType::Text("Hello! How can I assist you".to_string())),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("I want to book a flight.".to_string())),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some(ContentType::Text(
"Sure, where would you like to go?".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("seattle".to_string())),
..Default::default()
},
];
let req = router.generate_request(&messages);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_parse_response() {
let router = RouterModelV1::new(