mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Use heuristic based tokenizer
This commit is contained in:
parent
d1542b988a
commit
a74118238c
3 changed files with 41 additions and 48 deletions
|
|
@ -4,7 +4,6 @@ use common::{
|
|||
api::open_ai::{ChatCompletionsResponse, ContentType, Message},
|
||||
configuration::LlmProvider,
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
tokenizer::TiktokenTokenizer,
|
||||
};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
|
|
@ -70,7 +69,6 @@ impl RouterService {
|
|||
llm_providers_with_usage_yaml.clone(),
|
||||
routing_model_name.clone(),
|
||||
router_model_v1::MAX_TOKEN_LEN,
|
||||
Box::new(TiktokenTokenizer {}),
|
||||
));
|
||||
|
||||
RouterService {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
use common::{
|
||||
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
|
||||
consts::{SYSTEM_ROLE, USER_ROLE},
|
||||
tokenizer::Tokenizer,
|
||||
consts::{SYSTEM_ROLE, USER_ROLE}
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
|
|
@ -35,20 +34,17 @@ pub struct RouterModelV1 {
|
|||
llm_providers_with_usage_yaml: String,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
tokenizer: Box<dyn Tokenizer + Send + Sync>,
|
||||
}
|
||||
impl RouterModelV1 {
|
||||
pub fn new(
|
||||
llm_providers_with_usage_yaml: String,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
tokenizer: Box<dyn Tokenizer + Send + Sync>,
|
||||
) -> Self {
|
||||
RouterModelV1 {
|
||||
llm_providers_with_usage_yaml,
|
||||
routing_model,
|
||||
max_token_length,
|
||||
tokenizer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -58,6 +54,8 @@ struct LlmRouterResponse {
|
|||
pub route: Option<String>,
|
||||
}
|
||||
|
||||
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
|
||||
|
||||
impl RouterModel for RouterModelV1 {
|
||||
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
|
||||
let mut messages_vec = messages
|
||||
|
|
@ -69,32 +67,47 @@ impl RouterModel for RouterModelV1 {
|
|||
})
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let mut messages_content: String;
|
||||
|
||||
loop {
|
||||
messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", &self.llm_providers_with_usage_yaml)
|
||||
.replace("{conversation}", messages_vec.join("\n").as_str());
|
||||
|
||||
let token_count = self
|
||||
.tokenizer
|
||||
.token_count(&messages_content, &self.routing_model)
|
||||
.unwrap_or(0);
|
||||
if token_count <= self.max_token_length || messages_vec.len() <= 2 {
|
||||
if messages_vec.len() <= 2 {
|
||||
debug!("RouterModelV1: conversation is too short, using remaining messages",)
|
||||
}
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messsage_count = 0;
|
||||
for message in messages_vec.iter().rev() {
|
||||
let message_token_count = message.len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
"RouterModelV1: token count {} exceeds max token length {}, truncating conversation, selected message count {}, total message count: {}",
|
||||
token_count,
|
||||
self.max_token_length
|
||||
, selected_messsage_count,
|
||||
messages_vec.len()
|
||||
);
|
||||
break;
|
||||
}
|
||||
debug!(
|
||||
"RouterModelV1: token count {} exceeds max token length {}, truncating conversation",
|
||||
token_count,
|
||||
self.max_token_length
|
||||
);
|
||||
// trim top two elements from the conversation
|
||||
messages_vec = messages_vec.into_iter().skip(2).collect::<Vec<String>>();
|
||||
selected_messsage_count += 1;
|
||||
}
|
||||
|
||||
if selected_messsage_count == 0 {
|
||||
debug!("RouterModelV1: most recent message in conversation history exceeds max token length {}, keeping only the last message (even if it exceeds max token length)",
|
||||
self.max_token_length);
|
||||
messages_vec = messages_vec
|
||||
.last()
|
||||
.map_or_else(Vec::new, |last_message| vec![last_message.to_string()]);
|
||||
} else {
|
||||
let skip_messages_count = messages_vec.len() - selected_messsage_count;
|
||||
if skip_messages_count > 0 {
|
||||
debug!(
|
||||
"RouterModelV1: skipping first {} messages from the beginning of the conversation",
|
||||
skip_messages_count
|
||||
);
|
||||
messages_vec = messages_vec.into_iter().skip(skip_messages_count).collect();
|
||||
}
|
||||
}
|
||||
|
||||
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", &self.llm_providers_with_usage_yaml)
|
||||
.replace("{conversation}", messages_vec.join("\n").as_str());
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
|
|
@ -169,7 +182,6 @@ mod tests {
|
|||
use crate::utils::tracing::init_tracer;
|
||||
|
||||
use super::*;
|
||||
use common::tokenizer::TiktokenTokenizer;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
|
|
@ -204,7 +216,6 @@ user: "seattle"
|
|||
routes_yaml.to_string(),
|
||||
routing_model.clone(),
|
||||
usize::MAX,
|
||||
Box::new(TiktokenTokenizer {}),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
|
|
@ -275,8 +286,7 @@ user: "seattle"
|
|||
let router = RouterModelV1::new(
|
||||
routes_yaml.to_string(),
|
||||
routing_model.clone(),
|
||||
210,
|
||||
Box::new(TiktokenTokenizer {}),
|
||||
225
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
|
|
@ -354,7 +364,6 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there
|
|||
routes_yaml.to_string(),
|
||||
routing_model.clone(),
|
||||
210,
|
||||
Box::new(TiktokenTokenizer {}),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
|
|
@ -410,7 +419,6 @@ user: "Seatte, WA. But I also need to know about the weather there, and if there
|
|||
"route1: description1\nroute2: description2".to_string(),
|
||||
"test-model".to_string(),
|
||||
2000,
|
||||
Box::new(TiktokenTokenizer {}),
|
||||
);
|
||||
|
||||
// Case 1: Valid JSON with non-empty route
|
||||
|
|
|
|||
|
|
@ -1,18 +1,5 @@
|
|||
use log::debug;
|
||||
|
||||
pub trait Tokenizer {
|
||||
/// Returns the number of tokens in the given text.
|
||||
fn token_count(&self, text: &str, model_name: &str) -> Result<usize, String>;
|
||||
}
|
||||
|
||||
pub struct TiktokenTokenizer {}
|
||||
|
||||
impl Tokenizer for TiktokenTokenizer {
|
||||
fn token_count(&self, text: &str, model_name: &str) -> Result<usize, String> {
|
||||
token_count(model_name, text)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {
|
||||
debug!("getting token count model={}", model_name);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue