mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 12:52:56 +02:00
use passed in model name in chat completion request (#445)
This commit is contained in:
parent
bd8004d1ae
commit
eb48f3d5bb
20 changed files with 364 additions and 89 deletions
|
|
@ -249,7 +249,7 @@ client = OpenAI(
|
|||
|
||||
response = client.chat.completions.create(
|
||||
# we select model from arch_config file
|
||||
model="--",
|
||||
model="None",
|
||||
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ def docker_start_archgw_detached(
|
|||
volume_mappings = [
|
||||
f"{logs_path_abs}:/var/log:rw",
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
|
||||
]
|
||||
volume_mappings_args = [
|
||||
item for volume in volume_mappings for item in ("-v", volume)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use crate::configuration;
|
||||
use configuration::{Limit, Ratelimit, TimeUnit};
|
||||
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
|
||||
use log::debug;
|
||||
use log::trace;
|
||||
use std::fmt::Display;
|
||||
use std::num::{NonZero, NonZeroU32};
|
||||
use std::sync::RwLock;
|
||||
|
|
@ -99,9 +99,11 @@ impl RatelimitMap {
|
|||
selector: Header,
|
||||
tokens_used: NonZeroU32,
|
||||
) -> Result<(), Error> {
|
||||
debug!(
|
||||
trace!(
|
||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||
provider, selector, tokens_used
|
||||
provider,
|
||||
selector,
|
||||
tokens_used
|
||||
);
|
||||
|
||||
let provider_limits = match self.datastore.get(&provider) {
|
||||
|
|
|
|||
|
|
@ -1,19 +1,25 @@
|
|||
use log::trace;
|
||||
|
||||
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Error {
|
||||
#[error("Unknown model: {model_name}")]
|
||||
UnknownModel { model_name: String },
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {
|
||||
trace!("getting token count model={}", model_name);
|
||||
//HACK: add support for tokenizing mistral and other models
|
||||
//filed issue https://github.com/katanemo/arch/issues/222
|
||||
|
||||
let updated_model = match model_name.starts_with("gpt") {
|
||||
false => {
|
||||
trace!(
|
||||
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
|
||||
model_name
|
||||
);
|
||||
|
||||
"gpt-4"
|
||||
}
|
||||
true => model_name,
|
||||
};
|
||||
|
||||
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel {
|
||||
model_name: model_name.to_string(),
|
||||
})?;
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(updated_model).map_err(|e| e.to_string())?;
|
||||
Ok(bpe.encode_ordinary(text).len())
|
||||
}
|
||||
|
||||
|
|
@ -30,14 +36,4 @@ mod test {
|
|||
token_count(model_name, text).expect("correct tokenization")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unrecognized_model() {
|
||||
assert_eq!(
|
||||
Error::UnknownModel {
|
||||
model_name: "unknown".to_string()
|
||||
},
|
||||
token_count("unknown", "").expect_err("unknown model")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ impl StreamContext {
|
|||
provider_hint,
|
||||
));
|
||||
|
||||
debug!(
|
||||
trace!(
|
||||
"request received: llm provider hint: {}, selected llm: {}, model: {}",
|
||||
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.unwrap_or_default(),
|
||||
|
|
@ -167,7 +167,7 @@ impl StreamContext {
|
|||
|
||||
// Check if rate limiting needs to be applied.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
log::debug!("Applying ratelimit for model: {}", model);
|
||||
log::trace!("Applying ratelimit for model: {}", model);
|
||||
ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
model.to_owned(),
|
||||
selector,
|
||||
|
|
@ -300,22 +300,45 @@ impl HttpContext for StreamContext {
|
|||
.cloned();
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => match llm_provider.model.as_ref() {
|
||||
Some(model) => model,
|
||||
None => "--",
|
||||
},
|
||||
None => "--",
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
deserialized_body.model = model_name.to_string();
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
};
|
||||
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
|
||||
deserialized_body.model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"provider: {:?}, model requested: {}, model selected: {:?}",
|
||||
self.llm_provider().name,
|
||||
model_requested,
|
||||
model_name,
|
||||
);
|
||||
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
trace!(
|
||||
"arch => {:?}, body: {}",
|
||||
deserialized_body.model,
|
||||
chat_completion_request_str
|
||||
);
|
||||
trace!("request body: {}", chat_completion_request_str);
|
||||
|
||||
if deserialized_body.stream {
|
||||
self.streaming_response = true;
|
||||
|
|
@ -528,22 +551,13 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
let mut model = chat_completions_chunk_response_events
|
||||
let model = chat_completions_chunk_response_events
|
||||
.events
|
||||
.first()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
let tokens_str = chat_completions_chunk_response_events.to_string();
|
||||
//HACK: add support for tokenizing mistral and other models
|
||||
//filed issue https://github.com/katanemo/arch/issues/222
|
||||
if !model.as_ref().unwrap().starts_with("gpt") {
|
||||
trace!(
|
||||
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
|
||||
model.as_ref().unwrap()
|
||||
);
|
||||
}
|
||||
model = Some("gpt-4".to_string());
|
||||
|
||||
let token_count =
|
||||
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
|||
Some("x-arch-llm-provider-hint"),
|
||||
)
|
||||
.returning(None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Trace), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider"),
|
||||
|
|
@ -225,12 +225,13 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
|
|||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 21)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -287,7 +288,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
|||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
|
|
@ -296,9 +297,10 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
|||
)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 14)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -351,13 +353,14 @@ fn llm_gateway_request_ratelimited() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 107)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("server error occurred: exceeded limit provider=gpt-4, selector=Header { key: \"selector-key\", value: \"selector-value\" }, tokens_used=107"))
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
|
||||
None,
|
||||
|
|
@ -417,12 +420,196 @@ fn llm_gateway_request_not_ratelimited() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn llm_gateway_override_model_name() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
// give shorter body to avoid rate limiting
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"model\": \"o1-mini\",\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: o1-mini, model selected: Some(\"gpt-4\")"))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn llm_gateway_override_use_default_model() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
// give shorter body to avoid rate limiting
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(
|
||||
Some(LogLevel::Debug),
|
||||
Some("provider: \"open-ai-gpt-4\", model requested: , model selected: Some(\"gpt-4\")"),
|
||||
)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn llm_gateway_override_use_model_name_none() {
|
||||
let args = tester::MockSettings {
|
||||
wasm_path: wasm_module(),
|
||||
quiet: false,
|
||||
allow_unexpected: false,
|
||||
};
|
||||
let mut module = tester::mock(args).unwrap();
|
||||
|
||||
module
|
||||
.call_start()
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// Setup Filter
|
||||
let filter_context = setup_filter(&mut module, default_config());
|
||||
|
||||
// Setup HTTP Stream
|
||||
let http_context = 2;
|
||||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
// give shorter body to avoid rate limiting
|
||||
let chat_completions_request_body = "\
|
||||
{\
|
||||
\"model\": \"none\",\
|
||||
\"messages\": [\
|
||||
{\
|
||||
\"role\": \"system\",\
|
||||
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
||||
},\
|
||||
{\
|
||||
\"role\": \"user\",\
|
||||
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
||||
}\
|
||||
]
|
||||
}";
|
||||
|
||||
module
|
||||
.call_proxy_on_request_body(
|
||||
http_context,
|
||||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: none, model selected: Some(\"gpt-4\")"))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -45,7 +45,8 @@ impl HttpContext for StreamContext {
|
|||
warn!("Need single endpoint when use_agent_orchestrator is set");
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(
|
||||
"Need single endpoint when use_agent_orchestrator is set".to_string(),
|
||||
"Need single endpoint when use_agent_orchestrator is set"
|
||||
.to_string(),
|
||||
),
|
||||
None,
|
||||
);
|
||||
|
|
@ -190,7 +191,7 @@ impl HttpContext for StreamContext {
|
|||
messages: deserialized_body.messages.clone(),
|
||||
metadata,
|
||||
stream: deserialized_body.stream,
|
||||
model: "--".to_string(),
|
||||
model: deserialized_body.model.clone(),
|
||||
stream_options: deserialized_body.stream_options.clone(),
|
||||
tools: Some(tool_calls),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -427,7 +427,6 @@ impl StreamContext {
|
|||
headers.insert(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&path,
|
||||
|
|
@ -499,10 +498,7 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or(true)
|
||||
{
|
||||
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
|
||||
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
|
||||
|
||||
let direct_response_str = if self.streaming_response {
|
||||
|
|
@ -655,10 +651,7 @@ impl StreamContext {
|
|||
.clone();
|
||||
|
||||
// check if the default target should be dispatched to the LLM provider
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or(true)
|
||||
{
|
||||
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
|
||||
let default_target_response_str = if self.streaming_response {
|
||||
let chat_completion_response =
|
||||
match serde_json::from_slice::<ChatCompletionsResponse>(&body) {
|
||||
|
|
|
|||
|
|
@ -22,12 +22,12 @@
|
|||
</natures>
|
||||
<filteredResources>
|
||||
<filter>
|
||||
<id>1739479945236</id>
|
||||
<id>1742579142020</id>
|
||||
<name></name>
|
||||
<type>30</type>
|
||||
<matcher>
|
||||
<id>org.eclipse.core.resources.regexFilterMatcher</id>
|
||||
<arguments>node_modules|.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__</arguments>
|
||||
<arguments>node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__</arguments>
|
||||
</matcher>
|
||||
</filter>
|
||||
</filteredResources>
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def chat(
|
|||
try:
|
||||
response = client.chat.completions.create(
|
||||
# we select model from arch_config file
|
||||
model="--",
|
||||
model="None",
|
||||
messages=history,
|
||||
temperature=1.0,
|
||||
stream=True,
|
||||
|
|
|
|||
|
|
@ -54,13 +54,13 @@ def chat(
|
|||
if model_selector and model_selector != "":
|
||||
headers["x-arch-llm-provider-hint"] = model_selector
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
api_key="None",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
default_headers=headers,
|
||||
)
|
||||
response = client.chat.completions.create(
|
||||
# we select model from arch_config file
|
||||
model="--",
|
||||
model="None",
|
||||
messages=history,
|
||||
temperature=1.0,
|
||||
stream=True,
|
||||
|
|
|
|||
|
|
@ -14,11 +14,6 @@ llm_providers:
|
|||
model: gpt-4o-mini
|
||||
default: true
|
||||
|
||||
- name: gpt-3.5-turbo-0125
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
model: gpt-3.5-turbo-0125
|
||||
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
POST http://localhost:10000/v1/chat/completions
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I bought a package recently and it not working properly"
|
||||
}
|
||||
]
|
||||
}
|
||||
HTTP 200
|
||||
[Asserts]
|
||||
header "content-type" == "application/json"
|
||||
jsonpath "$.model" matches /^gpt-4o-2/
|
||||
jsonpath "$.metadata.x-arch-state" != null
|
||||
jsonpath "$.usage" != null
|
||||
jsonpath "$.choices[0].message.content" != null
|
||||
jsonpath "$.choices[0].message.role" == "assistant"
|
||||
|
|
@ -31,9 +31,10 @@ openai_client = openai.OpenAI(
|
|||
)
|
||||
|
||||
|
||||
def call_openai(messages: List[Dict[str, str]], stream: bool):
|
||||
def call_openai(messages: List[Dict[str, str]], stream: bool, model: str):
|
||||
logger.info(f"llm agent model: {model}")
|
||||
completion = openai_client.chat.completions.create(
|
||||
model="None", # archgw picks the default LLM configured in the config file
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
)
|
||||
|
|
@ -53,14 +54,19 @@ def call_openai(messages: List[Dict[str, str]], stream: bool):
|
|||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, role: str, instructions: str):
|
||||
def __init__(self, role: str, instructions: str, model: str = ""):
|
||||
self.model = model
|
||||
self.system_prompt = f"You are a {role}.\n{instructions}"
|
||||
|
||||
def handle(self, req: ChatCompletionsRequest):
|
||||
messages = [{"role": "system", "content": self.get_system_prompt()}] + [
|
||||
message.model_dump() for message in req.messages
|
||||
]
|
||||
return call_openai(messages, req.stream)
|
||||
|
||||
model = req.model
|
||||
if self.model:
|
||||
model = self.model
|
||||
return call_openai(messages, req.stream, model)
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return self.system_prompt
|
||||
|
|
@ -77,13 +83,17 @@ AGENTS = {
|
|||
"2. Quote ridiculous price\n"
|
||||
"3. Reveal caveat if user agrees."
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
"issues_and_repairs": Agent(
|
||||
role="issues and repairs agent",
|
||||
instructions="Propose a solution, offer refund if necessary.",
|
||||
model="gpt-4o",
|
||||
),
|
||||
"escalate_to_human": Agent(
|
||||
role="human escalation agent", instructions="Escalate issues to a human."
|
||||
role="human escalation agent",
|
||||
instructions="Escalate issues to a human.",
|
||||
# skipping model name here as arch gateway will pick the default model from the config file
|
||||
),
|
||||
"unknown_agent": Agent(
|
||||
role="general assistant", instructions="Assist the user in general queries."
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ POST http://localhost:12000/v1/chat/completions
|
|||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "--",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -13,5 +12,13 @@ Content-Type: application/json
|
|||
"content": "I want to sell red shoes"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
"stream": false
|
||||
}
|
||||
|
||||
HTTP 200
|
||||
[Asserts]
|
||||
header "content-type" == "application/json"
|
||||
jsonpath "$.model" matches /^gpt-4o-mini/
|
||||
jsonpath "$.usage" != null
|
||||
jsonpath "$.choices[0].message.content" != null
|
||||
jsonpath "$.choices[0].message.role" == "assistant"
|
||||
25
tests/hurl/llm_gateway_model_explicit_model.hurl
Normal file
25
tests/hurl/llm_gateway_model_explicit_model.hurl
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
POST http://localhost:12000/v1/chat/completions
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.\n"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I want to sell red shoes"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
|
||||
HTTP 200
|
||||
[Asserts]
|
||||
header "content-type" == "application/json"
|
||||
jsonpath "$.model" matches /^gpt-3.5-turbo/
|
||||
jsonpath "$.usage" != null
|
||||
jsonpath "$.choices[0].message.content" != null
|
||||
jsonpath "$.choices[0].message.role" == "assistant"
|
||||
25
tests/hurl/llm_gateway_model_hint.hurl
Normal file
25
tests/hurl/llm_gateway_model_hint.hurl
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
POST http://localhost:12000/v1/chat/completions
|
||||
Content-Type: application/json
|
||||
x-arch-llm-provider-hint: gpt-4o
|
||||
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.\n"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I want to sell red shoes"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
|
||||
HTTP 200
|
||||
[Asserts]
|
||||
header "content-type" == "application/json"
|
||||
jsonpath "$.model" matches /^gpt-4o-2/
|
||||
jsonpath "$.usage" != null
|
||||
jsonpath "$.choices[0].message.content" != null
|
||||
jsonpath "$.choices[0].message.role" == "assistant"
|
||||
|
|
@ -238,7 +238,7 @@ POST {{model_server_endpoint}}/function_calling HTTP/1.1
|
|||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "--",
|
||||
"model": "None",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
|
|||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "--",
|
||||
"model": "None",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue