Improve logging (#209)

* improve logging

* fix int tests

* better

* fix more logs

* fix more

* fix int
This commit is contained in:
Adil Hafeez 2024-10-22 12:07:40 -07:00 committed by GitHub
parent 2f374df034
commit ea76d85b43
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 319 additions and 309 deletions

View file

@ -7,7 +7,6 @@ pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user"; pub const USER_ROLE: &str = "user";
pub const TOOL_ROLE: &str = "tool"; pub const TOOL_ROLE: &str = "tool";
pub const ASSISTANT_ROLE: &str = "assistant"; pub const ASSISTANT_ROLE: &str = "assistant";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server"; pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot"; pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";

View file

@ -3,7 +3,7 @@ use crate::{
stats::{Gauge, IncrementingMetric}, stats::{Gauge, IncrementingMetric},
}; };
use derivative::Derivative; use derivative::Derivative;
use log::debug; use log::{debug, trace};
use proxy_wasm::{traits::Context, types::Status}; use proxy_wasm::{traits::Context, types::Status};
use serde::Serialize; use serde::Serialize;
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration}; use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
@ -48,9 +48,10 @@ pub trait Client: Context {
call_args: CallArgs, call_args: CallArgs,
call_context: Self::CallContext, call_context: Self::CallContext,
) -> Result<u32, ClientError> { ) -> Result<u32, ClientError> {
debug!( trace!(
"dispatching http call with args={:?} context={:?}", "dispatching http call with args={:?} context={:?}",
call_args, call_context call_args,
call_context
); );
match self.dispatch_http_call( match self.dispatch_http_call(

View file

@ -74,24 +74,15 @@ impl Context for StreamContext {
*/ */
if let Some(body) = self.get_http_call_response_body(0, body_size) { if let Some(body) = self.get_http_call_response_body(0, body_size) {
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type { match callout_context.response_handler_type {
ResponseHandlerType::GetEmbeddings => {
self.embeddings_handler(body, callout_context)
}
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::ZeroShotIntent => { ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
self.zero_shot_intent_detection_resp_handler(body, callout_context) ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
}
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context), ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::HallucinationDetect => { ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
self.hallucination_classification_resp_handler(body, callout_context) ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
} ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
ResponseHandlerType::FunctionCall => {
self.function_call_response_handler(body, callout_context)
}
ResponseHandlerType::DefaultTarget => {
self.default_target_handler(body, callout_context)
}
} }
} else { } else {
self.send_server_error( self.send_server_error(

View file

@ -16,7 +16,7 @@ use common::{
http::{CallArgs, Client}, http::{CallArgs, Client},
}; };
use http::StatusCode; use http::StatusCode;
use log::{debug, warn}; use log::{debug, trace, warn};
use proxy_wasm::{traits::HttpContext, types::Action}; use proxy_wasm::{traits::HttpContext, types::Action};
use serde_json::Value; use serde_json::Value;
@ -36,7 +36,7 @@ impl HttpContext for StreamContext {
self.is_chat_completions_request = self.is_chat_completions_request =
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
debug!( trace!(
"on_http_request_headers S[{}] req_headers={:?}", "on_http_request_headers S[{}] req_headers={:?}",
self.context_id, self.context_id,
self.get_http_request_headers() self.get_http_request_headers()
@ -60,32 +60,37 @@ impl HttpContext for StreamContext {
self.request_body_size = body_size; self.request_body_size = body_size;
debug!( trace!(
"on_http_request_body S[{}] body_size={}", "on_http_request_body S[{}] body_size={}",
self.context_id, body_size self.context_id,
body_size
); );
let body_bytes = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => body_bytes,
None => {
self.send_server_error(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
)),
None,
);
return Action::Pause;
}
};
debug!("developer => archgw: {}", String::from_utf8_lossy(&body_bytes));
// Deserialize body into spec. // Deserialize body into spec.
// Currently OpenAI API. // Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest = let mut deserialized_body: ChatCompletionsRequest =
match self.get_http_request_body(0, body_size) { match serde_json::from_slice(&body_bytes) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) { Ok(deserialized) => deserialized,
Ok(deserialized) => deserialized, Err(e) => {
Err(e) => {
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
},
None => {
self.send_server_error( self.send_server_error(
ServerError::LogicError(format!( ServerError::Deserialization(e),
"Failed to obtain body bytes even though body_size is {}", Some(StatusCode::BAD_REQUEST),
body_size
)),
None,
); );
return Action::Pause; return Action::Pause;
} }
@ -145,7 +150,6 @@ impl HttpContext for StreamContext {
similarity_scores: None, similarity_scores: None,
upstream_cluster: None, upstream_cluster: None,
upstream_cluster_path: None, upstream_cluster_path: None,
tool_calls: None,
}; };
self.get_embeddings(callout_context); self.get_embeddings(callout_context);
return Action::Pause; return Action::Pause;
@ -201,7 +205,6 @@ impl HttpContext for StreamContext {
similarity_scores: None, similarity_scores: None,
upstream_cluster: None, upstream_cluster: None,
upstream_cluster_path: None, upstream_cluster_path: None,
tool_calls: None,
}; };
if let Err(e) = self.http_call(call_args, call_context) { if let Err(e) = self.http_call(call_args, call_context) {
@ -212,7 +215,7 @@ impl HttpContext for StreamContext {
} }
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
debug!( trace!(
"on_http_response_headers recv [S={}] headers={:?}", "on_http_response_headers recv [S={}] headers={:?}",
self.context_id, self.context_id,
self.get_http_response_headers() self.get_http_response_headers()
@ -224,9 +227,11 @@ impl HttpContext for StreamContext {
} }
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!( trace!(
"recv [S={}] bytes={} end_stream={}", "recv [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream self.context_id,
body_size,
end_of_stream
); );
if !self.is_chat_completions_request { if !self.is_chat_completions_request {
@ -248,14 +253,14 @@ impl HttpContext for StreamContext {
.expect("cant get response body"); .expect("cant get response body");
if self.streaming_response { if self.streaming_response {
debug!("streaming response"); trace!("streaming response");
} else { } else {
debug!("non streaming response"); trace!("non streaming response");
let chat_completions_response: ChatCompletionsResponse = let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_slice(&body) { match serde_json::from_slice(&body) {
Ok(de) => de, Ok(de) => de,
Err(e) => { Err(e) => {
debug!( trace!(
"invalid response: {}, {}", "invalid response: {}, {}",
String::from_utf8_lossy(&body), String::from_utf8_lossy(&body),
e e
@ -316,16 +321,18 @@ impl HttpContext for StreamContext {
serde_json::Value::String(arch_state_str), serde_json::Value::String(arch_state_str),
); );
let data_serialized = serde_json::to_string(&data).unwrap(); let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {}", data_serialized); debug!("archgw <= developer: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes()); self.set_http_response_body(0, body_size, data_serialized.as_bytes());
}; };
} }
} }
} }
debug!( trace!(
"recv [S={}] total_tokens={} end_stream={}", "recv [S={}] total_tokens={} end_stream={}",
self.context_id, self.response_tokens, end_of_stream self.context_id,
self.response_tokens,
end_of_stream
); );
Action::Continue Action::Continue

View file

@ -15,7 +15,7 @@ use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST,
HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
ZEROSHOT_INTERNAL_HOST, ZEROSHOT_INTERNAL_HOST,
}; };
@ -27,7 +27,7 @@ use common::http::{CallArgs, Client};
use common::stats::Gauge; use common::stats::Gauge;
use derivative::Derivative; use derivative::Derivative;
use http::StatusCode; use http::StatusCode;
use log::{debug, info, warn}; use log::{debug, info, trace, warn};
use proxy_wasm::traits::*; use proxy_wasm::traits::*;
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
@ -37,11 +37,11 @@ use std::time::Duration;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ResponseHandlerType { pub enum ResponseHandlerType {
GetEmbeddings, Embeddings,
ArchFC, ArchFC,
FunctionCall, FunctionCall,
ZeroShotIntent, ZeroShotIntent,
HallucinationDetect, Hallucination,
ArchGuard, ArchGuard,
DefaultTarget, DefaultTarget,
} }
@ -54,7 +54,6 @@ pub struct StreamCallContext {
pub prompt_target_name: Option<String>, pub prompt_target_name: Option<String>,
#[derivative(Debug = "ignore")] #[derivative(Debug = "ignore")]
pub request_body: ChatCompletionsRequest, pub request_body: ChatCompletionsRequest,
pub tool_calls: Option<Vec<ToolCall>>,
pub similarity_scores: Option<Vec<(String, f64)>>, pub similarity_scores: Option<Vec<(String, f64)>>,
pub upstream_cluster: Option<String>, pub upstream_cluster: Option<String>,
pub upstream_cluster_path: Option<String>, pub upstream_cluster_path: Option<String>,
@ -129,18 +128,77 @@ impl StreamContext {
); );
} }
pub fn get_embeddings(&mut self, callout_context: StreamCallContext) {
let user_message = callout_context.user_message.unwrap();
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let embeddings_request_str: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
warn!("error serializing get embeddings request: {}", error);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", EMBEDDINGS_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
headers,
Some(embeddings_request_str.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::Embeddings,
user_message: Some(user_message),
prompt_target_name: None,
request_body: callout_context.request_body,
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
debug!(
"archgw => get embeddings request: {}",
embeddings_request_str
);
if let Err(e) = self.http_call(call_args, call_context) {
warn!("error dispatching get embeddings request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
pub fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) { pub fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response, Ok(embedding_response) => embedding_response,
Err(e) => { Err(e) => {
debug!("error deserializing embedding response: {}", e); warn!("error deserializing embedding response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None); return self.send_server_error(ServerError::Deserialization(e), None);
} }
}; };
let prompt_embeddings_vector = &embedding_response.data[0].embedding; let prompt_embeddings_vector = &embedding_response.data[0].embedding;
debug!( trace!(
"embedding model: {}, vector length: {:?}", "embedding model: {}, vector length: {:?}",
embedding_response.model, embedding_response.model,
prompt_embeddings_vector.len() prompt_embeddings_vector.len()
@ -237,7 +295,7 @@ impl StreamContext {
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
if let Err(e) = self.http_call(call_args, callout_context) { if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching zero shot classification request: {}", e); warn!("error dispatching zero shot classification request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None); self.send_server_error(ServerError::HttpDispatch(e), None);
} }
} }
@ -247,11 +305,13 @@ impl StreamContext {
body: Vec<u8>, body: Vec<u8>,
callout_context: StreamCallContext, callout_context: StreamCallContext,
) { ) {
let boyd_str = String::from_utf8(body).expect("could not convert body to string");
debug!("archgw <= hallucination response: {}", boyd_str);
let hallucination_response: HallucinationClassificationResponse = let hallucination_response: HallucinationClassificationResponse =
match serde_json::from_slice(&body) { match serde_json::from_str(boyd_str.as_str()) {
Ok(hallucination_response) => hallucination_response, Ok(hallucination_response) => hallucination_response,
Err(e) => { Err(e) => {
debug!("error deserializing hallucination response: {}", e); warn!("error deserializing hallucination response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None); return self.send_server_error(ServerError::Deserialization(e), None);
} }
}; };
@ -291,7 +351,7 @@ impl StreamContext {
metadata: None, metadata: None,
}; };
debug!("hallucination response: {:?}", chat_completion_response); trace!("hallucination response: {:?}", chat_completion_response);
self.send_http_response( self.send_http_response(
StatusCode::OK.as_u16().into(), StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")], vec![("Powered-By", "Katanemo")],
@ -316,7 +376,7 @@ impl StreamContext {
match serde_json::from_slice(&body) { match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response, Ok(zeroshot_response) => zeroshot_response,
Err(e) => { Err(e) => {
debug!( warn!(
"error deserializing zero shot classification response: {}", "error deserializing zero shot classification response: {}",
e e
); );
@ -324,7 +384,10 @@ impl StreamContext {
} }
}; };
debug!("zeroshot intent response: {:?}", zeroshot_intent_response); trace!(
"zeroshot intent response: {}",
serde_json::to_string(&zeroshot_intent_response).unwrap()
);
let desc_emb_similarity_map: HashMap<String, f64> = callout_context let desc_emb_similarity_map: HashMap<String, f64> = callout_context
.similarity_scores .similarity_scores
@ -362,7 +425,7 @@ impl StreamContext {
} }
} }
} else { } else {
info!("no assistant message found, probably first interaction"); debug!("no assistant message found, probably first interaction");
} }
// get prompt target similarity thresold from overrides // get prompt target similarity thresold from overrides
@ -382,15 +445,16 @@ impl StreamContext {
// if arch fc responded to the user message, then we don't need to check the similarity score // if arch fc responded to the user message, then we don't need to check the similarity score
// it may be that arch fc is handling the conversation for parameter collection // it may be that arch fc is handling the conversation for parameter collection
if arch_assistant { if arch_assistant {
info!("arch assistant is handling the conversation"); info!("arch fc is engaged in parameter collection");
} else { } else {
debug!("checking for default prompt target");
if let Some(default_prompt_target) = self if let Some(default_prompt_target) = self
.prompt_targets .prompt_targets
.values() .values()
.find(|pt| pt.default.unwrap_or(false)) .find(|pt| pt.default.unwrap_or(false))
{ {
debug!("default prompt target found"); debug!(
"default prompt target found, forwarding request to default prompt target"
);
let endpoint = default_prompt_target.endpoint.clone().unwrap(); let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
@ -401,8 +465,6 @@ impl StreamContext {
callout_context.request_body.messages.clone(), callout_context.request_body.messages.clone(),
); );
let arch_messages_json = serde_json::to_string(&params).unwrap(); let arch_messages_json = serde_json::to_string(&params).unwrap();
debug!("no prompt target found with similarity score above threshold, using default prompt target");
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![ let mut headers = vec![
@ -431,7 +493,7 @@ impl StreamContext {
callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) { if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching default prompt target request: {}", e); warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error( return self.send_server_error(
ServerError::HttpDispatch(e), ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST), Some(StatusCode::BAD_REQUEST),
@ -444,20 +506,12 @@ impl StreamContext {
} }
} }
let prompt_target = match self.prompt_targets.get(&prompt_target_name) { let prompt_target = self
Some(prompt_target) => prompt_target.clone(), .prompt_targets
None => { .get(&prompt_target_name)
debug!("prompt target not found: {}", prompt_target_name); .expect("prompt target not found")
return self.send_server_error( .clone();
ServerError::LogicError(format!(
"Prompt target not found: {prompt_target_name}"
)),
None,
);
}
};
info!("prompt_target name: {:?}", prompt_target_name);
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new(); let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.values() { for pt in self.prompt_targets.values() {
if pt.default.unwrap_or_default() { if pt.default.unwrap_or_default() {
@ -506,7 +560,12 @@ impl StreamContext {
); );
let chat_completions = ChatCompletionsRequest { let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(), model: self
.chat_completions_request
.as_ref()
.unwrap()
.model
.clone(),
messages: callout_context.request_body.messages.clone(), messages: callout_context.request_body.messages.clone(),
tools: Some(chat_completion_tools), tools: Some(chat_completion_tools),
stream: false, stream: false,
@ -515,12 +574,9 @@ impl StreamContext {
}; };
let msg_body = match serde_json::to_string(&chat_completions) { let msg_body = match serde_json::to_string(&chat_completions) {
Ok(msg_body) => { Ok(msg_body) => msg_body,
debug!("arch_fc request body content: {}", msg_body);
msg_body
}
Err(e) => { Err(e) => {
debug!("error serializing arch_fc request body: {}", e); warn!("error serializing arch_fc request body: {}", e);
return self.send_server_error(ServerError::Serialization(e), None); return self.send_server_error(ServerError::Serialization(e), None);
} }
}; };
@ -552,6 +608,7 @@ impl StreamContext {
callout_context.response_handler_type = ResponseHandlerType::ArchFC; callout_context.response_handler_type = ResponseHandlerType::ArchFC;
callout_context.prompt_target_name = Some(prompt_target.name); callout_context.prompt_target_name = Some(prompt_target.name);
debug!("archgw => archfc request: {}", msg_body);
if let Err(e) = self.http_call(call_args, callout_context) { if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching arch_fc request: {}", e); debug!("error dispatching arch_fc request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
@ -564,21 +621,28 @@ impl StreamContext {
mut callout_context: StreamCallContext, mut callout_context: StreamCallContext,
) { ) {
let body_str = String::from_utf8(body).unwrap(); let body_str = String::from_utf8(body).unwrap();
debug!("arch <= app response body: {}", body_str); debug!("archgw <= archfc response: {}", body_str);
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response, Ok(arch_fc_response) => arch_fc_response,
Err(e) => { Err(e) => {
debug!("error deserializing arch_fc response: {}", e); warn!("error deserializing archfc response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None); return self.send_server_error(ServerError::Deserialization(e), None);
} }
}; };
let model_resp = &arch_fc_response.choices[0]; arch_fc_response.choices[0]
.message
.tool_calls
.clone_into(&mut self.tool_calls);
if self.tool_calls.as_ref().unwrap().len() > 1 {
warn!(
"multiple tool calls not supported yet, tool_calls count found: {}",
self.tool_calls.as_ref().unwrap().len()
);
}
if model_resp.message.tool_calls.is_none() if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() {
|| model_resp.message.tool_calls.as_ref().unwrap().is_empty()
{
// This means that Arch FC did not have enough information to resolve the function call // This means that Arch FC did not have enough information to resolve the function call
// Arch FC probably responded with a message asking for more information. // Arch FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection // Let's send the response back to the user to initalize lightweight dialog for parameter collection
@ -592,121 +656,118 @@ impl StreamContext {
); );
} }
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
self.tool_calls = Some(tool_calls.clone());
// TODO CO: pass nli check // TODO CO: pass nli check
// If hallucination, pass chat template to check parameters let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let prompt_target = self
// extract all tool names .prompt_targets
let tool_names: Vec<String> = tool_calls .get(&tools_call_name)
.iter() .expect("prompt target not found for tool call")
.map(|tool_call| tool_call.function.name.clone()) .clone();
.collect();
debug!( debug!(
"call context similarity score: {:?}", "prompt_target_name: {}, tool_name(s): {:?}",
callout_context.similarity_scores prompt_target.name,
self.tool_calls
.as_ref()
.unwrap()
.iter()
.map(|tc| tc.function.name.clone())
.collect::<Vec<String>>(),
); );
// If hallucination, pass chat template to check parameters
//HACK: for now we only support one tool call, we will support multiple tool calls in the future //HACK: for now we only support one tool call, we will support multiple tool calls in the future
let mut tool_params = tool_calls[0].function.arguments.clone();
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
debug!(
"tool_params (without messages history): {}",
tool_params_json_str
);
tool_params.insert( tool_params.insert(
String::from(ARCH_MESSAGES_KEY), String::from(ARCH_MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
); );
let tools_call_name = tool_calls[0].function.name.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
callout_context.tool_calls = Some(tool_calls.clone()); use serde_json::Value;
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
let tool_params_dict: HashMap<String, String> = match v.as_object() {
Some(obj) => obj
.iter()
.map(|(key, value)| {
// Convert each value to a string, regardless of its type
(key.clone(), value.to_string())
})
.collect(),
None => HashMap::new(), // Return an empty HashMap if v is not an object
};
let all_user_messages =
extract_messages_for_hallucination(&callout_context.request_body.messages);
let user_messages_str = all_user_messages.join(", ");
debug!("user messages: {}", user_messages_str);
let hallucination_classification_request = HallucinationClassificationRequest {
prompt: user_messages_str,
model: String::from(DEFAULT_INTENT_MODEL),
parameters: tool_params_dict,
};
let hallucination_request_str: String =
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!(
"error serializing hallucination classification request: {}",
error
);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST),
(":method", "POST"),
(":path", "/hallucination"),
(":authority", HALLUCINATION_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/hallucination",
headers,
Some(hallucination_request_str.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::Hallucination;
debug!( debug!(
"prompt_target_name: {}, tool_name(s): {:?}", "archgw => hallucination request: {}",
prompt_target.name, tool_names hallucination_request_str
); );
debug!("tool_params: {}", tool_params_json_str); if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
if model_resp.message.tool_calls.is_some()
&& !model_resp.message.tool_calls.as_ref().unwrap().is_empty()
{
use serde_json::Value;
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
let tool_params_dict: HashMap<String, String> = match v.as_object() {
Some(obj) => obj
.iter()
.map(|(key, value)| {
// Convert each value to a string, regardless of its type
(key.clone(), value.to_string())
})
.collect(),
None => HashMap::new(), // Return an empty HashMap if v is not an object
};
let all_user_messages =
extract_messages_for_hallucination(&callout_context.request_body.messages);
let user_messages_str = all_user_messages.join(", ");
debug!("user messages: {}", user_messages_str);
let hallucination_classification_request = HallucinationClassificationRequest {
prompt: user_messages_str,
model: String::from(DEFAULT_INTENT_MODEL),
parameters: tool_params_dict,
};
let json_data: String =
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!(
"error serializing hallucination classification request: {}",
error
);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST),
(":method", "POST"),
(":path", "/hallucination"),
(":authority", HALLUCINATION_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/hallucination",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect;
if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
}
} else {
self.schedule_api_call_request(callout_context);
} }
} }
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) { fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0] let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
.function
.name
.clone();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
//HACK: for now we only support one tool call, we will support multiple tool calls in the future let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0]
.function .function
.arguments .arguments
.clone(); .clone();
@ -741,8 +802,16 @@ impl StreamContext {
vec![], vec![],
Duration::from_secs(5), Duration::from_secs(5),
); );
callout_context.upstream_cluster = Some(endpoint.name.clone());
callout_context.upstream_cluster_path = Some(path.clone()); debug!(
"archgw => api call, endpoint: {}/{}, body: {}",
endpoint.name.as_str(),
path,
tool_params_json_str
);
callout_context.upstream_cluster = Some(endpoint.name.to_owned());
callout_context.upstream_cluster_path = Some(path.to_owned());
callout_context.response_handler_type = ResponseHandlerType::FunctionCall; callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if let Err(e) = self.http_call(call_args, callout_context) { if let Err(e) = self.http_call(call_args, callout_context) {
@ -750,32 +819,29 @@ impl StreamContext {
} }
} }
pub fn function_call_response_handler( pub fn api_call_response_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
&mut self, let http_status = self
body: Vec<u8>, .get_http_call_response_header(":status")
callout_context: StreamCallContext, .expect("http status code not found");
) { if http_status != StatusCode::OK.as_str() {
if let Some(http_status) = self.get_http_call_response_header(":status") { warn!(
if http_status != StatusCode::OK.as_str() { "api server responded with non 2xx status code: {}",
debug!("upstream error response: {}", http_status); http_status
return self.send_server_error( );
ServerError::Upstream { return self.send_server_error(
host: callout_context.upstream_cluster.unwrap(), ServerError::Upstream {
path: callout_context.upstream_cluster_path.unwrap(), host: callout_context.upstream_cluster.unwrap(),
status: http_status.clone(), path: callout_context.upstream_cluster_path.unwrap(),
body: String::from_utf8(body).unwrap(), status: http_status.clone(),
}, body: String::from_utf8(body).unwrap(),
Some(StatusCode::from_str(http_status.as_str()).unwrap()), },
); Some(StatusCode::from_str(http_status.as_str()).unwrap()),
} );
} else {
warn!("http status code not found in api response");
} }
let app_function_call_response_str: String = String::from_utf8(body).unwrap(); self.tool_call_response = Some(String::from_utf8(body).unwrap());
self.tool_call_response = Some(app_function_call_response_str.clone());
debug!( debug!(
"arch <= app response body: {}", "archgw <= api call response: {}",
app_function_call_response_str self.tool_call_response.as_ref().unwrap()
); );
let prompt_target_name = callout_context.prompt_target_name.unwrap(); let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self let prompt_target = self
@ -825,7 +891,7 @@ impl StreamContext {
let final_prompt = format!( let final_prompt = format!(
"{}\ncontext: {}", "{}\ncontext: {}",
user_message.content.unwrap(), user_message.content.unwrap(),
app_function_call_response_str self.tool_call_response.as_ref().unwrap()
); );
// add original user prompt // add original user prompt
@ -848,22 +914,24 @@ impl StreamContext {
metadata: None, metadata: None,
}; };
let json_string = match serde_json::to_string(&chat_completions_request) { let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string, Ok(json_string) => json_string,
Err(e) => { Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None); return self.send_server_error(ServerError::Serialization(e), None);
} }
}; };
debug!("arch => upstream llm request body: {}", json_string); debug!("archgw => llm request: {}", llm_request_str);
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes()); self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes());
self.resume_http_request(); self.resume_http_request();
} }
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) { pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
debug!("response received for arch guard");
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!("prompt_guard_resp: {:?}", prompt_guard_resp); debug!(
"archgw <= archguard response: {:?}",
serde_json::to_string(&prompt_guard_resp)
);
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() { if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
//TODO: handle other scenarios like forward to error target //TODO: handle other scenarios like forward to error target
@ -871,7 +939,7 @@ impl StreamContext {
.prompt_guards .prompt_guards
.jailbreak_on_exception_message() .jailbreak_on_exception_message()
.unwrap_or("refrain from discussing jailbreaking."); .unwrap_or("refrain from discussing jailbreaking.");
debug!("jailbreak detected: {}", msg); warn!("jailbreak detected: {}", msg);
return self.send_server_error( return self.send_server_error(
ServerError::Jailbreak(String::from(msg)), ServerError::Jailbreak(String::from(msg)),
Some(StatusCode::BAD_REQUEST), Some(StatusCode::BAD_REQUEST),
@ -881,92 +949,27 @@ impl StreamContext {
self.get_embeddings(callout_context); self.get_embeddings(callout_context);
} }
pub fn get_embeddings(&mut self, callout_context: StreamCallContext) {
let user_message = callout_context.user_message.unwrap();
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing get embeddings request: {}", error);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", EMBEDDINGS_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::GetEmbeddings,
user_message: Some(user_message),
prompt_target_name: None,
request_body: callout_context.request_body,
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
tool_calls: None,
};
if let Err(e) = self.http_call(call_args, call_context) {
debug!("error dispatching get embeddings request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) { pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_target = self let prompt_target = self
.prompt_targets .prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap()) .get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap() .unwrap()
.clone(); .clone();
debug!(
"response received for default target: {}",
prompt_target.name
);
// check if the default target should be dispatched to the LLM provider // check if the default target should be dispatched to the LLM provider
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) { if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
let default_target_response_str = String::from_utf8(body).unwrap(); let default_target_response_str = String::from_utf8(body).unwrap();
debug!(
"sending response back to developer: {}",
default_target_response_str
);
self.send_http_response( self.send_http_response(
StatusCode::OK.as_u16().into(), StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")], vec![("Powered-By", "Katanemo")],
Some(default_target_response_str.as_bytes()), Some(default_target_response_str.as_bytes()),
); );
// self.resume_http_request();
return; return;
} }
debug!("default_target: sending api response to default llm");
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp, Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => { Err(e) => {
debug!("error deserializing default target response: {}", e); warn!("error deserializing default target response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None); return self.send_server_error(ServerError::Deserialization(e), None);
} }
}; };
@ -1000,7 +1003,12 @@ impl StreamContext {
tool_call_id: None, tool_call_id: None,
}); });
let chat_completion_request = ChatCompletionsRequest { let chat_completion_request = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(), model: self
.chat_completions_request
.as_ref()
.unwrap()
.model
.clone(),
messages, messages,
tools: None, tools: None,
stream: callout_context.request_body.stream, stream: callout_context.request_body.stream,
@ -1008,7 +1016,7 @@ impl StreamContext {
metadata: None, metadata: None,
}; };
let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp); debug!("archgw => (default target) llm request: {}", json_resp);
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
self.resume_http_request(); self.resume_http_request();
} }

