mirror of
https://github.com/katanemo/plano.git
synced 2026-05-24 14:05:14 +02:00
use passed in model name in chat completion request (#445)
This commit is contained in:
parent
bd8004d1ae
commit
eb48f3d5bb
20 changed files with 364 additions and 89 deletions
|
|
@ -249,7 +249,7 @@ client = OpenAI(
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
# we select model from arch_config file
|
# we select model from arch_config file
|
||||||
model="--",
|
model="None",
|
||||||
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ def docker_start_archgw_detached(
|
||||||
volume_mappings = [
|
volume_mappings = [
|
||||||
f"{logs_path_abs}:/var/log:rw",
|
f"{logs_path_abs}:/var/log:rw",
|
||||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||||
|
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
|
||||||
]
|
]
|
||||||
volume_mappings_args = [
|
volume_mappings_args = [
|
||||||
item for volume in volume_mappings for item in ("-v", volume)
|
item for volume in volume_mappings for item in ("-v", volume)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::configuration;
|
use crate::configuration;
|
||||||
use configuration::{Limit, Ratelimit, TimeUnit};
|
use configuration::{Limit, Ratelimit, TimeUnit};
|
||||||
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
|
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
|
||||||
use log::debug;
|
use log::trace;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::num::{NonZero, NonZeroU32};
|
use std::num::{NonZero, NonZeroU32};
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
@ -99,9 +99,11 @@ impl RatelimitMap {
|
||||||
selector: Header,
|
selector: Header,
|
||||||
tokens_used: NonZeroU32,
|
tokens_used: NonZeroU32,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
debug!(
|
trace!(
|
||||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||||
provider, selector, tokens_used
|
provider,
|
||||||
|
selector,
|
||||||
|
tokens_used
|
||||||
);
|
);
|
||||||
|
|
||||||
let provider_limits = match self.datastore.get(&provider) {
|
let provider_limits = match self.datastore.get(&provider) {
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,25 @@
|
||||||
use log::trace;
|
use log::trace;
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub enum Error {
|
pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {
|
||||||
#[error("Unknown model: {model_name}")]
|
|
||||||
UnknownModel { model_name: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
|
|
||||||
trace!("getting token count model={}", model_name);
|
trace!("getting token count model={}", model_name);
|
||||||
|
//HACK: add support for tokenizing mistral and other models
|
||||||
|
//filed issue https://github.com/katanemo/arch/issues/222
|
||||||
|
|
||||||
|
let updated_model = match model_name.starts_with("gpt") {
|
||||||
|
false => {
|
||||||
|
trace!(
|
||||||
|
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
|
||||||
|
model_name
|
||||||
|
);
|
||||||
|
|
||||||
|
"gpt-4"
|
||||||
|
}
|
||||||
|
true => model_name,
|
||||||
|
};
|
||||||
|
|
||||||
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
||||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
|
let bpe = tiktoken_rs::get_bpe_from_model(updated_model).map_err(|e| e.to_string())?;
|
||||||
model_name: model_name.to_string(),
|
|
||||||
})?;
|
|
||||||
Ok(bpe.encode_ordinary(text).len())
|
Ok(bpe.encode_ordinary(text).len())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -30,14 +36,4 @@ mod test {
|
||||||
token_count(model_name, text).expect("correct tokenization")
|
token_count(model_name, text).expect("correct tokenization")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn unrecognized_model() {
|
|
||||||
assert_eq!(
|
|
||||||
Error::UnknownModel {
|
|
||||||
model_name: "unknown".to_string()
|
|
||||||
},
|
|
||||||
token_count("unknown", "").expect_err("unknown model")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ impl StreamContext {
|
||||||
provider_hint,
|
provider_hint,
|
||||||
));
|
));
|
||||||
|
|
||||||
debug!(
|
trace!(
|
||||||
"request received: llm provider hint: {}, selected llm: {}, model: {}",
|
"request received: llm provider hint: {}, selected llm: {}, model: {}",
|
||||||
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||||
.unwrap_or_default(),
|
.unwrap_or_default(),
|
||||||
|
|
@ -167,7 +167,7 @@ impl StreamContext {
|
||||||
|
|
||||||
// Check if rate limiting needs to be applied.
|
// Check if rate limiting needs to be applied.
|
||||||
if let Some(selector) = self.ratelimit_selector.take() {
|
if let Some(selector) = self.ratelimit_selector.take() {
|
||||||
log::debug!("Applying ratelimit for model: {}", model);
|
log::trace!("Applying ratelimit for model: {}", model);
|
||||||
ratelimit::ratelimits(None).read().unwrap().check_limit(
|
ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||||
model.to_owned(),
|
model.to_owned(),
|
||||||
selector,
|
selector,
|
||||||
|
|
@ -300,22 +300,45 @@ impl HttpContext for StreamContext {
|
||||||
.cloned();
|
.cloned();
|
||||||
|
|
||||||
let model_name = match self.llm_provider.as_ref() {
|
let model_name = match self.llm_provider.as_ref() {
|
||||||
Some(llm_provider) => match llm_provider.model.as_ref() {
|
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||||
Some(model) => model,
|
None => None,
|
||||||
None => "--",
|
|
||||||
},
|
|
||||||
None => "--",
|
|
||||||
};
|
};
|
||||||
|
|
||||||
deserialized_body.model = model_name.to_string();
|
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||||
|
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||||
|
None => false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_requested = deserialized_body.model.clone();
|
||||||
|
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
|
||||||
|
deserialized_body.model = match model_name {
|
||||||
|
Some(model_name) => model_name.clone(),
|
||||||
|
None => {
|
||||||
|
if use_agent_orchestrator {
|
||||||
|
"agent_orchestrator".to_string()
|
||||||
|
} else {
|
||||||
|
self.send_server_error(
|
||||||
|
ServerError::BadRequest {
|
||||||
|
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||||
|
},
|
||||||
|
Some(StatusCode::BAD_REQUEST),
|
||||||
|
);
|
||||||
|
return Action::Continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"provider: {:?}, model requested: {}, model selected: {:?}",
|
||||||
|
self.llm_provider().name,
|
||||||
|
model_requested,
|
||||||
|
model_name,
|
||||||
|
);
|
||||||
|
|
||||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||||
|
|
||||||
trace!(
|
trace!("request body: {}", chat_completion_request_str);
|
||||||
"arch => {:?}, body: {}",
|
|
||||||
deserialized_body.model,
|
|
||||||
chat_completion_request_str
|
|
||||||
);
|
|
||||||
|
|
||||||
if deserialized_body.stream {
|
if deserialized_body.stream {
|
||||||
self.streaming_response = true;
|
self.streaming_response = true;
|
||||||
|
|
@ -528,22 +551,13 @@ impl HttpContext for StreamContext {
|
||||||
return Action::Continue;
|
return Action::Continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut model = chat_completions_chunk_response_events
|
let model = chat_completions_chunk_response_events
|
||||||
.events
|
.events
|
||||||
.first()
|
.first()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.model
|
.model
|
||||||
.clone();
|
.clone();
|
||||||
let tokens_str = chat_completions_chunk_response_events.to_string();
|
let tokens_str = chat_completions_chunk_response_events.to_string();
|
||||||
//HACK: add support for tokenizing mistral and other models
|
|
||||||
//filed issue https://github.com/katanemo/arch/issues/222
|
|
||||||
if !model.as_ref().unwrap().starts_with("gpt") {
|
|
||||||
trace!(
|
|
||||||
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
|
|
||||||
model.as_ref().unwrap()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
model = Some("gpt-4".to_string());
|
|
||||||
|
|
||||||
let token_count =
|
let token_count =
|
||||||
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())
|
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
||||||
Some("x-arch-llm-provider-hint"),
|
Some("x-arch-llm-provider-hint"),
|
||||||
)
|
)
|
||||||
.returning(None)
|
.returning(None)
|
||||||
.expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
|
.expect_log(Some(LogLevel::Trace), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
|
||||||
.expect_add_header_map_value(
|
.expect_add_header_map_value(
|
||||||
Some(MapType::HttpRequestHeaders),
|
Some(MapType::HttpRequestHeaders),
|
||||||
Some("x-arch-llm-provider"),
|
Some("x-arch-llm-provider"),
|
||||||
|
|
@ -225,12 +225,13 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
|
||||||
)
|
)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(chat_completions_request_body))
|
.returning(Some(chat_completions_request_body))
|
||||||
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_metric_record("input_sequence_length", 21)
|
.expect_metric_record("input_sequence_length", 21)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -287,7 +288,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
||||||
)
|
)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(incomplete_chat_completions_request_body))
|
.returning(Some(incomplete_chat_completions_request_body))
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_send_local_response(
|
.expect_send_local_response(
|
||||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||||
None,
|
None,
|
||||||
|
|
@ -296,9 +297,10 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
||||||
)
|
)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_metric_record("input_sequence_length", 14)
|
.expect_metric_record("input_sequence_length", 14)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -351,13 +353,14 @@ fn llm_gateway_request_ratelimited() {
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(chat_completions_request_body))
|
.returning(Some(chat_completions_request_body))
|
||||||
// The actual call is not important in this test, we just need to grab the token_id
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_metric_record("input_sequence_length", 107)
|
.expect_metric_record("input_sequence_length", 107)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("server error occurred: exceeded limit provider=gpt-4, selector=Header { key: \"selector-key\", value: \"selector-value\" }, tokens_used=107"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_send_local_response(
|
.expect_send_local_response(
|
||||||
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
|
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
|
||||||
None,
|
None,
|
||||||
|
|
@ -417,12 +420,196 @@ fn llm_gateway_request_not_ratelimited() {
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(chat_completions_request_body))
|
.returning(Some(chat_completions_request_body))
|
||||||
// The actual call is not important in this test, we just need to grab the token_id
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_metric_record("input_sequence_length", 29)
|
.expect_metric_record("input_sequence_length", 29)
|
||||||
.expect_log(Some(LogLevel::Trace), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn llm_gateway_override_model_name() {
|
||||||
|
let args = tester::MockSettings {
|
||||||
|
wasm_path: wasm_module(),
|
||||||
|
quiet: false,
|
||||||
|
allow_unexpected: false,
|
||||||
|
};
|
||||||
|
let mut module = tester::mock(args).unwrap();
|
||||||
|
|
||||||
|
module
|
||||||
|
.call_start()
|
||||||
|
.execute_and_expect(ReturnType::None)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Setup Filter
|
||||||
|
let filter_context = setup_filter(&mut module, default_config());
|
||||||
|
|
||||||
|
// Setup HTTP Stream
|
||||||
|
let http_context = 2;
|
||||||
|
|
||||||
|
normal_flow(&mut module, filter_context, http_context);
|
||||||
|
|
||||||
|
// give shorter body to avoid rate limiting
|
||||||
|
let chat_completions_request_body = "\
|
||||||
|
{\
|
||||||
|
\"model\": \"o1-mini\",\
|
||||||
|
\"messages\": [\
|
||||||
|
{\
|
||||||
|
\"role\": \"system\",\
|
||||||
|
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||||
|
},\
|
||||||
|
{\
|
||||||
|
\"role\": \"user\",\
|
||||||
|
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||||
|
}\
|
||||||
|
]
|
||||||
|
}";
|
||||||
|
|
||||||
|
module
|
||||||
|
.call_proxy_on_request_body(
|
||||||
|
http_context,
|
||||||
|
chat_completions_request_body.len() as i32,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
|
.returning(Some(chat_completions_request_body))
|
||||||
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
|
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: o1-mini, model selected: Some(\"gpt-4\")"))
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_metric_record("input_sequence_length", 29)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn llm_gateway_override_use_default_model() {
|
||||||
|
let args = tester::MockSettings {
|
||||||
|
wasm_path: wasm_module(),
|
||||||
|
quiet: false,
|
||||||
|
allow_unexpected: false,
|
||||||
|
};
|
||||||
|
let mut module = tester::mock(args).unwrap();
|
||||||
|
|
||||||
|
module
|
||||||
|
.call_start()
|
||||||
|
.execute_and_expect(ReturnType::None)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Setup Filter
|
||||||
|
let filter_context = setup_filter(&mut module, default_config());
|
||||||
|
|
||||||
|
// Setup HTTP Stream
|
||||||
|
let http_context = 2;
|
||||||
|
|
||||||
|
normal_flow(&mut module, filter_context, http_context);
|
||||||
|
|
||||||
|
// give shorter body to avoid rate limiting
|
||||||
|
let chat_completions_request_body = "\
|
||||||
|
{\
|
||||||
|
\"messages\": [\
|
||||||
|
{\
|
||||||
|
\"role\": \"system\",\
|
||||||
|
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||||
|
},\
|
||||||
|
{\
|
||||||
|
\"role\": \"user\",\
|
||||||
|
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||||
|
}\
|
||||||
|
]
|
||||||
|
}";
|
||||||
|
|
||||||
|
module
|
||||||
|
.call_proxy_on_request_body(
|
||||||
|
http_context,
|
||||||
|
chat_completions_request_body.len() as i32,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
|
.returning(Some(chat_completions_request_body))
|
||||||
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
|
.expect_log(
|
||||||
|
Some(LogLevel::Debug),
|
||||||
|
Some("provider: \"open-ai-gpt-4\", model requested: , model selected: Some(\"gpt-4\")"),
|
||||||
|
)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_metric_record("input_sequence_length", 29)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn llm_gateway_override_use_model_name_none() {
|
||||||
|
let args = tester::MockSettings {
|
||||||
|
wasm_path: wasm_module(),
|
||||||
|
quiet: false,
|
||||||
|
allow_unexpected: false,
|
||||||
|
};
|
||||||
|
let mut module = tester::mock(args).unwrap();
|
||||||
|
|
||||||
|
module
|
||||||
|
.call_start()
|
||||||
|
.execute_and_expect(ReturnType::None)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Setup Filter
|
||||||
|
let filter_context = setup_filter(&mut module, default_config());
|
||||||
|
|
||||||
|
// Setup HTTP Stream
|
||||||
|
let http_context = 2;
|
||||||
|
|
||||||
|
normal_flow(&mut module, filter_context, http_context);
|
||||||
|
|
||||||
|
// give shorter body to avoid rate limiting
|
||||||
|
let chat_completions_request_body = "\
|
||||||
|
{\
|
||||||
|
\"model\": \"none\",\
|
||||||
|
\"messages\": [\
|
||||||
|
{\
|
||||||
|
\"role\": \"system\",\
|
||||||
|
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||||
|
},\
|
||||||
|
{\
|
||||||
|
\"role\": \"user\",\
|
||||||
|
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||||
|
}\
|
||||||
|
]
|
||||||
|
}";
|
||||||
|
|
||||||
|
module
|
||||||
|
.call_proxy_on_request_body(
|
||||||
|
http_context,
|
||||||
|
chat_completions_request_body.len() as i32,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
|
.returning(Some(chat_completions_request_body))
|
||||||
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
|
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: none, model selected: Some(\"gpt-4\")"))
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_metric_record("input_sequence_length", 29)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,8 @@ impl HttpContext for StreamContext {
|
||||||
warn!("Need single endpoint when use_agent_orchestrator is set");
|
warn!("Need single endpoint when use_agent_orchestrator is set");
|
||||||
self.send_server_error(
|
self.send_server_error(
|
||||||
ServerError::LogicError(
|
ServerError::LogicError(
|
||||||
"Need single endpoint when use_agent_orchestrator is set".to_string(),
|
"Need single endpoint when use_agent_orchestrator is set"
|
||||||
|
.to_string(),
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
@ -190,7 +191,7 @@ impl HttpContext for StreamContext {
|
||||||
messages: deserialized_body.messages.clone(),
|
messages: deserialized_body.messages.clone(),
|
||||||
metadata,
|
metadata,
|
||||||
stream: deserialized_body.stream,
|
stream: deserialized_body.stream,
|
||||||
model: "--".to_string(),
|
model: deserialized_body.model.clone(),
|
||||||
stream_options: deserialized_body.stream_options.clone(),
|
stream_options: deserialized_body.stream_options.clone(),
|
||||||
tools: Some(tool_calls),
|
tools: Some(tool_calls),
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -427,7 +427,6 @@ impl StreamContext {
|
||||||
headers.insert(key.as_str(), value.as_str());
|
headers.insert(key.as_str(), value.as_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
let call_args = CallArgs::new(
|
let call_args = CallArgs::new(
|
||||||
ARCH_INTERNAL_CLUSTER_NAME,
|
ARCH_INTERNAL_CLUSTER_NAME,
|
||||||
&path,
|
&path,
|
||||||
|
|
@ -499,10 +498,7 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if !prompt_target
|
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
|
||||||
.auto_llm_dispatch_on_response
|
|
||||||
.unwrap_or(true)
|
|
||||||
{
|
|
||||||
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
|
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
|
||||||
|
|
||||||
let direct_response_str = if self.streaming_response {
|
let direct_response_str = if self.streaming_response {
|
||||||
|
|
@ -655,10 +651,7 @@ impl StreamContext {
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
// check if the default target should be dispatched to the LLM provider
|
// check if the default target should be dispatched to the LLM provider
|
||||||
if !prompt_target
|
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
|
||||||
.auto_llm_dispatch_on_response
|
|
||||||
.unwrap_or(true)
|
|
||||||
{
|
|
||||||
let default_target_response_str = if self.streaming_response {
|
let default_target_response_str = if self.streaming_response {
|
||||||
let chat_completion_response =
|
let chat_completion_response =
|
||||||
match serde_json::from_slice::<ChatCompletionsResponse>(&body) {
|
match serde_json::from_slice::<ChatCompletionsResponse>(&body) {
|
||||||
|
|
|
||||||
|
|
@ -22,12 +22,12 @@
|
||||||
</natures>
|
</natures>
|
||||||
<filteredResources>
|
<filteredResources>
|
||||||
<filter>
|
<filter>
|
||||||
<id>1739479945236</id>
|
<id>1742579142020</id>
|
||||||
<name></name>
|
<name></name>
|
||||||
<type>30</type>
|
<type>30</type>
|
||||||
<matcher>
|
<matcher>
|
||||||
<id>org.eclipse.core.resources.regexFilterMatcher</id>
|
<id>org.eclipse.core.resources.regexFilterMatcher</id>
|
||||||
<arguments>node_modules|.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__</arguments>
|
<arguments>node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__</arguments>
|
||||||
</matcher>
|
</matcher>
|
||||||
</filter>
|
</filter>
|
||||||
</filteredResources>
|
</filteredResources>
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ def chat(
|
||||||
try:
|
try:
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
# we select model from arch_config file
|
# we select model from arch_config file
|
||||||
model="--",
|
model="None",
|
||||||
messages=history,
|
messages=history,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
||||||
|
|
@ -54,13 +54,13 @@ def chat(
|
||||||
if model_selector and model_selector != "":
|
if model_selector and model_selector != "":
|
||||||
headers["x-arch-llm-provider-hint"] = model_selector
|
headers["x-arch-llm-provider-hint"] = model_selector
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
api_key="--",
|
api_key="None",
|
||||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||||
default_headers=headers,
|
default_headers=headers,
|
||||||
)
|
)
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
# we select model from arch_config file
|
# we select model from arch_config file
|
||||||
model="--",
|
model="None",
|
||||||
messages=history,
|
messages=history,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,6 @@ llm_providers:
|
||||||
model: gpt-4o-mini
|
model: gpt-4o-mini
|
||||||
default: true
|
default: true
|
||||||
|
|
||||||
- name: gpt-3.5-turbo-0125
|
|
||||||
access_key: $OPENAI_API_KEY
|
|
||||||
provider_interface: openai
|
|
||||||
model: gpt-3.5-turbo-0125
|
|
||||||
|
|
||||||
- name: gpt-4o
|
- name: gpt-4o
|
||||||
access_key: $OPENAI_API_KEY
|
access_key: $OPENAI_API_KEY
|
||||||
provider_interface: openai
|
provider_interface: openai
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
POST http://localhost:10000/v1/chat/completions
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I bought a package recently and it not working properly"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
HTTP 200
|
||||||
|
[Asserts]
|
||||||
|
header "content-type" == "application/json"
|
||||||
|
jsonpath "$.model" matches /^gpt-4o-2/
|
||||||
|
jsonpath "$.metadata.x-arch-state" != null
|
||||||
|
jsonpath "$.usage" != null
|
||||||
|
jsonpath "$.choices[0].message.content" != null
|
||||||
|
jsonpath "$.choices[0].message.role" == "assistant"
|
||||||
|
|
@ -31,9 +31,10 @@ openai_client = openai.OpenAI(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def call_openai(messages: List[Dict[str, str]], stream: bool):
|
def call_openai(messages: List[Dict[str, str]], stream: bool, model: str):
|
||||||
|
logger.info(f"llm agent model: {model}")
|
||||||
completion = openai_client.chat.completions.create(
|
completion = openai_client.chat.completions.create(
|
||||||
model="None", # archgw picks the default LLM configured in the config file
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
@ -53,14 +54,19 @@ def call_openai(messages: List[Dict[str, str]], stream: bool):
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, role: str, instructions: str):
|
def __init__(self, role: str, instructions: str, model: str = ""):
|
||||||
|
self.model = model
|
||||||
self.system_prompt = f"You are a {role}.\n{instructions}"
|
self.system_prompt = f"You are a {role}.\n{instructions}"
|
||||||
|
|
||||||
def handle(self, req: ChatCompletionsRequest):
|
def handle(self, req: ChatCompletionsRequest):
|
||||||
messages = [{"role": "system", "content": self.get_system_prompt()}] + [
|
messages = [{"role": "system", "content": self.get_system_prompt()}] + [
|
||||||
message.model_dump() for message in req.messages
|
message.model_dump() for message in req.messages
|
||||||
]
|
]
|
||||||
return call_openai(messages, req.stream)
|
|
||||||
|
model = req.model
|
||||||
|
if self.model:
|
||||||
|
model = self.model
|
||||||
|
return call_openai(messages, req.stream, model)
|
||||||
|
|
||||||
def get_system_prompt(self) -> str:
|
def get_system_prompt(self) -> str:
|
||||||
return self.system_prompt
|
return self.system_prompt
|
||||||
|
|
@ -77,13 +83,17 @@ AGENTS = {
|
||||||
"2. Quote ridiculous price\n"
|
"2. Quote ridiculous price\n"
|
||||||
"3. Reveal caveat if user agrees."
|
"3. Reveal caveat if user agrees."
|
||||||
),
|
),
|
||||||
|
model="gpt-4o-mini",
|
||||||
),
|
),
|
||||||
"issues_and_repairs": Agent(
|
"issues_and_repairs": Agent(
|
||||||
role="issues and repairs agent",
|
role="issues and repairs agent",
|
||||||
instructions="Propose a solution, offer refund if necessary.",
|
instructions="Propose a solution, offer refund if necessary.",
|
||||||
|
model="gpt-4o",
|
||||||
),
|
),
|
||||||
"escalate_to_human": Agent(
|
"escalate_to_human": Agent(
|
||||||
role="human escalation agent", instructions="Escalate issues to a human."
|
role="human escalation agent",
|
||||||
|
instructions="Escalate issues to a human.",
|
||||||
|
# skipping model name here as arch gateway will pick the default model from the config file
|
||||||
),
|
),
|
||||||
"unknown_agent": Agent(
|
"unknown_agent": Agent(
|
||||||
role="general assistant", instructions="Assist the user in general queries."
|
role="general assistant", instructions="Assist the user in general queries."
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ POST http://localhost:12000/v1/chat/completions
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
|
|
||||||
{
|
{
|
||||||
"model": "--",
|
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
|
@ -13,5 +12,13 @@ Content-Type: application/json
|
||||||
"content": "I want to sell red shoes"
|
"content": "I want to sell red shoes"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"stream": true
|
"stream": false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HTTP 200
|
||||||
|
[Asserts]
|
||||||
|
header "content-type" == "application/json"
|
||||||
|
jsonpath "$.model" matches /^gpt-4o-mini/
|
||||||
|
jsonpath "$.usage" != null
|
||||||
|
jsonpath "$.choices[0].message.content" != null
|
||||||
|
jsonpath "$.choices[0].message.role" == "assistant"
|
||||||
25
tests/hurl/llm_gateway_model_explicit_model.hurl
Normal file
25
tests/hurl/llm_gateway_model_explicit_model.hurl
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
POST http://localhost:12000/v1/chat/completions
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I want to sell red shoes"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": false
|
||||||
|
}
|
||||||
|
|
||||||
|
HTTP 200
|
||||||
|
[Asserts]
|
||||||
|
header "content-type" == "application/json"
|
||||||
|
jsonpath "$.model" matches /^gpt-3.5-turbo/
|
||||||
|
jsonpath "$.usage" != null
|
||||||
|
jsonpath "$.choices[0].message.content" != null
|
||||||
|
jsonpath "$.choices[0].message.role" == "assistant"
|
||||||
25
tests/hurl/llm_gateway_model_hint.hurl
Normal file
25
tests/hurl/llm_gateway_model_hint.hurl
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
POST http://localhost:12000/v1/chat/completions
|
||||||
|
Content-Type: application/json
|
||||||
|
x-arch-llm-provider-hint: gpt-4o
|
||||||
|
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I want to sell red shoes"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": false
|
||||||
|
}
|
||||||
|
|
||||||
|
HTTP 200
|
||||||
|
[Asserts]
|
||||||
|
header "content-type" == "application/json"
|
||||||
|
jsonpath "$.model" matches /^gpt-4o-2/
|
||||||
|
jsonpath "$.usage" != null
|
||||||
|
jsonpath "$.choices[0].message.content" != null
|
||||||
|
jsonpath "$.choices[0].message.role" == "assistant"
|
||||||
|
|
@ -238,7 +238,7 @@ POST {{model_server_endpoint}}/function_calling HTTP/1.1
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
|
|
||||||
{
|
{
|
||||||
"model": "--",
|
"model": "None",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
|
|
||||||
{
|
{
|
||||||
"model": "--",
|
"model": "None",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue