trim conversation if it exceed max limit of what router model can handle (#488)

This commit is contained in:
Adil Hafeez 2025-05-27 20:28:22 -07:00 committed by GitHub
parent 79cbcb5fe1
commit d29eba4102
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 393 additions and 83 deletions

View file

@ -1,2 +1,3 @@
pub mod handlers; pub mod handlers;
pub mod router; pub mod router;
pub mod utils;

View file

@ -1,5 +1,6 @@
use brightstaff::handlers::chat_completions::chat_completions; use brightstaff::handlers::chat_completions::chat_completions;
use brightstaff::router::llm_router::RouterService; use brightstaff::router::llm_router::RouterService;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes; use bytes::Bytes;
use common::configuration::Configuration; use common::configuration::Configuration;
use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use http_body_util::{combinators::BoxBody, BodyExt, Empty};
@ -11,13 +12,10 @@ use hyper_util::rt::TokioIo;
use opentelemetry::trace::FutureExt; use opentelemetry::trace::FutureExt;
use opentelemetry::{global, Context}; use opentelemetry::{global, Context};
use opentelemetry_http::HeaderExtractor; use opentelemetry_http::HeaderExtractor;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter;
use std::sync::Arc; use std::sync::Arc;
use std::{env, fs}; use std::{env, fs};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::{debug, info}; use tracing::{debug, info};
use tracing_subscriber::EnvFilter;
pub mod router; pub mod router;
@ -30,18 +28,6 @@ fn extract_context_from_request(req: &Request<Incoming>) -> Context {
}) })
} }
fn init_tracer() -> SdkTracerProvider {
global::set_text_map_propagator(TraceContextPropagator::new());
// Install stdout exporter pipeline to be able to retrieve the collected spans.
// For the demonstration, use `Sampler::AlwaysOn` sampler to sample all traces.
let provider = SdkTracerProvider::builder()
.with_simple_exporter(SpanExporter::default())
.build();
global::set_tracer_provider(provider.clone());
provider
}
fn empty() -> BoxBody<Bytes, hyper::Error> { fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new() Empty::<Bytes>::new()
.map_err(|never| match never {}) .map_err(|never| match never {})
@ -51,15 +37,9 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let _tracer_provider = init_tracer(); let _tracer_provider = init_tracer();
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.init();
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string()); let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
//loading arch_config.yaml file // loading arch_config.yaml file
let arch_config_path = let arch_config_path =
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string()); env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string());
info!("Loading arch_config.yaml from {}", arch_config_path); info!("Loading arch_config.yaml from {}", arch_config_path);
@ -87,7 +67,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model = arch_config let model = arch_config
.routing .routing
.as_ref() .as_ref()
.and_then(|r| Some(r.model.clone())) .map(|r| r.model.clone())
.unwrap_or_else(|| "none".to_string()); .unwrap_or_else(|| "none".to_string());
let router_service: Arc<RouterService> = Arc::new(RouterService::new( let router_service: Arc<RouterService> = Arc::new(RouterService::new(

View file

@ -9,6 +9,8 @@ use hyper::header;
use thiserror::Error; use thiserror::Error;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::router::router_model_v1::{self};
use super::router_model::RouterModel; use super::router_model::RouterModel;
pub struct RouterService { pub struct RouterService {
@ -63,9 +65,10 @@ impl RouterService {
llm_providers_with_usage_yaml.replace("\n", "\\n") llm_providers_with_usage_yaml.replace("\n", "\\n")
); );
let router_model = Arc::new(super::router_model_v1::RouterModelV1::new( let router_model = Arc::new(router_model_v1::RouterModelV1::new(
llm_providers_with_usage_yaml.clone(), llm_providers_with_usage_yaml.clone(),
routing_model_name.clone(), routing_model_name.clone(),
router_model_v1::MAX_TOKEN_LEN,
)); ));
RouterService { RouterService {

View file

@ -3,9 +3,11 @@ use common::{
consts::{SYSTEM_ROLE, USER_ROLE}, consts::{SYSTEM_ROLE, USER_ROLE},
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
use super::router_model::{RouterModel, RoutingModelError}; use super::router_model::{RouterModel, RoutingModelError};
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the routing model
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#" pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
You are a helpful assistant designed to find the best suited route. You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags: You are provided with route description within <routes></routes> XML tags:
@ -28,17 +30,21 @@ Based on your analysis, provide your response in the following JSON formats if y
"#; "#;
pub type Result<T> = std::result::Result<T, RoutingModelError>; pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub struct RouterModelV1 { pub struct RouterModelV1 {
llm_providers_with_usage_yaml: String, llm_providers_with_usage_yaml: String,
routing_model: String, routing_model: String,
max_token_length: usize,
} }
impl RouterModelV1 { impl RouterModelV1 {
pub fn new(llm_providers_with_usage_yaml: String, routing_model: String) -> Self { pub fn new(
llm_providers_with_usage_yaml: String,
routing_model: String,
max_token_length: usize,
) -> Self {
RouterModelV1 { RouterModelV1 {
llm_providers_with_usage_yaml, llm_providers_with_usage_yaml,
routing_model, routing_model,
max_token_length,
} }
} }
} }
@ -48,26 +54,89 @@ struct LlmRouterResponse {
pub route: Option<String>, pub route: Option<String>,
} }
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
impl RouterModel for RouterModelV1 { impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest { fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
let messages_str = messages let messages_vec = messages
.iter() .iter()
.filter(|m| m.role != SYSTEM_ROLE) .filter(|m| m.role != SYSTEM_ROLE)
.collect::<Vec<&Message>>();
// Following code is to ensure that the conversation does not exceed max token length
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
let message_token_count = message
.content
.as_ref()
.unwrap_or(&ContentType::Text("".to_string()))
.to_string()
.len()
/ TOKEN_LENGTH_DIVISOR;
token_count += message_token_count;
if token_count > self.max_token_length {
debug!(
"RouterModelV1: token count {} exceeds max token length {}, truncating conversation, selected message count {}, total message count: {}",
token_count,
self.max_token_length
, selected_messsage_count,
messages_vec.len()
);
if message.role == USER_ROLE {
// If message that exceeds max token length is from user, we need to keep it
selected_messages_list_reversed.push(message);
}
break;
}
// If we are here, it means that the message is within the max token length
selected_messages_list_reversed.push(message);
}
if selected_messages_list_reversed.is_empty() {
debug!(
"RouterModelV1: no messages selected, using the last message in the conversation"
);
if let Some(last_message) = messages_vec.last() {
selected_messages_list_reversed.push(last_message);
}
}
// ensure that first and last selected message is from user
if let Some(first_message) = selected_messages_list_reversed.first() {
if first_message.role != USER_ROLE {
warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing");
}
}
if let Some(last_message) = selected_messages_list_reversed.last() {
if last_message.role != USER_ROLE {
warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing");
}
}
// Reverse the selected messages to maintain the conversation order
let selected_conversation_list_str = selected_messages_list_reversed
.iter()
.rev()
.map(|m| { .map(|m| {
let content_json_str = serde_json::to_string(&m.content).unwrap_or_default(); let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
format!("{}: {}", m.role, content_json_str) format!("{}: {}", m.role, content_json_str)
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>();
.join("\n");
let message = ARCH_ROUTER_V1_SYSTEM_PROMPT let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml) .replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace("{conversation}", messages_str.as_str()); .replace(
"{conversation}",
selected_conversation_list_str.join("\n").as_str(),
);
ChatCompletionsRequest { ChatCompletionsRequest {
model: self.routing_model.clone(), model: self.routing_model.clone(),
messages: vec![Message { messages: vec![Message {
content: Some(ContentType::Text(message)), content: Some(ContentType::Text(messages_content)),
role: USER_ROLE.to_string(), role: USER_ROLE.to_string(),
model: None, model: None,
tool_calls: None, tool_calls: None,
@ -135,6 +204,8 @@ impl std::fmt::Debug for dyn RouterModel {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::utils::tracing::init_tracer;
use super::*; use super::*;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
@ -166,7 +237,7 @@ user: "seattle"
let routes_yaml = "route1: description1\nroute2: description2"; let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string(); let routing_model = "test-model".to_string();
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone()); let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), usize::MAX);
let messages = vec![ let messages = vec![
Message { Message {
@ -201,56 +272,281 @@ user: "seattle"
let prompt = req.messages[0].content.as_ref().unwrap(); let prompt = req.messages[0].content.as_ref().unwrap();
println!("Prompt: {}", prompt); assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_conversation_exceed_token_count() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
route1: description1
route2: description2
</routes>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
<conversation>
user: "I want to book a flight."
assistant: "Sure, where would you like to go?"
user: "seattle"
</conversation>
"#;
let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 223);
let messages = vec![
Message {
role: "system".to_string(),
content: Some(ContentType::Text(
"You are a helpful assistant.".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("Hi".to_string())),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some(ContentType::Text("Hello! How can I assist you".to_string())),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("I want to book a flight.".to_string())),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some(ContentType::Text(
"Sure, where would you like to go?".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("seattle".to_string())),
..Default::default()
},
];
let req = router.generate_request(&messages);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string()); assert_eq!(expected_prompt, prompt.to_string());
} }
}
#[test]
#[test] fn test_conversation_exceed_token_count_large_single_message() {
fn test_parse_response() { let _tracer = init_tracer();
let router = RouterModelV1::new( let expected_prompt = r#"
"route1: description1\nroute2: description2".to_string(), You are a helpful assistant designed to find the best suited route.
"test-model".to_string(), You are provided with route description within <routes></routes> XML tags:
); <routes>
route1: description1
// Case 1: Valid JSON with non-empty route route2: description2
let input = r#"{"route": "route1"}"#; </routes>
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string())); Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
// Case 2: Valid JSON with empty route 2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
let input = r#"{"route": ""}"#; 3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
let result = router.parse_response(input).unwrap();
assert_eq!(result, None); Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
// Case 3: Valid JSON with null route
let input = r#"{"route": null}"#;
let result = router.parse_response(input).unwrap(); <conversation>
assert_eq!(result, None); user: "Seatte, WA. But I also need to know about the weather there, and if there are any good restaurants nearby, and what the best time to visit is, and also if there are any events happening in the city."
</conversation>
// Case 4: JSON missing route field "#;
let input = r#"{}"#;
let result = router.parse_response(input).unwrap(); let routes_yaml = "route1: description1\nroute2: description2";
assert_eq!(result, None); let routing_model = "test-model".to_string();
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 210);
// Case 4.1: empty string
let input = r#""#; let messages = vec![
let result = router.parse_response(input).unwrap(); Message {
assert_eq!(result, None); role: "system".to_string(),
content: Some(ContentType::Text(
// Case 5: Malformed JSON "You are a helpful assistant.".to_string(),
let input = r#"{"route": "route1""#; // missing closing } )),
let result = router.parse_response(input); ..Default::default()
assert!(result.is_err()); },
Message {
// Case 6: Single quotes and \n in JSON role: "user".to_string(),
let input = "{'route': 'route2'}\\n"; content: Some(ContentType::Text("Hi".to_string())),
let result = router.parse_response(input).unwrap(); ..Default::default()
assert_eq!(result, Some("route2".to_string())); },
Message {
// Case 7: Code block marker role: "assistant".to_string(),
let input = "```json\n{\"route\": \"route1\"}\n```"; content: Some(ContentType::Text("Hello! How can I assist you".to_string())),
let result = router.parse_response(input).unwrap(); ..Default::default()
assert_eq!(result, Some("route1".to_string())); },
Message {
role: "user".to_string(),
content: Some(ContentType::Text("I want to book a flight.".to_string())),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some(ContentType::Text(
"Sure, where would you like to go?".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("Seatte, WA. But I also need to know about the weather there, \
and if there are any good restaurants nearby, and what the \
best time to visit is, and also if there are any events \
happening in the city.".to_string())),
..Default::default()
},
];
let req = router.generate_request(&messages);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_conversation_trim_upto_user_message() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
route1: description1
route2: description2
</routes>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant, response with empty route {"route": ""}.
2. If the user request is full fill and user thank or ending the conversation , response with empty route {"route": ""}.
3. Understand user latest intent and find the best match route in <routes></routes> xml tags.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
<conversation>
user: "I want to book a flight."
assistant: "Sure, where would you like to go?"
user: "seattle"
</conversation>
"#;
let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone(), 220);
let messages = vec![
Message {
role: "system".to_string(),
content: Some(ContentType::Text(
"You are a helpful assistant.".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("Hi".to_string())),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some(ContentType::Text("Hello! How can I assist you".to_string())),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("I want to book a flight.".to_string())),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some(ContentType::Text(
"Sure, where would you like to go?".to_string(),
)),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some(ContentType::Text("seattle".to_string())),
..Default::default()
},
];
let req = router.generate_request(&messages);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_parse_response() {
let router = RouterModelV1::new(
"route1: description1\nroute2: description2".to_string(),
"test-model".to_string(),
2000,
);
// Case 1: Valid JSON with non-empty route
let input = r#"{"route": "route1"}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string()));
// Case 2: Valid JSON with empty route
let input = r#"{"route": ""}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, None);
// Case 3: Valid JSON with null route
let input = r#"{"route": null}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, None);
// Case 4: JSON missing route field
let input = r#"{}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, None);
// Case 4.1: empty string
let input = r#""#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, None);
// Case 5: Malformed JSON
let input = r#"{"route": "route1""#; // missing closing }
let result = router.parse_response(input);
assert!(result.is_err());
// Case 6: Single quotes and \n in JSON
let input = "{'route': 'route2'}\\n";
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route2".to_string()));
// Case 7: Code block marker
let input = "```json\n{\"route\": \"route1\"}\n```";
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string()));
}
} }

View file

@ -0,0 +1 @@
pub mod tracing;

View file

@ -0,0 +1,29 @@
use std::sync::OnceLock;
use opentelemetry::global;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter;
use tracing_subscriber::EnvFilter;
static INIT_LOGGER: OnceLock<SdkTracerProvider> = OnceLock::new();
pub fn init_tracer() -> &'static SdkTracerProvider {
INIT_LOGGER.get_or_init(|| {
global::set_text_map_propagator(TraceContextPropagator::new());
// Install stdout exporter pipeline to be able to retrieve the collected spans.
// For the demonstration, use `Sampler::AlwaysOn` sampler to sample all traces.
let provider = SdkTracerProvider::builder()
.with_simple_exporter(SpanExporter::default())
.build();
global::set_tracer_provider(provider.clone());
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.init();
provider
})
}