use passed in model name in chat completion request (#445)

This commit is contained in:
Adil Hafeez 2025-03-21 15:56:17 -07:00 committed by GitHub
parent bd8004d1ae
commit eb48f3d5bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 364 additions and 89 deletions

View file

@ -249,7 +249,7 @@ client = OpenAI(
response = client.chat.completions.create( response = client.chat.completions.create(
# we select model from arch_config file # we select model from arch_config file
model="--", model="None",
messages=[{"role": "user", "content": "What is the capital of France?"}], messages=[{"role": "user", "content": "What is the capital of France?"}],
) )

View file

@ -56,6 +56,7 @@ def docker_start_archgw_detached(
volume_mappings = [ volume_mappings = [
f"{logs_path_abs}:/var/log:rw", f"{logs_path_abs}:/var/log:rw",
f"{arch_config_file}:/app/arch_config.yaml:ro", f"{arch_config_file}:/app/arch_config.yaml:ro",
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
] ]
volume_mappings_args = [ volume_mappings_args = [
item for volume in volume_mappings for item in ("-v", volume) item for volume in volume_mappings for item in ("-v", volume)

View file

@ -1,7 +1,7 @@
use crate::configuration; use crate::configuration;
use configuration::{Limit, Ratelimit, TimeUnit}; use configuration::{Limit, Ratelimit, TimeUnit};
use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota}; use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota};
use log::debug; use log::trace;
use std::fmt::Display; use std::fmt::Display;
use std::num::{NonZero, NonZeroU32}; use std::num::{NonZero, NonZeroU32};
use std::sync::RwLock; use std::sync::RwLock;
@ -99,9 +99,11 @@ impl RatelimitMap {
selector: Header, selector: Header,
tokens_used: NonZeroU32, tokens_used: NonZeroU32,
) -> Result<(), Error> { ) -> Result<(), Error> {
debug!( trace!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}", "Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used provider,
selector,
tokens_used
); );
let provider_limits = match self.datastore.get(&provider) { let provider_limits = match self.datastore.get(&provider) {

View file

@ -1,19 +1,25 @@
use log::trace; use log::trace;
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[allow(dead_code)] #[allow(dead_code)]
pub enum Error { pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {
#[error("Unknown model: {model_name}")]
UnknownModel { model_name: String },
}
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, Error> {
trace!("getting token count model={}", model_name); trace!("getting token count model={}", model_name);
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
let updated_model = match model_name.starts_with("gpt") {
false => {
trace!(
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
model_name
);
"gpt-4"
}
true => model_name,
};
// Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton? // Consideration: is it more expensive to instantiate the BPE object every time, or to contend the singleton?
let bpe = tiktoken_rs::get_bpe_from_model(model_name).map_err(|_| Error::UnknownModel { let bpe = tiktoken_rs::get_bpe_from_model(updated_model).map_err(|e| e.to_string())?;
model_name: model_name.to_string(),
})?;
Ok(bpe.encode_ordinary(text).len()) Ok(bpe.encode_ordinary(text).len())
} }
@ -30,14 +36,4 @@ mod test {
token_count(model_name, text).expect("correct tokenization") token_count(model_name, text).expect("correct tokenization")
); );
} }
#[test]
fn unrecognized_model() {
assert_eq!(
Error::UnknownModel {
model_name: "unknown".to_string()
},
token_count("unknown", "").expect_err("unknown model")
)
}
} }

View file

