use passed in model name in chat completion request

This commit is contained in:
Adil Hafeez 2025-03-21 14:46:45 -07:00
parent bd8004d1ae
commit 7331c415aa
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
10 changed files with 299 additions and 49 deletions

View file

@ -1,7 +1,7 @@
use crate::configuration;
use configuration::{Limit, Ratelimit, TimeUnit};
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use log::trace;
use std::fmt::Display;
use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock;
@ -99,7 +99,7 @@ impl RatelimitMap {
selector: Header,
tokens_used: NonZeroU32,
) -> Result<(), Error> {
debug!(
trace!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
);

View file

@ -10,9 +10,24 @@ pub enum Error {
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
trace!("getting token count model={}", model_name);
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
let updated_model = match model_name.starts_with("gpt") {
false => {
trace!(
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
model_name
);
"gpt-4"
}
true => model_name,
};
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
model_name: model_name.to_string(),
let bpe = tiktoken_rs::get_bpe_from_model(updated_model).map_err(|_| Error::UnknownModel {
model_name: updated_model.to_string(),
})?;
Ok(bpe.encode_ordinary(text).len())
}
@ -34,10 +49,8 @@ mod test {
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel {
model_name: "unknown".to_string()
},
token_count("unknown", "").expect_err("unknown model")
2,
token_count("unknown model", "hello world").expect("correct tokenization")
)
}
}