2025-01-31 10:37:53 -08:00
|
|
|
use log::trace;
|
2024-09-04 17:28:12 -07:00
|
|
|
|
2024-10-28 20:05:06 -04:00
|
|
|
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
|
2024-09-04 17:28:12 -07:00
|
|
|
#[allow(dead_code)]
|
|
|
|
|
pub enum Error {
|
2024-10-28 20:05:06 -04:00
|
|
|
#[error("Unknown model: {model_name}")]
|
|
|
|
|
UnknownModel { model_name: String },
|
2024-09-04 17:28:12 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[allow(dead_code)]
|
|
|
|
|
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
|
2025-01-31 10:37:53 -08:00
|
|
|
trace!("getting token count model={}", model_name);
|
2024-09-04 17:28:12 -07:00
|
|
|
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
2024-10-28 20:05:06 -04:00
|
|
|
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
|
|
|
|
|
model_name: model_name.to_string(),
|
|
|
|
|
})?;
|
2024-09-04 17:28:12 -07:00
|
|
|
Ok(bpe.encode_ordinary(text).len())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod test {
|
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn encode_ordinary() {
|
|
|
|
|
let model_name = "gpt-3.5-turbo";
|
|
|
|
|
let text = "How many tokens does this sentence have?";
|
|
|
|
|
assert_eq!(
|
|
|
|
|
8,
|
|
|
|
|
token_count(model_name, text).expect("correct tokenization")
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn unrecognized_model() {
|
|
|
|
|
assert_eq!(
|
2024-10-28 20:05:06 -04:00
|
|
|
Error::UnknownModel {
|
|
|
|
|
model_name: "unknown".to_string()
|
|
|
|
|
},
|
2024-09-04 17:28:12 -07:00
|
|
|
token_count("unknown", "").expect_err("unknown model")
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
}
|