use passed in model name in chat completion request (#445)

This commit is contained in:
Adil Hafeez 2025-03-21 15:56:17 -07:00 committed by GitHub
parent bd8004d1ae
commit eb48f3d5bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 364 additions and 89 deletions

View file

@ -249,7 +249,7 @@ client = OpenAI(
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
model="None",
messages=[{"role": "user", "content": "What is the capital of France?"}],
)

View file

@ -56,6 +56,7 @@ def docker_start_archgw_detached(
volume_mappings = [
f"{logs_path_abs}:/var/log:rw",
f"{arch_config_file}:/app/arch_config.yaml:ro",
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
]
volume_mappings_args = [
item for volume in volume_mappings for item in ("-v", volume)

View file

@ -1,7 +1,7 @@
use crate::configuration;
use configuration::{Limit, Ratelimit, TimeUnit};
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug;
use log::trace;
use std::fmt::Display;
use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock;
@ -99,9 +99,11 @@ impl RatelimitMap {
selector: Header,
tokens_used: NonZeroU32,
) -> Result<(), Error> {
debug!(
trace!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
provider,
selector,
tokens_used
);
let provider_limits = match self.datastore.get(&provider) {

View file

@ -1,19 +1,25 @@
use log::trace;
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[allow(dead_code)]
pub enum Error {
#[error("Unknown model: {model_name}")]
UnknownModel { model_name: String },
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {
trace!("getting token count model={}", model_name);
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
let updated_model = match model_name.starts_with("gpt") {
false => {
trace!(
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
model_name
);
"gpt-4"
}
true => model_name,
};
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
model_name: model_name.to_string(),
})?;
let bpe = tiktoken_rs::get_bpe_from_model(updated_model).map_err(|e| e.to_string())?;
Ok(bpe.encode_ordinary(text).len())
}
@ -30,14 +36,4 @@ mod test {
token_count(model_name, text).expect("correct tokenization")
);
}
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel {
model_name: "unknown".to_string()
},
token_count("unknown", "").expect_err("unknown model")
)
}
}

View file

@ -89,7 +89,7 @@ impl StreamContext {
provider_hint,
));
debug!(
trace!(
"request received: llm provider hint: {}, selected llm: {}, model: {}",
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.unwrap_or_default(),
@ -167,7 +167,7 @@ impl StreamContext {
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
log::debug!("Applying ratelimit for model: {}", model);
log::trace!("Applying ratelimit for model: {}", model);
ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(),
selector,
@ -300,22 +300,45 @@ impl HttpContext for StreamContext {
.cloned();
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => match llm_provider.model.as_ref() {
Some(model) => model,
None => "--",
},
None => "--",
Some(llm_provider) => llm_provider.model.as_ref(),
None => None,
};
deserialized_body.model = model_name.to_string();
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
let model_requested = deserialized_body.model.clone();
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
deserialized_body.model = match model_name {
Some(model_name) => model_name.clone(),
None => {
if use_agent_orchestrator {
"agent_orchestrator".to_string()
} else {
self.send_server_error(
ServerError::BadRequest {
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
}
}
debug!(
"provider: {:?}, model requested: {}, model selected: {:?}",
self.llm_provider().name,
model_requested,
model_name,
);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
trace!(
"arch => {:?}, body: {}",
deserialized_body.model,
chat_completion_request_str
);
trace!("request body: {}", chat_completion_request_str);
if deserialized_body.stream {
self.streaming_response = true;
@ -528,22 +551,13 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
let mut model = chat_completions_chunk_response_events
let model = chat_completions_chunk_response_events
.events
.first()
.unwrap()
.model
.clone();
let tokens_str = chat_completions_chunk_response_events.to_string();
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
if !model.as_ref().unwrap().starts_with("gpt") {
trace!(
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
model.as_ref().unwrap()
);
}
model = Some("gpt-4".to_string());
let token_count =
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())

View file

@ -30,7 +30,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some("x-arch-llm-provider-hint"),
)
.returning(None)
.expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
.expect_log(Some(LogLevel::Trace), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
@ -225,12 +225,13 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 21)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
@ -287,7 +288,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
@ -296,9 +297,10 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 14)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
@ -351,13 +353,14 @@ fn llm_gateway_request_ratelimited() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 107)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), Some("server error occurred: exceeded limit provider=gpt-4, selector=Header { key: \"selector-key\", value: \"selector-value\" }, tokens_used=107"))
.expect_send_local_response(
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
None,
@ -417,12 +420,196 @@ fn llm_gateway_request_not_ratelimited() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_model_name() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"model\": \"o1-mini\",\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: o1-mini, model selected: Some(\"gpt-4\")"))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_use_default_model() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(
Some(LogLevel::Debug),
Some("provider: \"open-ai-gpt-4\", model requested: , model selected: Some(\"gpt-4\")"),
)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_use_model_name_none() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"model\": \"none\",\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: none, model selected: Some(\"gpt-4\")"))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();

View file

