use passed in model name in chat completion request

This commit is contained in:
Adil Hafeez 2025-03-21 14:46:45 -07:00
parent bd8004d1ae
commit 7331c415aa
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
10 changed files with 299 additions and 49 deletions

View file

@ -1,7 +1,7 @@
use crate::configuration;
use configuration::{Limit, Ratelimit, TimeUnit};
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use log::trace;
use std::fmt::Display;
use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock;
@ -99,7 +99,7 @@ impl RatelimitMap {
selector: Header,
tokens_used: NonZeroU32,
) -> Result<(), Error> {
debug!(
trace!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
);

View file

@ -10,9 +10,24 @@ pub enum Error {
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
trace!("getting token count model={}", model_name);
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
let updated_model = match model_name.starts_with("gpt") {
false => {
trace!(
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
model_name
);
"gpt-4"
}
true => model_name,
};
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
model_name: model_name.to_string(),
let bpe = tiktoken_rs::get_bpe_from_model(updated_model).map_err(|_| Error::UnknownModel {
model_name: updated_model.to_string(),
})?;
Ok(bpe.encode_ordinary(text).len())
}
@ -34,10 +49,8 @@ mod test {
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel {
model_name: "unknown".to_string()
},
token_count("unknown", "").expect_err("unknown model")
2,
token_count("unknown model", "hello world").expect("correct tokenization")
)
}
}

View file

@ -89,7 +89,7 @@ impl StreamContext {
provider_hint,
));
debug!(
trace!(
"request received: llm provider hint: {}, selected llm: {}, model: {}",
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.unwrap_or_default(),
@ -167,7 +167,7 @@ impl StreamContext {
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
log::debug!("Applying ratelimit for model: {}", model);
log::trace!("Applying ratelimit for model: {}", model);
ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(),
selector,
@ -301,21 +301,27 @@ impl HttpContext for StreamContext {
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => match llm_provider.model.as_ref() {
Some(model) => model,
None => "--",
Some(model) => Some(model),
None => None,
},
None => "--",
None => None,
};
deserialized_body.model = model_name.to_string();
let model_requested = deserialized_body.model.clone();
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
deserialized_body.model = model_name.unwrap().to_string();
}
debug!(
"provider: {:?}, model requested: {}, model selected: {:?}",
self.llm_provider().name,
model_requested,
model_name,
);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
trace!(
"arch => {:?}, body: {}",
deserialized_body.model,
chat_completion_request_str
);
trace!("request body: {}", chat_completion_request_str);
if deserialized_body.stream {
self.streaming_response = true;
@ -528,22 +534,13 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
let mut model = chat_completions_chunk_response_events
let model = chat_completions_chunk_response_events
.events
.first()
.unwrap()
.model
.clone();
let tokens_str = chat_completions_chunk_response_events.to_string();
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
if !model.as_ref().unwrap().starts_with("gpt") {
trace!(
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
model.as_ref().unwrap()
);
}
model = Some("gpt-4".to_string());
let token_count =
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())

View file

@ -30,7 +30,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some("x-arch-llm-provider-hint"),
)
.returning(None)
.expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
.expect_log(Some(LogLevel::Trace), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
@ -225,12 +225,13 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 21)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
@ -287,7 +288,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
@ -296,9 +297,10 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 14)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
@ -351,13 +353,14 @@ fn llm_gateway_request_ratelimited() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 107)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), Some("server error occurred: exceeded limit provider=gpt-4, selector=Header { key: \"selector-key\", value: \"selector-value\" }, tokens_used=107"))
.expect_send_local_response(
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
None,
@ -417,12 +420,196 @@ fn llm_gateway_request_not_ratelimited() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_model_name() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"model\": \"o1-mini\",\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: o1-mini, model selected: Some(\"gpt-4\")"))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_use_default_model() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(
Some(LogLevel::Debug),
Some("provider: \"open-ai-gpt-4\", model requested: , model selected: Some(\"gpt-4\")"),
)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_use_model_name_none() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"model\": \"none\",\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: none, model selected: Some(\"gpt-4\")"))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();