From a74118238cc5223de584536445cd929434830e43 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 27 May 2025 17:35:43 -0700 Subject: [PATCH] Use heuristic based tokenizer --- crates/brightstaff/src/router/llm_router.rs | 2 - .../brightstaff/src/router/router_model_v1.rs | 74 ++++++++++--------- crates/common/src/tokenizer.rs | 13 ---- 3 files changed, 41 insertions(+), 48 deletions(-) diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index 914b091c..8d8c057f 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -4,7 +4,6 @@ use common::{ api::open_ai::{ChatCompletionsResponse, ContentType, Message}, configuration::LlmProvider, consts::ARCH_PROVIDER_HINT_HEADER, - tokenizer::TiktokenTokenizer, }; use hyper::header; use thiserror::Error; @@ -70,7 +69,6 @@ impl RouterService { llm_providers_with_usage_yaml.clone(), routing_model_name.clone(), router_model_v1::MAX_TOKEN_LEN, - Box::new(TiktokenTokenizer {}), )); RouterService { diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index b9e16103..7c267252 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,7 +1,6 @@ use common::{ api::open_ai::{ChatCompletionsRequest, ContentType, Message}, - consts::{SYSTEM_ROLE, USER_ROLE}, - tokenizer::Tokenizer, + consts::{SYSTEM_ROLE, USER_ROLE} }; use serde::{Deserialize, Serialize}; use tracing::debug; @@ -35,20 +34,17 @@ pub struct RouterModelV1 { llm_providers_with_usage_yaml: String, routing_model: String, max_token_length: usize, - tokenizer: Box, } impl RouterModelV1 { pub fn new( llm_providers_with_usage_yaml: String, routing_model: String, max_token_length: usize, - tokenizer: Box, ) -> Self { RouterModelV1 { llm_providers_with_usage_yaml, routing_model, max_token_length, - tokenizer, } } } @@ -58,6 +54,8 @@ struct LlmRouterResponse { pub route: Option, } +const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters + impl RouterModel for RouterModelV1 { fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest { let mut messages_vec = messages @@ -69,32 +67,47 @@ impl RouterModel for RouterModelV1 { }) .collect::>(); - let mut messages_content: String; - - loop { - messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT - .replace("{routes}", &self.llm_providers_with_usage_yaml) - .replace("{conversation}", messages_vec.join("\n").as_str()); - - let token_count = self - .tokenizer - .token_count(&messages_content, &self.routing_model) - .unwrap_or(0); - if token_count <= self.max_token_length || messages_vec.len() <= 2 { - if messages_vec.len() <= 2 { - debug!("RouterModelV1: conversation is too short, using remaining messages",) - } + // Following code is to ensure that the conversation does not exceed max token length + // Note: we use a simple heuristic to estimate token count based on character length to optimize for performance + let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; + let mut selected_messsage_count = 0; + for message in messages_vec.iter().rev() { + let message_token_count = message.len() / TOKEN_LENGTH_DIVISOR; + token_count += message_token_count; + if token_count > self.max_token_length { + debug!( + "RouterModelV1: token count {} exceeds max token length {}, truncating conversation, selected message count {}, total message count: {}", + token_count, + self.max_token_length + , selected_messsage_count, + messages_vec.len() + ); break; } - debug!( - "RouterModelV1: token count {} exceeds max token length {}, truncating conversation", - token_count, - self.max_token_length - ); - // trim top two elements from the conversation - messages_vec = messages_vec.into_iter().skip(2).collect::>(); + selected_messsage_count += 1; } + if selected_messsage_count == 0 { + debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)", + self.max_token_length); + messages_vec = messages_vec + .last() + .map_or_else(Vec::new, |last_message| vec![last_message.to_string()]); + } else { + let skip_messages_count = messages_vec.len() - selected_messsage_count; + if skip_messages_count > 0 { + debug!( + "RouterModelV1: skipping first {} messages from the beginning of the conversation", + skip_messages_count + ); + messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect(); + } + } + + let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT + .replace("{routes}", &self.llm_providers_with_usage_yaml) + .replace("{conversation}", messages_vec.join("\n").as_str()); + ChatCompletionsRequest { model: self.routing_model.clone(), messages: vec![Message { @@ -169,7 +182,6 @@ mod tests { use crate::utils::tracing::init_tracer; use super::*; - use common::tokenizer::TiktokenTokenizer; use pretty_assertions::assert_eq; #[test] @@ -204,7 +216,6 @@ user: "seattle" routes_yaml.to_string(), routing_model.clone(), usize::MAX, - Box::new(TiktokenTokenizer {}), ); let messages = vec![ @@ -275,8 +286,7 @@ user: "seattle" let router = RouterModelV1::new( routes_yaml.to_string(), routing_model.clone(), - 210, - Box::new(TiktokenTokenizer {}), + 225 ); let messages = vec![ @@ -354,7 +364,6 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there routes_yaml.to_string(), routing_model.clone(), 210, - Box::new(TiktokenTokenizer {}), ); let messages = vec![ @@ -410,7 +419,6 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there "route1: description1\nroute2: description2".to_string(), "test-model".to_string(), 2000, - Box::new(TiktokenTokenizer {}), ); // Case 1: Valid JSON with non-empty route diff --git a/crates/common/src/tokenizer.rs b/crates/common/src/tokenizer.rs index 03f325de..11ce7295 100644 --- a/crates/common/src/tokenizer.rs +++ b/crates/common/src/tokenizer.rs @@ -1,18 +1,5 @@ use log::debug; -pub trait Tokenizer { - /// Returns the number of tokens in the given text. - fn token_count(&self, text: &str, model_name: &str) -> Result; -} - -pub struct TiktokenTokenizer {} - -impl Tokenizer for TiktokenTokenizer { - fn token_count(&self, text: &str, model_name: &str) -> Result { - token_count(model_name, text) - } -} - #[allow(dead_code)] pub fn token_count(model_name: &str, text: &str) -> Result { debug!("getting token count model={}", model_name);