@ -89,7 +89,7 @@ impl StreamContext {
provider_hint, provider_hint,
)); ));
debug!( trace!(
"request received: llm provider hint: {}, selected llm: {}, model: {}", "request received: llm provider hint: {}, selected llm: {}, model: {}",
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER) self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.unwrap_or_default(), .unwrap_or_default(),
@ -167,7 +167,7 @@ impl StreamContext {
// Check if rate limiting needs to be applied. // Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() { if let Some(selector) = self.ratelimit_selector.take() {
log::debug!("Applying ratelimit for model: {}", model); log::trace!("Applying ratelimit for model: {}", model);
ratelimit::ratelimits(None).read().unwrap().check_limit( ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(), model.to_owned(),
selector, selector,
@ -300,22 +300,45 @@ impl HttpContext for StreamContext {
.cloned(); .cloned();
let model_name = match self.llm_provider.as_ref() { let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => match llm_provider.model.as_ref() { Some(llm_provider) => llm_provider.model.as_ref(),
Some(model) => model, None => None,
None => "--",
},
None => "--",
}; };
deserialized_body.model = model_name.to_string(); let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
let model_requested = deserialized_body.model.clone();
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
deserialized_body.model = match model_name {
Some(model_name) => model_name.clone(),
None => {
if use_agent_orchestrator {
"agent_orchestrator".to_string()
} else {
self.send_server_error(
ServerError::BadRequest {
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
}
}
debug!(
"provider: {:?}, model requested: {}, model selected: {:?}",
self.llm_provider().name,
model_requested,
model_name,
);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap(); let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
trace!( trace!("request body: {}", chat_completion_request_str);
"arch => {:?}, body: {}",
deserialized_body.model,
chat_completion_request_str
);
if deserialized_body.stream { if deserialized_body.stream {
self.streaming_response = true; self.streaming_response = true;
@ -528,22 +551,13 @@ impl HttpContext for StreamContext {
return Action::Continue; return Action::Continue;
} }
let mut model = chat_completions_chunk_response_events let model = chat_completions_chunk_response_events
.events .events
.first() .first()
.unwrap() .unwrap()
.model .model
.clone(); .clone();
let tokens_str = chat_completions_chunk_response_events.to_string(); let tokens_str = chat_completions_chunk_response_events.to_string();
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
if !model.as_ref().unwrap().starts_with("gpt") {
trace!(
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
model.as_ref().unwrap()
);
}
model = Some("gpt-4".to_string());
let token_count = let token_count =
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str()) match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())

View file

@ -30,7 +30,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
Some("x-arch-llm-provider-hint"), Some("x-arch-llm-provider-hint"),
) )
.returning(None) .returning(None)
.expect_log(Some(LogLevel::Debug), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4")) .expect_log(Some(LogLevel::Trace), Some("request received: llm provider hint: default, selected llm: open-ai-gpt-4, model: gpt-4"))
.expect_add_header_map_value( .expect_add_header_map_value(
Some(MapType::HttpRequestHeaders), Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"), Some("x-arch-llm-provider"),
@ -225,12 +225,13 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
) )
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body)) .returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 21) .expect_metric_record("input_sequence_length", 21)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue)) .execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap(); .unwrap();
@ -287,7 +288,7 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
) )
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body)) .returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response( .expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()), Some(StatusCode::BAD_REQUEST.as_u16().into()),
None, None,
@ -296,9 +297,10 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
) )
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 14) .expect_metric_record("input_sequence_length", 14)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Continue)) .execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap(); .unwrap();
} }
@ -351,13 +353,14 @@ fn llm_gateway_request_ratelimited() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body)) .returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id // The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 107) .expect_metric_record("input_sequence_length", 107)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), Some("server error occurred: exceeded limit provider=gpt-4, selector=Header { key: \"selector-key\", value: \"selector-value\" }, tokens_used=107"))
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response( .expect_send_local_response(
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()), Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
None, None,
@ -417,12 +420,196 @@ fn llm_gateway_request_not_ratelimited() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body)) .returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id // The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29) .expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_model_name() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"model\": \"o1-mini\",\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: o1-mini, model selected: Some(\"gpt-4\")"))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_use_default_model() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(
Some(LogLevel::Debug),
Some("provider: \"open-ai-gpt-4\", model requested: , model selected: Some(\"gpt-4\")"),
)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}
#[test]
#[serial]
fn llm_gateway_override_use_model_name_none() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
// give shorter body to avoid rate limiting
let chat_completions_request_body = "\
{\
\"model\": \"none\",\
\"messages\": [\
{\
\"role\": \"system\",\
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
},\
{\
\"role\": \"user\",\
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
}\
]
}";
module
.call_proxy_on_request_body(
http_context,
chat_completions_request_body.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), Some("provider: \"open-ai-gpt-4\", model requested: none, model selected: Some(\"gpt-4\")"))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue)) .execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap(); .unwrap();

View file

@ -45,7 +45,8 @@ impl HttpContext for StreamContext {
warn!("Need single endpoint when use_agent_orchestrator is set"); warn!("Need single endpoint when use_agent_orchestrator is set");
self.send_server_error( self.send_server_error(
ServerError::LogicError( ServerError::LogicError(
"Need single endpoint when use_agent_orchestrator is set".to_string(), "Need single endpoint when use_agent_orchestrator is set"
.to_string(),
), ),
None, None,
); );
@ -190,7 +191,7 @@ impl HttpContext for StreamContext {
messages: deserialized_body.messages.clone(), messages: deserialized_body.messages.clone(),
metadata, metadata,
stream: deserialized_body.stream, stream: deserialized_body.stream,
model: "--".to_string(), model: deserialized_body.model.clone(),
stream_options: deserialized_body.stream_options.clone(), stream_options: deserialized_body.stream_options.clone(),
tools: Some(tool_calls), tools: Some(tool_calls),
}; };

View file

@ -427,7 +427,6 @@ impl StreamContext {
headers.insert(key.as_str(), value.as_str()); headers.insert(key.as_str(), value.as_str());
} }
let call_args = CallArgs::new( let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME, ARCH_INTERNAL_CLUSTER_NAME,
&path, &path,
@ -499,10 +498,7 @@ impl StreamContext {
} }
}; };
if !prompt_target if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
.auto_llm_dispatch_on_response
.unwrap_or(true)
{
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone(); let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
let direct_response_str = if self.streaming_response { let direct_response_str = if self.streaming_response {
@ -655,10 +651,7 @@ impl StreamContext {
.clone(); .clone();
// check if the default target should be dispatched to the LLM provider // check if the default target should be dispatched to the LLM provider
if !prompt_target if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
.auto_llm_dispatch_on_response
.unwrap_or(true)
{
let default_target_response_str = if self.streaming_response { let default_target_response_str = if self.streaming_response {
let chat_completion_response = let chat_completion_response =
match serde_json::from_slice::<ChatCompletionsResponse>(&body) { match serde_json::from_slice::<ChatCompletionsResponse>(&body) {

View file

@ -22,12 +22,12 @@
</natures> </natures>
<filteredResources> <filteredResources>
<filter> <filter>
<id>1739479945236</id> <id>1742579142020</id>
<name></name> <name></name>
<type>30</type> <type>30</type>
<matcher> <matcher>
<id>org.eclipse.core.resources.regexFilterMatcher</id> <id>org.eclipse.core.resources.regexFilterMatcher</id>
<arguments>node_modules|.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__</arguments> <arguments>node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__</arguments>
</matcher> </matcher>
</filter> </filter>
</filteredResources> </filteredResources>

View file

@ -38,7 +38,7 @@ def chat(
try: try:
response = client.chat.completions.create( response = client.chat.completions.create(
# we select model from arch_config file # we select model from arch_config file
model="--", model="None",
messages=history, messages=history,
temperature=1.0, temperature=1.0,
stream=True, stream=True,

View file

@ -54,13 +54,13 @@ def chat(
if model_selector and model_selector != "": if model_selector and model_selector != "":
headers["x-arch-llm-provider-hint"] = model_selector headers["x-arch-llm-provider-hint"] = model_selector
client = OpenAI( client = OpenAI(
api_key="--", api_key="None",
base_url=CHAT_COMPLETION_ENDPOINT, base_url=CHAT_COMPLETION_ENDPOINT,
default_headers=headers, default_headers=headers,
) )
response = client.chat.completions.create( response = client.chat.completions.create(
# we select model from arch_config file # we select model from arch_config file
model="--", model="None",
messages=history, messages=history,
temperature=1.0, temperature=1.0,
stream=True, stream=True,

View file

@ -14,11 +14,6 @@ llm_providers:
model: gpt-4o-mini model: gpt-4o-mini
default: true default: true
- name: gpt-3.5-turbo-0125
access_key: $OPENAI_API_KEY
provider_interface: openai
model: gpt-3.5-turbo-0125
- name: gpt-4o - name: gpt-4o
access_key: $OPENAI_API_KEY access_key: $OPENAI_API_KEY
provider_interface: openai provider_interface: openai

View file

@ -0,0 +1,19 @@
POST http://localhost:10000/v1/chat/completions
Content-Type: application/json
{
"messages": [
{
"role": "user",
"content": "I bought a package recently and it not working properly"
}
]
}
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-4o-2/
jsonpath "$.metadata.x-arch-state" != null
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -31,9 +31,10 @@ openai_client = openai.OpenAI(
) )
def call_openai(messages: List[Dict[str, str]], stream: bool): def call_openai(messages: List[Dict[str, str]], stream: bool, model: str):
logger.info(f"llm agent model: {model}")
completion = openai_client.chat.completions.create( completion = openai_client.chat.completions.create(
model="None", # archgw picks the default LLM configured in the config file model=model,
messages=messages, messages=messages,
stream=stream, stream=stream,
) )
@ -53,14 +54,19 @@ def call_openai(messages: List[Dict[str, str]], stream: bool):
class Agent: class Agent:
def __init__(self, role: str, instructions: str): def __init__(self, role: str, instructions: str, model: str = ""):
self.model = model
self.system_prompt = f"You are a {role}.\n{instructions}" self.system_prompt = f"You are a {role}.\n{instructions}"
def handle(self, req: ChatCompletionsRequest): def handle(self, req: ChatCompletionsRequest):
messages = [{"role": "system", "content": self.get_system_prompt()}] + [ messages = [{"role": "system", "content": self.get_system_prompt()}] + [
message.model_dump() for message in req.messages message.model_dump() for message in req.messages
] ]
return call_openai(messages, req.stream)
model = req.model
if self.model:
model = self.model
return call_openai(messages, req.stream, model)
def get_system_prompt(self) -> str: def get_system_prompt(self) -> str:
return self.system_prompt return self.system_prompt
@ -77,13 +83,17 @@ AGENTS = {
"2. Quote ridiculous price\n" "2. Quote ridiculous price\n"
"3. Reveal caveat if user agrees." "3. Reveal caveat if user agrees."
), ),
model="gpt-4o-mini",
), ),
"issues_and_repairs": Agent( "issues_and_repairs": Agent(
role="issues and repairs agent", role="issues and repairs agent",
instructions="Propose a solution, offer refund if necessary.", instructions="Propose a solution, offer refund if necessary.",
model="gpt-4o",
), ),
"escalate_to_human": Agent( "escalate_to_human": Agent(
role="human escalation agent", instructions="Escalate issues to a human." role="human escalation agent",
instructions="Escalate issues to a human.",
# skipping model name here as arch gateway will pick the default model from the config file
), ),
"unknown_agent": Agent( "unknown_agent": Agent(
role="general assistant", instructions="Assist the user in general queries." role="general assistant", instructions="Assist the user in general queries."

View file

@ -2,7 +2,6 @@ POST http://localhost:12000/v1/chat/completions
Content-Type: application/json Content-Type: application/json
{ {
"model": "--",
"messages": [ "messages": [
{ {
"role": "system", "role": "system",
@ -13,5 +12,13 @@ Content-Type: application/json
"content": "I want to sell red shoes" "content": "I want to sell red shoes"
} }
], ],
"stream": true "stream": false
} }
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-4o-mini/
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -0,0 +1,25 @@
POST http://localhost:12000/v1/chat/completions
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.\n"
},
{
"role": "user",
"content": "I want to sell red shoes"
}
],
"stream": false
}
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-3.5-turbo/
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -0,0 +1,25 @@
POST http://localhost:12000/v1/chat/completions
Content-Type: application/json
x-arch-llm-provider-hint: gpt-4o
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.\n"
},
{
"role": "user",
"content": "I want to sell red shoes"
}
],
"stream": false
}
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-4o-2/
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -238,7 +238,7 @@ POST {{model_server_endpoint}}/function_calling HTTP/1.1
Content-Type: application/json Content-Type: application/json
{ {
"model": "--", "model": "None",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",

View file

@ -82,7 +82,7 @@ POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1
Content-Type: application/json Content-Type: application/json
{ {
"model": "--", "model": "None",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",