merging from main

This commit is contained in:
Salman Paracha 2024-10-22 17:03:10 -07:00
commit 6a72cb45b7
13 changed files with 901 additions and 705 deletions

1
crates/Cargo.lock generated
View file

@ -1120,6 +1120,7 @@ dependencies = [
"http",
"log",
"md5",
"pretty_assertions",
"proxy-wasm",
"proxy-wasm-test-framework",
"rand",

View file

@ -7,7 +7,6 @@ pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const TOOL_ROLE: &str = "tool";
pub const ASSISTANT_ROLE: &str = "assistant";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";

View file

@ -22,11 +22,12 @@ pub enum ServerError {
Serialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error("upstream error response authority={authority}, path={path}, status={status}")]
#[error("upstream application error host={host}, path={path}, status={status}, body={body}")]
Upstream {
authority: String,
host: String,
path: String,
status: String,
body: String,
},
#[error("jailbreak detected: {0}")]
Jailbreak(String),

View file

@ -1,6 +1,9 @@
use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}};
use crate::{
errors::ClientError,
stats::{Gauge, IncrementingMetric},
};
use derivative::Derivative;
use log::debug;
use log::{debug, trace};
use proxy_wasm::{traits::Context, types::Status};
use serde::Serialize;
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
@ -45,9 +48,10 @@ pub trait Client: Context {
call_args: CallArgs,
call_context: Self::CallContext,
) -> Result<u32, ClientError> {
debug!(
trace!(
"dispatching http call with args={:?} context={:?}",
call_args, call_context
call_args,
call_context
);
match self.dispatch_http_call(

View file

@ -4,10 +4,10 @@ pub mod common_types;
pub mod configuration;
pub mod consts;
pub mod embeddings;
pub mod errors;
pub mod http;
pub mod llm_providers;
pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod errors;

View file

@ -26,3 +26,4 @@ sha2 = "0.10.8"
[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
serial_test = "3.1.1"
pretty_assertions = "1.4.1"

View file

@ -0,0 +1,94 @@
use common::errors::ServerError;
use common::stats::IncrementingMetric;
use proxy_wasm::traits::Context;
use crate::stream_context::{ResponseHandlerType, StreamContext};
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_context = self
.callouts
.get_mut()
.remove(&token_id)
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
/*
state transition
graph LR
on_http_request_body --> prompt received
prompt received --> get embeddings & arch guard
arch guard --> get embeddings
get embeddings --> zeroshot intent
on_http_request_body prompt received get embeddings zeroshot intent
arch guard
continue from zeroshot intent
graph LR
zeroshot intent --> arch_fc
zeroshot intent --> default prompt target
arch_fc --> developer api call & hallucination check
hallucination check --> parameter gathering & developer api call
developer api call --> resume request to llm
zeroshot intent arch_fc developer api call resume request to llm
default prompt target hallucination check parameter gathering
using https://mermaid-ascii.art/
*/
if let Some(body) = self.get_http_call_response_body(0, body_size) {
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
}
} else {
self.send_server_error(
ServerError::LogicError(String::from("No response body in inline HTTP request")),
None,
);
}
}
}

View file

@ -1,9 +1,9 @@
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;
use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};

View file

@ -1,39 +1,164 @@
use common::{common_types::open_ai::Message, consts::USER_ROLE};
use common::{
common_types::open_ai::Message,
consts::{ARCH_MODEL_PREFIX, ASSISTANT_ROLE, USER_ROLE},
};
pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String> {
let all_user_messages = messages
.iter()
.filter(|m| m.role == USER_ROLE)
.map(|m| m.content.as_ref().unwrap().clone())
.collect::<Vec<String>>();
return all_user_messages;
let mut arch_assistant = false;
let mut user_messages = Vec::new();
if messages.len() >= 2 {
let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() {
if model.starts_with(ARCH_MODEL_PREFIX) {
arch_assistant = true;
}
}
}
if arch_assistant {
for message in messages.iter().rev() {
if let Some(model) = message.model.as_ref() {
if !model.starts_with(ARCH_MODEL_PREFIX) {
if message.role == ASSISTANT_ROLE {
break;
}
}
}
if message.role == USER_ROLE {
if let Some(content) = &message.content {
user_messages.push(content.clone());
}
}
}
} else if let Some(message) = messages.last() {
if let Some(content) = &message.content {
user_messages.push(content.clone());
}
}
user_messages.reverse(); // Reverse to maintain the original order
return user_messages;
}
#[cfg(test)]
mod test {
use pretty_assertions::assert_eq;
use common::common_types::open_ai::Message;
use super::extract_messages_for_hallucination;
#[test]
fn test_hallucination_message() {
let test_str = r#"
fn test_hallucination_message_simple() {
let test_str = r#"
[
{
"role": "system",
"model" : "gpt-3.5-turbo",
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
},
{ "role": "user", "content": "tell me about headcount data" },
{
"role": "assistant",
"model": "Arch-Function-1.5B",
"content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data."
},
{ "role": "user", "content": "europe and for fte" }
]
"#;
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
assert_eq!(messages_for_halluncination.len(), 2);
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
assert_eq!(messages_for_halluncination.len(), 2);
}
#[test]
fn test_hallucination_message_medium() {
let test_str = r#"
[
{
"role": "system",
"model" : "gpt-3.5-turbo",
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
},
{ "role": "user", "content": "Hello" },
{
"role": "assistant",
"model": "gpt-3.5-turbo",
"content": "Hi there!"
},
{ "role": "user", "content": "tell me about headcount data" },
{
"role": "assistant",
"model": "Arch-Function-1.5B",
"content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data."
},
{ "role": "user", "content": "europe" }
,
{
"role": "system",
"model": "Arch-Function-1.5B",
"content": "It seems like you are asking for headcount data for Europe. Could you please specify the staffing type?"
},
{ "role": "user", "content": "fte" }
]
"#;
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
println!("{:?}", messages_for_halluncination);
assert_eq!(messages_for_halluncination.len(), 3);
}
#[test]
fn test_hallucination_message_long() {
let test_str = r#"
[
{
"role": "system",
"model" : "gpt-3.5-turbo",
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
},
{ "role": "user", "content": "Hello" },
{
"role": "assistant",
"model": "gpt-3.5-turbo",
"content": "Hi there!"
},
{ "role": "user", "content": "tell me about headcount data" },
{
"role": "assistant",
"model": "Arch-Function-1.5B",
"content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data."
},
{ "role": "user", "content": "europe" },
{
"role": "system",
"model": "Arch-Function-1.5B",
"content": "It seems like you are asking for headcount data for Europe. Could you please specify the staffing type?"
},
{ "role": "user", "content": "fte" },
{
"role": "assistant",
"model": "gpt-3.5-turbo",
"content": "The headcount is 50000"
},
{ "role": "user", "content": "tell me about the weather" },
{
"role": "assistant",
"model": "Arch-Function-1.5B",
"content" : "The weather forcast tools requires 2 parameters: city and days. Please specify"
},
{ "role": "user", "content": "Seattle" },
{
"role": "system",
"model": "Arch-Function-1.5B",
"content": "It seems like you are asking for weather data for Seattle. Could you please specify the days?"
},
{ "role": "user", "content": "7 days" }
]
"#;
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
println!("{:?}", messages_for_halluncination);
assert_eq!(messages_for_halluncination.len(), 3);
assert_eq!(["tell me about the weather", "Seattle", "7 days"], messages_for_halluncination.as_slice());
}
}

View file

@ -0,0 +1,340 @@
use std::{collections::HashMap, time::Duration};
use common::{
common_types::{
open_ai::{
ArchState, ChatCompletionsRequest, ChatCompletionsResponse, Message, StreamOptions,
},
PromptGuardRequest, PromptGuardTask,
},
consts::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST,
REQUEST_ID_HEADER, TOOL_ROLE, USER_ROLE,
},
errors::ServerError,
http::{CallArgs, Client},
};
use http::StatusCode;
use log::{debug, trace, warn};
use proxy_wasm::{traits::HttpContext, types::Action};
use serde_json::Value;
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
// Server's generally throw away requests whose body length do not match the Content-Length header.
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
self.is_chat_completions_request =
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
trace!(
"on_http_request_headers S[{}] req_headers={:?}",
self.context_id,
self.get_http_request_headers()
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
Action::Continue
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
if !end_of_stream {
return Action::Pause;
}
if body_size == 0 {
return Action::Continue;
}
self.request_body_size = body_size;
trace!(
"on_http_request_body S[{}] body_size={}",
self.context_id,
body_size
);
let body_bytes = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => body_bytes,
None => {
self.send_server_error(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
)),
None,
);
return Action::Pause;
}
};
debug!("developer => archgw: {}", String::from_utf8_lossy(&body_bytes));
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
if metadata.contains_key(ARCH_STATE_HEADER) {
let arch_state_str = metadata[ARCH_STATE_HEADER].clone();
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
Some(arch_state)
} else {
None
}
}
None => None,
};
self.streaming_response = deserialized_body.stream;
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
deserialized_body.stream_options = Some(StreamOptions {
include_usage: true,
});
}
let last_user_prompt = match deserialized_body
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.last()
{
Some(content) => content,
None => {
warn!("No messages in the request body");
return Action::Continue;
}
};
self.user_prompt = Some(last_user_prompt.clone());
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
let prompt_guard_jailbreak_task = self
.prompt_guards
.input_guards
.contains_key(&common::configuration::GuardType::Jailbreak);
self.chat_completions_request = Some(deserialized_body);
if !prompt_guard_jailbreak_task {
debug!("Missing input guard. Making inline call to retrieve embeddings");
let callout_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: user_message_str.clone(),
prompt_target_name: None,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
}
let get_prompt_guards_request = PromptGuardRequest {
input: self
.user_prompt
.as_ref()
.unwrap()
.content
.as_ref()
.unwrap()
.clone(),
task: PromptGuardTask::Jailbreak,
};
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
Ok(json_data) => json_data,
Err(error) => {
self.send_server_error(ServerError::Serialization(error), None);
return Action::Pause;
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST),
(":method", "POST"),
(":path", "/guard"),
(":authority", GUARD_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/guard",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
prompt_target_name: None,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
if let Err(e) = self.http_call(call_args, call_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
}
Action::Pause
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
trace!(
"on_http_response_headers recv [S={}] headers={:?}",
self.context_id,
self.get_http_response_headers()
);
// delete content-lenght header let envoy calculate it, because we modify the response body
// that would result in a different content-length
self.set_http_response_header("content-length", None);
Action::Continue
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
trace!(
"recv [S={}] bytes={} end_stream={}",
self.context_id,
body_size,
end_of_stream
);
if !self.is_chat_completions_request {
if let Some(body_str) = self
.get_http_response_body(0, body_size)
.and_then(|bytes| String::from_utf8(bytes).ok())
{
debug!("recv [S={}] body_str={}", self.context_id, body_str);
}
return Action::Continue;
}
if !end_of_stream {
return Action::Pause;
}
let body = self
.get_http_response_body(0, body_size)
.expect("cant get response body");
if self.streaming_response {
trace!("streaming response");
} else {
trace!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
trace!(
"invalid response: {}, {}",
String::from_utf8_lossy(&body),
e
);
return Action::Continue;
}
};
if chat_completions_response.usage.is_some() {
self.response_tokens += chat_completions_response
.usage
.as_ref()
.unwrap()
.completion_tokens;
}
if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
let mut data = serde_json::from_slice(&body).unwrap();
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
// since arch gateway generates tool calls (using arch-fc) and calls upstream api to
// get response, we will send these back to developer so they can see the api response
// and tool call arch-fc generated
let fc_messages = vec![
Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
},
Message {
role: TOOL_ROLE.to_string(),
content: self.tool_call_response.clone(),
model: None,
tool_calls: None,
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
},
];
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("archgw <= developer: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
}
}
}
trace!(
"recv [S={}] total_tokens={} end_stream={}",
self.context_id,
self.response_tokens,
end_of_stream
);
Action::Continue
}
}

View file

@ -2,9 +2,11 @@ use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod context;
mod filter_context;
mod stream_context;
mod hallucination;
mod http_context;
mod stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);

File diff suppressed because it is too large Load diff

View file

@ -33,7 +33,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
.returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue))
@ -74,7 +74,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -92,6 +92,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
)
.returning(Some(1))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
@ -116,6 +117,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.returning(Some(&prompt_guard_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -133,7 +135,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -159,8 +160,9 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -178,7 +180,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
)
.returning(Some(3))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -200,9 +201,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -219,8 +221,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
None,
)
.returning(Some(4))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -245,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module
.call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -426,8 +426,9 @@ fn successful_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
@ -486,13 +487,14 @@ fn bad_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
@ -564,7 +566,7 @@ fn request_to_llm_gateway() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -603,6 +605,8 @@ fn request_to_llm_gateway() {
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -628,10 +632,10 @@ fn request_to_llm_gateway() {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -664,11 +668,11 @@ fn request_to_llm_gateway() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
.returning(Some(chat_completion_response_str.as_str()))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}