View file

@ -33,7 +33,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
.returning(Some("/v1/chat/completions")) .returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None) .returning(None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None) .returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue)) .execute_and_expect(ReturnType::Action(Action::Continue))
@ -74,7 +74,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body)) .returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id // The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
Some("arch_internal"), Some("arch_internal"),
Some(vec![ Some(vec![
@ -92,6 +92,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
) )
.returning(Some(1)) .returning(Some(1))
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1) .expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause)) .execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap(); .unwrap();
@ -116,6 +117,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.returning(Some(&prompt_guard_response_buffer)) .returning(Some(&prompt_guard_response_buffer))
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
Some("arch_internal"), Some("arch_internal"),
Some(vec![ Some(vec![
@ -133,7 +135,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
) )
.returning(Some(2)) .returning(Some(2))
.expect_metric_increment("active_http_calls", 1) .expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None) .execute_and_expect(ReturnType::None)
.unwrap(); .unwrap();
@ -159,8 +160,9 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_metric_increment("active_http_calls", -1) .expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer)) .returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
Some("arch_internal"), Some("arch_internal"),
Some(vec![ Some(vec![
@ -178,7 +180,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
) )
.returning(Some(3)) .returning(Some(3))
.expect_metric_increment("active_http_calls", 1) .expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None) .execute_and_expect(ReturnType::None)
.unwrap(); .unwrap();
@ -200,9 +201,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_metric_increment("active_http_calls", -1) .expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&zeroshot_intent_detection_buffer)) .returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
Some("arch_internal"), Some("arch_internal"),
Some(vec![ Some(vec![
@ -219,8 +221,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
None, None,
) )
.returning(Some(4)) .returning(Some(4))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1) .expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None) .execute_and_expect(ReturnType::None)
.unwrap(); .unwrap();
@ -245,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module module
.call_proxy_on_tick(filter_context) .call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
Some("arch_internal"), Some("arch_internal"),
Some(vec![ Some(vec![
@ -426,8 +426,9 @@ fn successful_request_to_open_ai_chat_completions() {
) )
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body)) .returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None) .expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4)) .returning(Some(4))
.expect_metric_increment("active_http_calls", 1) .expect_metric_increment("active_http_calls", 1)
@ -486,13 +487,14 @@ fn bad_request_to_open_ai_chat_completions() {
) )
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body)) .returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_send_local_response( .expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()), Some(StatusCode::BAD_REQUEST.as_u16().into()),
None, None,
None, None,
None, None,
) )
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Pause)) .execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap(); .unwrap();
} }
@ -564,7 +566,7 @@ fn request_to_llm_gateway() {
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
Some("arch_internal"), Some("arch_internal"),
Some(vec![ Some(vec![
@ -603,6 +605,8 @@ fn request_to_llm_gateway() {
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text)) .returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
Some("arch_internal"), Some("arch_internal"),
Some(vec![ Some(vec![
@ -628,10 +632,10 @@ fn request_to_llm_gateway() {
.expect_metric_increment("active_http_calls", -1) .expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text)) .returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200")) .returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None) .execute_and_expect(ReturnType::None)
.unwrap(); .unwrap();
@ -664,11 +668,11 @@ fn request_to_llm_gateway() {
) )
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody)) .expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
.returning(Some(chat_completion_response_str.as_str())) .returning(Some(chat_completion_response_str.as_str()))
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Continue)) .execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap(); .unwrap();
} }