mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
ensure that recent and last message is from user role
This commit is contained in:
parent
a74118238c
commit
a39ef5f215
1 changed files with 137 additions and 41 deletions
|
|
@ -1,6 +1,6 @@
|
|||
use common::{
|
||||
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
|
||||
consts::{SYSTEM_ROLE, USER_ROLE}
|
||||
consts::{SYSTEM_ROLE, USER_ROLE},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
|
|
@ -58,21 +58,28 @@ const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for U
|
|||
|
||||
impl RouterModel for RouterModelV1 {
|
||||
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
|
||||
let mut messages_vec = messages
|
||||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| m.role != SYSTEM_ROLE)
|
||||
.map(|m| {
|
||||
let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
|
||||
format!("{}: {}", m.role, content_json_str)
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
// .map(|m| {
|
||||
// let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
|
||||
// format!("{}: {}", m.role, content_json_str)
|
||||
// })
|
||||
// .collect::<Vec<String>>();
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messsage_count = 0;
|
||||
for message in messages_vec.iter().rev() {
|
||||
let message_token_count = message.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap_or(&ContentType::Text("".to_string()))
|
||||
.to_string()
|
||||
.len()
|
||||
/ TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
|
|
@ -82,31 +89,57 @@ impl RouterModel for RouterModelV1 {
|
|||
, selected_messsage_count,
|
||||
messages_vec.len()
|
||||
);
|
||||
if message.role == USER_ROLE {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list.push(message);
|
||||
}
|
||||
break;
|
||||
}
|
||||
selected_messsage_count += 1;
|
||||
// If we are here, it means that the message is within the max token length
|
||||
selected_messages_list.push(message);
|
||||
}
|
||||
|
||||
if selected_messsage_count == 0 {
|
||||
debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)",
|
||||
self.max_token_length);
|
||||
messages_vec = messages_vec
|
||||
.last()
|
||||
.map_or_else(Vec::new, |last_message| vec![last_message.to_string()]);
|
||||
} else {
|
||||
let skip_messages_count = messages_vec.len() - selected_messsage_count;
|
||||
if skip_messages_count > 0 {
|
||||
debug!(
|
||||
"RouterModelV1: skipping first {} messages from the beginning of the conversation",
|
||||
skip_messages_count
|
||||
);
|
||||
messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect();
|
||||
if selected_messages_list.is_empty() {
|
||||
debug!(
|
||||
"RouterModelV1: no messages selected, using the last message in the conversation"
|
||||
);
|
||||
if let Some(last_message) = messages_vec.last() {
|
||||
selected_messages_list.push(last_message);
|
||||
}
|
||||
}
|
||||
|
||||
// if selected_messsage_count == 0 {
|
||||
// debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)",
|
||||
// self.max_token_length);
|
||||
// messages_vec = messages_vec
|
||||
// .last()
|
||||
// .map_or_else(Vec::new, |last_message| vec![last_message.to_string()]);
|
||||
// } else {
|
||||
// let skip_messages_count = messages_vec.len() - selected_messsage_count;
|
||||
// if skip_messages_count > 0 {
|
||||
// debug!(
|
||||
// "RouterModelV1: skipping first {} messages from the beginning of the conversation",
|
||||
// skip_messages_count
|
||||
// );
|
||||
// messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect();
|
||||
// }
|
||||
// }
|
||||
|
||||
let selected_conversation_list_str = selected_messages_list
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|m| {
|
||||
let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
|
||||
format!("{}: {}", m.role, content_json_str)
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", &self.llm_providers_with_usage_yaml)
|
||||
.replace("{conversation}", messages_vec.join("\n").as_str());
|
||||
.replace(
|
||||
"{conversation}",
|
||||
selected_conversation_list_str.join("\n").as_str(),
|
||||
);
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
|
|
@ -212,11 +245,7 @@ user: "seattle"
|
|||
|
||||
let routes_yaml = "route1: description1\nroute2: description2";
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(
|
||||
routes_yaml.to_string(),
|
||||
routing_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), usize::MAX);
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
|
|
@ -283,11 +312,7 @@ user: "seattle"
|
|||
|
||||
let routes_yaml = "route1: description1\nroute2: description2";
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(
|
||||
routes_yaml.to_string(),
|
||||
routing_model.clone(),
|
||||
225
|
||||
);
|
||||
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 223);
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
|
|
@ -360,11 +385,7 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there
|
|||
|
||||
let routes_yaml = "route1: description1\nroute2: description2";
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(
|
||||
routes_yaml.to_string(),
|
||||
routing_model.clone(),
|
||||
210,
|
||||
);
|
||||
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 210);
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
|
|
@ -413,6 +434,81 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there
|
|||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
let _tracer = init_tracer();
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
route1: description1
|
||||
route2: description2
|
||||
</routes>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
|
||||
2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
|
||||
3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
|
||||
|
||||
<conversation>
|
||||
user: "I want to book a flight."
|
||||
assistant: "Sure, where would you like to go?"
|
||||
user: "seattle"
|
||||
</conversation>
|
||||
"#;
|
||||
|
||||
let routes_yaml = "route1: description1\nroute2: description2";
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 220);
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(ContentType::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(ContentType::Text("Hi".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(ContentType::Text("Hello! How can I assist you".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(ContentType::Text("I want to book a flight.".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(ContentType::Text(
|
||||
"Sure, where would you like to go?".to_string(),
|
||||
)),
|
||||
..Default::default()
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some(ContentType::Text("seattle".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
];
|
||||
|
||||
let req = router.generate_request(&messages);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let router = RouterModelV1::new(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue