From 7331c415aaa1b40501e55a06e9b8faed8ac116a3 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 21 Mar 2025 14:46:45 -0700 Subject: [PATCH] use passed in model name in chat completion request --- arch/tools/cli/docker_cli.py | 1 + crates/common/src/ratelimit.rs | 4 +- crates/common/src/tokenizer.rs | 25 ++- crates/llm_gateway/src/stream_context.rs | 39 ++-- crates/llm_gateway/tests/integration.rs | 209 +++++++++++++++++- .../weather_forcecast_service/.project | 4 +- demos/use_cases/llm_routing/arch_config.yaml | 5 - ...url => llm_gateway_model_default_llm.hurl} | 11 +- .../llm_gateway_model_explicit_model.hurl | 25 +++ tests/hurl/llm_gateway_model_hint.hurl | 25 +++ 10 files changed, 299 insertions(+), 49 deletions(-) rename tests/hurl/{llm_gateway_simple.hurl => llm_gateway_model_default_llm.hurl} (51%) create mode 100644 tests/hurl/llm_gateway_model_explicit_model.hurl create mode 100644 tests/hurl/llm_gateway_model_hint.hurl diff --git a/arch/tools/cli/docker_cli.py b/arch/tools/cli/docker_cli.py index 59a0553c..6edfb8dc 100644 --- a/arch/tools/cli/docker_cli.py +++ b/arch/tools/cli/docker_cli.py @@ -56,6 +56,7 @@ def docker_start_archgw_detached( volume_mappings = [ f"{logs_path_abs}:/var/log:rw", f"{arch_config_file}:/app/arch_config.yaml:ro", + # "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro", ] volume_mappings_args = [ item for volume in volume_mappings for item in ("-v", volume) diff --git a/crates/common/src/ratelimit.rs b/crates/common/src/ratelimit.rs index 66c3facd..39a79b9d 100644 --- a/crates/common/src/ratelimit.rs +++ b/crates/common/src/ratelimit.rs @@ -1,7 +1,7 @@ use crate::configuration; use configuration::{Limit, Ratelimit, TimeUnit}; use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota}; -use log::debug; +use log::trace; use std::fmt::Display; use std::num::{NonZero, NonZeroU32}; use std::sync::RwLock; @@ -99,7 +99,7 @@ impl RatelimitMap { selector: Header, tokens_used: NonZeroU32, ) -> Result<(), Error> { - debug!( + trace!( "Checking limit for provider={}, with selector={:?}, consuming tokens={:?}", provider, selector, tokens_used ); diff --git a/crates/common/src/tokenizer.rs b/crates/common/src/tokenizer.rs index c424e344..11aada32 100644 --- a/crates/common/src/tokenizer.rs +++ b/crates/common/src/tokenizer.rs @@ -10,9 +10,24 @@ pub enum Error { #[allow(dead_code)] pub fn token_count(model_name: &str, text: &str) -> Result { trace!("getting token count model={}", model_name); + //HACK: add support for tokenizing mistral and other models + //filed issue https://github.com/katanemo/arch/issues/222 + + let updated_model = match model_name.starts_with("gpt") { + false => { + trace!( + "tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count", + model_name + ); + + "gpt-4" + } + true => model_name, + }; + // Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton? - let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel { - model_name: model_name.to_string(), + let bpe = tiktoken_rs::get_bpe_from_model(updated_model).map_err(|_| Error::UnknownModel { + model_name: updated_model.to_string(), })?; Ok(bpe.encode_ordinary(text).len()) } @@ -34,10 +49,8 @@ mod test { #[test] fn unrecognized_model() { assert_eq!( - Error::UnknownModel { - model_name: "unknown".to_string() - }, - token_count("unknown", "").expect_err("unknown model") + 2, + token_count("unknown model", "hello world").expect("correct tokenization") ) } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index b4a05575..d51aba23 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -89,7 +89,7 @@ impl StreamContext { provider_hint, )); - debug!( + trace!( "request received: llm provider hint: {}, selected llm: {}, model: {}", self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER) .unwrap_or_default(), @@ -167,7 +167,7 @@ impl StreamContext { // Check if rate limiting needs to be applied. if let Some(selector) = self.ratelimit_selector.take() { - log::debug!("Applying ratelimit for model: {}", model); + log::trace!("Applying ratelimit for model: {}", model); ratelimit::ratelimits(None).read().unwrap().check_limit( model.to_owned(), selector, @@ -301,21 +301,27 @@ impl HttpContext for StreamContext { let model_name = match self.llm_provider.as_ref() { Some(llm_provider) => match llm_provider.model.as_ref() { - Some(model) => model, - None => "--", + Some(model) => Some(model), + None => None, }, - None => "--", + None => None, }; - deserialized_body.model = model_name.to_string(); + let model_requested = deserialized_body.model.clone(); + if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" { + deserialized_body.model = model_name.unwrap().to_string(); + } + + debug!( + "provider: {:?}, model requested: {}, model selected: {:?}", + self.llm_provider().name, + model_requested, + model_name, + ); let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap(); - trace!( - "arch => {:?}, body: {}", - deserialized_body.model, - chat_completion_request_str - ); + trace!("request body: {}", chat_completion_request_str); if deserialized_body.stream { self.streaming_response = true; @@ -528,22 +534,13 @@ impl HttpContext for StreamContext { return Action::Continue; } - let mut model = chat_completions_chunk_response_events + let model = chat_completions_chunk_response_events .events .first() .unwrap() .model .clone(); let tokens_str = chat_completions_chunk_response_events.to_string(); - //HACK: add support for tokenizing mistral and other models - //filed issue https://github.com/katanemo/arch/issues/222 - if !model.as_ref().unwrap().starts_with("gpt") { - trace!( - "tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count", - model.as_ref().unwrap() - ); - } - model = Some("gpt-4".to_string()); let token_count = match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str()) diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index 71ec2e16..9d87c00c 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -30,7 +30,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { Some("x-arch-llm-provider-hint"), ) .returning(None) - .expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4")) + .expect_log(Some(LogLevel::Trace), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4")) .expect_add_header_map_value( Some(MapType::HttpRequestHeaders), Some("x-arch-llm-provider"), @@ -225,12 +225,13 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) + .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None) .expect_metric_record("input_sequence_length", 21) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -287,7 +288,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) - .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), None, @@ -296,9 +297,10 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() { ) .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) .expect_metric_record("input_sequence_length", 14) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } @@ -351,13 +353,14 @@ fn llm_gateway_request_ratelimited() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None) .expect_metric_record("input_sequence_length", 107) .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Debug), Some("server error occurred: exceeded limit provider=gpt-4, selector=Header { key: \"selector-key\", value: \"selector-value\" }, tokens_used=107")) .expect_send_local_response( Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()), None, @@ -417,12 +420,196 @@ fn llm_gateway_request_not_ratelimited() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None) .expect_metric_record("input_sequence_length", 29) .expect_log(Some(LogLevel::Trace), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); +} + +#[test] +#[serial] +fn llm_gateway_override_model_name() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let filter_context = setup_filter(&mut module, default_config()); + + // Setup HTTP Stream + let http_context = 2; + + normal_flow(&mut module, filter_context, http_context); + + // give shorter body to avoid rate limiting + let chat_completions_request_body = "\ +{\ + \"model\": \"o1-mini\",\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ] +}"; + + module + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id + .expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: o1-mini, model selected: Some(\"gpt-4\")")) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_metric_record("input_sequence_length", 29) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); +} + +#[test] +#[serial] +fn llm_gateway_override_use_default_model() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let filter_context = setup_filter(&mut module, default_config()); + + // Setup HTTP Stream + let http_context = 2; + + normal_flow(&mut module, filter_context, http_context); + + // give shorter body to avoid rate limiting + let chat_completions_request_body = "\ +{\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ] +}"; + + module + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id + .expect_log( + Some(LogLevel::Debug), + Some("provider: \"open-ai-gpt-4\", model requested: , model selected: Some(\"gpt-4\")"), + ) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_metric_record("input_sequence_length", 29) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); +} + +#[test] +#[serial] +fn llm_gateway_override_use_model_name_none() { + let args = tester::MockSettings { + wasm_path: wasm_module(), + quiet: false, + allow_unexpected: false, + }; + let mut module = tester::mock(args).unwrap(); + + module + .call_start() + .execute_and_expect(ReturnType::None) + .unwrap(); + + // Setup Filter + let filter_context = setup_filter(&mut module, default_config()); + + // Setup HTTP Stream + let http_context = 2; + + normal_flow(&mut module, filter_context, http_context); + + // give shorter body to avoid rate limiting + let chat_completions_request_body = "\ +{\ + \"model\": \"none\",\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ] +}"; + + module + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id + .expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: none, model selected: Some(\"gpt-4\")")) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_metric_record("input_sequence_length", 29) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Trace), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); diff --git a/demos/samples_java/weather_forcecast_service/.project b/demos/samples_java/weather_forcecast_service/.project index 51363145..efcdddf7 100644 --- a/demos/samples_java/weather_forcecast_service/.project +++ b/demos/samples_java/weather_forcecast_service/.project @@ -22,12 +22,12 @@ - 1739479945236 + 1742579142020 30 org.eclipse.core.resources.regexFilterMatcher - node_modules|.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ + node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ diff --git a/demos/use_cases/llm_routing/arch_config.yaml b/demos/use_cases/llm_routing/arch_config.yaml index 4cccd718..289d8bf2 100644 --- a/demos/use_cases/llm_routing/arch_config.yaml +++ b/demos/use_cases/llm_routing/arch_config.yaml @@ -14,11 +14,6 @@ llm_providers: model: gpt-4o-mini default: true - - name: gpt-3.5-turbo-0125 - access_key: $OPENAI_API_KEY - provider_interface: openai - model: gpt-3.5-turbo-0125 - - name: gpt-4o access_key: $OPENAI_API_KEY provider_interface: openai diff --git a/tests/hurl/llm_gateway_simple.hurl b/tests/hurl/llm_gateway_model_default_llm.hurl similarity index 51% rename from tests/hurl/llm_gateway_simple.hurl rename to tests/hurl/llm_gateway_model_default_llm.hurl index f16e0074..0bca20c8 100644 --- a/tests/hurl/llm_gateway_simple.hurl +++ b/tests/hurl/llm_gateway_model_default_llm.hurl @@ -2,7 +2,6 @@ POST http://localhost:12000/v1/chat/completions Content-Type: application/json { - "model": "--", "messages": [ { "role": "system", @@ -13,5 +12,13 @@ Content-Type: application/json "content": "I want to sell red shoes" } ], - "stream": true + "stream": false } + +HTTP 200 +[Asserts] +header "content-type" == "application/json" +jsonpath "$.model" matches /^gpt-4o-mini/ +jsonpath "$.usage" != null +jsonpath "$.choices[0].message.content" != null +jsonpath "$.choices[0].message.role" == "assistant" diff --git a/tests/hurl/llm_gateway_model_explicit_model.hurl b/tests/hurl/llm_gateway_model_explicit_model.hurl new file mode 100644 index 00000000..8ba7ccd1 --- /dev/null +++ b/tests/hurl/llm_gateway_model_explicit_model.hurl @@ -0,0 +1,25 @@ +POST http://localhost:12000/v1/chat/completions +Content-Type: application/json + +{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant.\n" + }, + { + "role": "user", + "content": "I want to sell red shoes" + } + ], + "stream": false +} + +HTTP 200 +[Asserts] +header "content-type" == "application/json" +jsonpath "$.model" matches /^gpt-3.5-turbo/ +jsonpath "$.usage" != null +jsonpath "$.choices[0].message.content" != null +jsonpath "$.choices[0].message.role" == "assistant" diff --git a/tests/hurl/llm_gateway_model_hint.hurl b/tests/hurl/llm_gateway_model_hint.hurl new file mode 100644 index 00000000..9f45c670 --- /dev/null +++ b/tests/hurl/llm_gateway_model_hint.hurl @@ -0,0 +1,25 @@ +POST http://localhost:12000/v1/chat/completions +Content-Type: application/json +x-arch-llm-provider-hint: gpt-4o + +{ + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant.\n" + }, + { + "role": "user", + "content": "I want to sell red shoes" + } + ], + "stream": false +} + +HTTP 200 +[Asserts] +header "content-type" == "application/json" +jsonpath "$.model" matches /^gpt-4o-2/ +jsonpath "$.usage" != null +jsonpath "$.choices[0].message.content" != null +jsonpath "$.choices[0].message.role" == "assistant"