mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
use passed in model name in chat completion request
This commit is contained in:
parent
bd8004d1ae
commit
7331c415aa
10 changed files with 299 additions and 49 deletions
|
|
@ -1,7 +1,7 @@
|
|||
use crate::configuration;
|
||||
use configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
|
||||
use log::debug;
|
||||
use log::trace;
|
||||
use std::fmt::Display;
|
||||
use std::num::{NonZero, NonZeroU32};
|
||||
use std::sync::RwLock;
|
||||
|
|
@ -99,7 +99,7 @@ impl RatelimitMap {
|
|||
selector: Header,
|
||||
tokens_used: NonZeroU32,
|
||||
) -> Result<(), Error> {
|
||||
debug!(
|
||||
trace!(
|
||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||
provider, selector, tokens_used
|
||||
);
|
||||
|
|
|
|||
|
|
@ -10,9 +10,24 @@ pub enum Error {
|
|||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
|
||||
trace!("getting token count model={}", model_name);
|
||||
//HACK: add support for tokenizing mistral and other models
|
||||
//filed issue https://github.com/katanemo/arch/issues/222
|
||||
|
||||
let updated_model = match model_name.starts_with("gpt") {
|
||||
false => {
|
||||
trace!(
|
||||
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
|
||||
model_name
|
||||
);
|
||||
|
||||
"gpt-4"
|
||||
}
|
||||
true => model_name,
|
||||
};
|
||||
|
||||
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
|
||||
model_name: model_name.to_string(),
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(updated_model).map_err(|_| Error::UnknownModel {
|
||||
model_name: updated_model.to_string(),
|
||||
})?;
|
||||
Ok(bpe.encode_ordinary(text).len())
|
||||
}
|
||||
|
|
@ -34,10 +49,8 @@ mod test {
|
|||
#[test]
|
||||
fn unrecognized_model() {
|
||||
assert_eq!(
|
||||
Error::UnknownModel {
|
||||
model_name: "unknown".to_string()
|
||||
},
|
||||
token_count("unknown", "").expect_err("unknown model")
|
||||
2,
|
||||
token_count("unknown model", "hello world").expect("correct tokenization")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ impl StreamContext {
|
|||
provider_hint,
|
||||
));
|
||||
|
||||
debug!(
|
||||
trace!(
|
||||
"request received: llm provider hint: {}, selected llm: {}, model: {}",
|
||||
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.unwrap_or_default(),
|
||||
|
|
@ -167,7 +167,7 @@ impl StreamContext {
|
|||
|
||||
// Check if rate limiting needs to be applied.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
log::debug!("Applying ratelimit for model: {}", model);
|
||||
log::trace!("Applying ratelimit for model: {}", model);
|
||||
ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
model.to_owned(),
|
||||
selector,
|
||||
|
|
@ -301,21 +301,27 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => match llm_provider.model.as_ref() {
|
||||
Some(model) => model,
|
||||
None => "--",
|
||||
Some(model) => Some(model),
|
||||
None => None,
|
||||
},
|
||||
None => "--",
|
||||
None => None,
|
||||
};
|
||||
|
||||
deserialized_body.model = model_name.to_string();
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
|
||||
deserialized_body.model = model_name.unwrap().to_string();
|
||||
}
|
||||
|
||||
debug!(
|
||||
"provider: {:?}, model requested: {}, model selected: {:?}",
|
||||
self.llm_provider().name,
|
||||
model_requested,
|
||||
model_name,
|
||||
);
|
||||
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
trace!(
|
||||
"arch => {:?}, body: {}",
|
||||
deserialized_body.model,
|
||||
chat_completion_request_str
|
||||
);
|
||||
trace!("request body: {}", chat_completion_request_str);
|
||||
|
||||
if deserialized_body.stream {
|
||||
self.streaming_response = true;
|
||||
|
|
@ -528,22 +534,13 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
let mut model = chat_completions_chunk_response_events
|
||||
let model = chat_completions_chunk_response_events
|
||||
.events
|
||||
.first()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
let tokens_str = chat_completions_chunk_response_events.to_string();
|
||||
//HACK: add support for tokenizing mistral and other models
|
||||
//filed issue https://github.com/katanemo/arch/issues/222
|
||||
if !model.as_ref().unwrap().starts_with("gpt") {
|
||||
trace!(
|
||||
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
|
||||
model.as_ref().unwrap()
|
||||
);
|
||||
}
|
||||
model = Some("gpt-4".to_string());
|
||||
|
||||
let token_count =
|
||||
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
|||
Some("x-arch-llm-provider-hint"),
|
||||
)
|
||||
.returning(None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Trace), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider"),
|
||||
|
|
@ -225,12 +225,13 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
|
|||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 21)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -287,7 +288,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
|||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
|
|
@ -296,9 +297,10 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
|||
)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 14)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -351,13 +353,14 @@ fn llm_gateway_request_ratelimited() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 107)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("server error occurred: exceeded limit provider=gpt-4, selector=Header { key: \"selector-key\", value: \"selector-value\" }, tokens_used=107"))
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
|
||||
None,
|
||||
|
|
@ -417,12 +420,196 @@ fn llm_gateway_request_not_ratelimited() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn llm_gateway_override_model_name() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
// give shorter body to avoid rate limiting
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"model\": \"o1-mini\",\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: o1-mini, model selected: Some(\"gpt-4\")"))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn llm_gateway_override_use_default_model() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
// give shorter body to avoid rate limiting
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(
|
||||
Some(LogLevel::Debug),
|
||||
Some("provider: \"open-ai-gpt-4\", model requested: , model selected: Some(\"gpt-4\")"),
|
||||
)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn llm_gateway_override_use_model_name_none() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
// give shorter body to avoid rate limiting
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"model\": \"none\",\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: none, model selected: Some(\"gpt-4\")"))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue