Use heuristic based tokenizer

This commit is contained in:
Adil Hafeez 2025-05-27 17:35:43 -07:00
parent d1542b988a
commit a74118238c
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
3 changed files with 41 additions and 48 deletions

View file

@ -4,7 +4,6 @@ use common::{
api::open_ai::{ChatCompletionsResponse, ContentType, Message},
configuration::LlmProvider,
consts::ARCH_PROVIDER_HINT_HEADER,
tokenizer::TiktokenTokenizer,
};
use hyper::header;
use thiserror::Error;
@ -70,7 +69,6 @@ impl RouterService {
llm_providers_with_usage_yaml.clone(),
routing_model_name.clone(),
router_model_v1::MAX_TOKEN_LEN,
Box::new(TiktokenTokenizer {}),
));
RouterService {

View file

@ -1,7 +1,6 @@
use common::{
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
consts::{SYSTEM_ROLE, USER_ROLE},
tokenizer::Tokenizer,
consts::{SYSTEM_ROLE, USER_ROLE}
};
use serde::{Deserialize, Serialize};
use tracing::debug;
@ -35,20 +34,17 @@ pub struct RouterModelV1 {
llm_providers_with_usage_yaml: String,
routing_model: String,
max_token_length: usize,
tokenizer: Box<dyn Tokenizer + Send + Sync>,
}
impl RouterModelV1 {
pub fn new(
llm_providers_with_usage_yaml: String,
routing_model: String,
max_token_length: usize,
tokenizer: Box<dyn Tokenizer + Send + Sync>,
) -> Self {
RouterModelV1 {
llm_providers_with_usage_yaml,
routing_model,
max_token_length,
tokenizer,
}
}
}
@ -58,6 +54,8 @@ struct LlmRouterResponse {
pub route: Option<String>,
}
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
let mut messages_vec = messages
@ -69,32 +67,47 @@ impl RouterModel for RouterModelV1 {
})
.collect::<Vec<String>>();
let mut messages_content: String;
loop {
messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace("{conversation}", messages_vec.join("\n").as_str());
let token_count = self
.tokenizer
.token_count(&messages_content, &self.routing_model)
.unwrap_or(0);
if token_count <= self.max_token_length || messages_vec.len() <= 2 {
if messages_vec.len() <= 2 {
debug!("RouterModelV1: conversation is too short, using remaining messages",)
}
// Following code is to ensure that the conversation does not exceed max token length
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messsage_count = 0;
for message in messages_vec.iter().rev() {
let message_token_count = message.len() / TOKEN_LENGTH_DIVISOR;
token_count += message_token_count;
if token_count > self.max_token_length {
debug!(
"RouterModelV1: token count {} exceeds max token length {}, truncating conversation, selected message count {}, total message count: {}",
token_count,
self.max_token_length
, selected_messsage_count,
messages_vec.len()
);
break;
}
debug!(
"RouterModelV1: token count {} exceeds max token length {}, truncating conversation",
token_count,
self.max_token_length
);
// trim top two elements from the conversation
messages_vec = messages_vec.into_iter().skip(2).collect::<Vec<String>>();
selected_messsage_count += 1;
}
if selected_messsage_count == 0 {
debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)",
self.max_token_length);
messages_vec = messages_vec
.last()
.map_or_else(Vec::new, |last_message| vec![last_message.to_string()]);
} else {
let skip_messages_count = messages_vec.len() - selected_messsage_count;
if skip_messages_count > 0 {
debug!(
"RouterModelV1: skipping first {} messages from the beginning of the conversation",
skip_messages_count
);
messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect();
}
}
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace("{conversation}", messages_vec.join("\n").as_str());
ChatCompletionsRequest {
model: self.routing_model.clone(),
messages: vec![Message {
@ -169,7 +182,6 @@ mod tests {
use crate::utils::tracing::init_tracer;
use super::*;
use common::tokenizer::TiktokenTokenizer;
use pretty_assertions::assert_eq;
#[test]
@ -204,7 +216,6 @@ user: "seattle"
routes_yaml.to_string(),
routing_model.clone(),
usize::MAX,
Box::new(TiktokenTokenizer {}),
);
let messages = vec![
@ -275,8 +286,7 @@ user: "seattle"
let router = RouterModelV1::new(
routes_yaml.to_string(),
routing_model.clone(),
210,
Box::new(TiktokenTokenizer {}),
225
);
let messages = vec![
@ -354,7 +364,6 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there
routes_yaml.to_string(),
routing_model.clone(),
210,
Box::new(TiktokenTokenizer {}),
);
let messages = vec![
@ -410,7 +419,6 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there
"route1: description1\nroute2: description2".to_string(),
"test-model".to_string(),
2000,
Box::new(TiktokenTokenizer {}),
);
// Case 1: Valid JSON with non-empty route

View file

@ -1,18 +1,5 @@
use log::debug;
pub trait Tokenizer {
/// Returns the number of tokens in the given text.
fn token_count(&self, text: &str, model_name: &str) -> Result<usize, String>;
}
pub struct TiktokenTokenizer {}
impl Tokenizer for TiktokenTokenizer {
fn token_count(&self, text: &str, model_name: &str) -> Result<usize, String> {
token_count(model_name, text)
}
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {
debug!("getting token count model={}", model_name);