@ -45,7 +45,8 @@ impl HttpContext for StreamContext {
warn!("Need single endpoint when use_agent_orchestrator is set");
self.send_server_error(
ServerError::LogicError(
"Need single endpoint when use_agent_orchestrator is set".to_string(),
"Need single endpoint when use_agent_orchestrator is set"
.to_string(),
),
None,
);
@ -190,7 +191,7 @@ impl HttpContext for StreamContext {
messages: deserialized_body.messages.clone(),
metadata,
stream: deserialized_body.stream,
model: "--".to_string(),
model: deserialized_body.model.clone(),
stream_options: deserialized_body.stream_options.clone(),
tools: Some(tool_calls),
};

View file

@ -427,7 +427,6 @@ impl StreamContext {
headers.insert(key.as_str(), value.as_str());
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
@ -499,10 +498,7 @@ impl StreamContext {
}
};
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or(true)
{
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
let direct_response_str = if self.streaming_response {
@ -655,10 +651,7 @@ impl StreamContext {
.clone();
// check if the default target should be dispatched to the LLM provider
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or(true)
{
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
let default_target_response_str = if self.streaming_response {
let chat_completion_response =
match serde_json::from_slice::<ChatCompletionsResponse>(&body) {

View file

@ -22,12 +22,12 @@
</natures>
<filteredResources>
<filter>
<id>1739479945236</id>
<id>1742579142020</id>
<name></name>
<type>30</type>
<matcher>
<id>org.eclipse.core.resources.regexFilterMatcher</id>
<arguments>node_modules|.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__</arguments>
<arguments>node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__</arguments>
</matcher>
</filter>
</filteredResources>

View file

@ -38,7 +38,7 @@ def chat(
try:
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
model="None",
messages=history,
temperature=1.0,
stream=True,

View file

@ -54,13 +54,13 @@ def chat(
if model_selector and model_selector != "":
headers["x-arch-llm-provider-hint"] = model_selector
client = OpenAI(
api_key="--",
api_key="None",
base_url=CHAT_COMPLETION_ENDPOINT,
default_headers=headers,
)
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
model="None",
messages=history,
temperature=1.0,
stream=True,

View file

@ -14,11 +14,6 @@ llm_providers:
model: gpt-4o-mini
default: true
- name: gpt-3.5-turbo-0125
access_key: $OPENAI_API_KEY
provider_interface: openai
model: gpt-3.5-turbo-0125
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider_interface: openai

View file

@ -0,0 +1,19 @@
POST http://localhost:10000/v1/chat/completions
Content-Type: application/json
{
"messages": [
{
"role": "user",
"content": "I bought a package recently and it not working properly"
}
]
}
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-4o-2/
jsonpath "$.metadata.x-arch-state" != null
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -31,9 +31,10 @@ openai_client = openai.OpenAI(
)
def call_openai(messages: List[Dict[str, str]], stream: bool):
def call_openai(messages: List[Dict[str, str]], stream: bool, model: str):
logger.info(f"llm agent model: {model}")
completion = openai_client.chat.completions.create(
model="None", # archgw picks the default LLM configured in the config file
model=model,
messages=messages,
stream=stream,
)
@ -53,14 +54,19 @@ def call_openai(messages: List[Dict[str, str]], stream: bool):
class Agent:
def __init__(self, role: str, instructions: str):
def __init__(self, role: str, instructions: str, model: str = ""):
self.model = model
self.system_prompt = f"You are a {role}.\n{instructions}"
def handle(self, req: ChatCompletionsRequest):
messages = [{"role": "system", "content": self.get_system_prompt()}] + [
message.model_dump() for message in req.messages
]
return call_openai(messages, req.stream)
model = req.model
if self.model:
model = self.model
return call_openai(messages, req.stream, model)
def get_system_prompt(self) -> str:
return self.system_prompt
@ -77,13 +83,17 @@ AGENTS = {
"2. Quote ridiculous price\n"
"3. Reveal caveat if user agrees."
),
model="gpt-4o-mini",
),
"issues_and_repairs": Agent(
role="issues and repairs agent",
instructions="Propose a solution, offer refund if necessary.",
model="gpt-4o",
),
"escalate_to_human": Agent(
role="human escalation agent", instructions="Escalate issues to a human."
role="human escalation agent",
instructions="Escalate issues to a human.",
# skipping model name here as arch gateway will pick the default model from the config file
),
"unknown_agent": Agent(
role="general assistant", instructions="Assist the user in general queries."

View file

@ -2,7 +2,6 @@ POST http://localhost:12000/v1/chat/completions
Content-Type: application/json
{
"model": "--",
"messages": [
{
"role": "system",
@ -13,5 +12,13 @@ Content-Type: application/json
"content": "I want to sell red shoes"
}
],
"stream": true
"stream": false
}
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-4o-mini/
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -0,0 +1,25 @@
POST http://localhost:12000/v1/chat/completions
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.\n"
},
{
"role": "user",
"content": "I want to sell red shoes"
}
],
"stream": false
}
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-3.5-turbo/
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -0,0 +1,25 @@
POST http://localhost:12000/v1/chat/completions
Content-Type: application/json
x-arch-llm-provider-hint: gpt-4o
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.\n"
},
{
"role": "user",
"content": "I want to sell red shoes"
}
],
"stream": false
}
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-4o-2/
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -238,7 +238,7 @@ POST {{model_server_endpoint}}/function_calling HTTP/1.1
Content-Type: application/json
{
"model": "--",
"model": "None",
"messages": [
{
"role": "user",

View file

@ -82,7 +82,7 @@ POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
Content-Type: application/json
{
"model": "--",
"model": "None",
"messages": [
{
"role": "user",