Improve logging (#209)

* improve logging

* fix int tests

* better

* fix more logs

* fix more

* fix int
This commit is contained in:
Adil Hafeez 2024-10-22 12:07:40 -07:00 committed by GitHub
parent 2f374df034
commit ea76d85b43
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 319 additions and 309 deletions

View file

@ -7,7 +7,6 @@ pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const TOOL_ROLE: &str = "tool";
pub const ASSISTANT_ROLE: &str = "assistant";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";

View file

@ -3,7 +3,7 @@ use crate::{
stats::{Gauge, IncrementingMetric},
};
use derivative::Derivative;
use log::debug;
use log::{debug, trace};
use proxy_wasm::{traits::Context, types::Status};
use serde::Serialize;
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
@ -48,9 +48,10 @@ pub trait Client: Context {
call_args: CallArgs,
call_context: Self::CallContext,
) -> Result<u32, ClientError> {
debug!(
trace!(
"dispatching http call with args={:?} context={:?}",
call_args, call_context
call_args,
call_context
);
match self.dispatch_http_call(

View file

@ -74,24 +74,15 @@ impl Context for StreamContext {
*/
if let Some(body) = self.get_http_call_response_body(0, body_size) {
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::GetEmbeddings => {
self.embeddings_handler(body, callout_context)
}
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::ZeroShotIntent => {
self.zero_shot_intent_detection_resp_handler(body, callout_context)
}
ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::HallucinationDetect => {
self.hallucination_classification_resp_handler(body, callout_context)
}
ResponseHandlerType::FunctionCall => {
self.function_call_response_handler(body, callout_context)
}
ResponseHandlerType::DefaultTarget => {
self.default_target_handler(body, callout_context)
}
ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
}
} else {
self.send_server_error(

View file

@ -16,7 +16,7 @@ use common::{
http::{CallArgs, Client},
};
use http::StatusCode;
use log::{debug, warn};
use log::{debug, trace, warn};
use proxy_wasm::{traits::HttpContext, types::Action};
use serde_json::Value;
@ -36,7 +36,7 @@ impl HttpContext for StreamContext {
self.is_chat_completions_request =
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
debug!(
trace!(
"on_http_request_headers S[{}] req_headers={:?}",
self.context_id,
self.get_http_request_headers()
@ -60,25 +60,14 @@ impl HttpContext for StreamContext {
self.request_body_size = body_size;
debug!(
trace!(
"on_http_request_body S[{}] body_size={}",
self.context_id, body_size
self.context_id,
body_size
);
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
},
let body_bytes = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => body_bytes,
None => {
self.send_server_error(
ServerError::LogicError(format!(
@ -91,6 +80,22 @@ impl HttpContext for StreamContext {
}
};
debug!("developer => archgw: {}", String::from_utf8_lossy(&body_bytes));
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
if metadata.contains_key(ARCH_STATE_HEADER) {
@ -145,7 +150,6 @@ impl HttpContext for StreamContext {
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
tool_calls: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
@ -201,7 +205,6 @@ impl HttpContext for StreamContext {
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
tool_calls: None,
};
if let Err(e) = self.http_call(call_args, call_context) {
@ -212,7 +215,7 @@ impl HttpContext for StreamContext {
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
debug!(
trace!(
"on_http_response_headers recv [S={}] headers={:?}",
self.context_id,
self.get_http_response_headers()
@ -224,9 +227,11 @@ impl HttpContext for StreamContext {
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
trace!(
"recv [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
self.context_id,
body_size,
end_of_stream
);
if !self.is_chat_completions_request {
@ -248,14 +253,14 @@ impl HttpContext for StreamContext {
.expect("cant get response body");
if self.streaming_response {
debug!("streaming response");
trace!("streaming response");
} else {
debug!("non streaming response");
trace!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
debug!(
trace!(
"invalid response: {}, {}",
String::from_utf8_lossy(&body),
e
@ -316,16 +321,18 @@ impl HttpContext for StreamContext {
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {}", data_serialized);
debug!("archgw <= developer: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
}
}
}
debug!(
trace!(
"recv [S={}] total_tokens={} end_stream={}",
self.context_id, self.response_tokens, end_of_stream
self.context_id,
self.response_tokens,
end_of_stream
);
Action::Continue

View file

@ -15,7 +15,7 @@ use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST,
HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
ZEROSHOT_INTERNAL_HOST,
};
@ -27,7 +27,7 @@ use common::http::{CallArgs, Client};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, warn};
use log::{debug, info, trace, warn};
use proxy_wasm::traits::*;
use std::cell::RefCell;
use std::collections::HashMap;
@ -37,11 +37,11 @@ use std::time::Duration;
#[derive(Debug, Clone)]
pub enum ResponseHandlerType {
GetEmbeddings,
Embeddings,
ArchFC,
FunctionCall,
ZeroShotIntent,
HallucinationDetect,
Hallucination,
ArchGuard,
DefaultTarget,
}
@ -54,7 +54,6 @@ pub struct StreamCallContext {
pub prompt_target_name: Option<String>,
#[derivative(Debug = "ignore")]
pub request_body: ChatCompletionsRequest,
pub tool_calls: Option<Vec<ToolCall>>,
pub similarity_scores: Option<Vec<(String, f64)>>,
pub upstream_cluster: Option<String>,
pub upstream_cluster_path: Option<String>,
@ -129,18 +128,77 @@ impl StreamContext {
);
}
pub fn get_embeddings(&mut self, callout_context: StreamCallContext) {
let user_message = callout_context.user_message.unwrap();
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let embeddings_request_str: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
warn!("error serializing get embeddings request: {}", error);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", EMBEDDINGS_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
headers,
Some(embeddings_request_str.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::Embeddings,
user_message: Some(user_message),
prompt_target_name: None,
request_body: callout_context.request_body,
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
debug!(
"archgw => get embeddings request: {}",
embeddings_request_str
);
if let Err(e) = self.http_call(call_args, call_context) {
warn!("error dispatching get embeddings request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
pub fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
debug!("error deserializing embedding response: {}", e);
warn!("error deserializing embedding response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let prompt_embeddings_vector = &embedding_response.data[0].embedding;
debug!(
trace!(
"embedding model: {}, vector length: {:?}",
embedding_response.model,
prompt_embeddings_vector.len()
@ -237,7 +295,7 @@ impl StreamContext {
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching zero shot classification request: {}", e);
warn!("error dispatching zero shot classification request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
@ -247,11 +305,13 @@ impl StreamContext {
body: Vec<u8>,
callout_context: StreamCallContext,
) {
let boyd_str = String::from_utf8(body).expect("could not convert body to string");
debug!("archgw <= hallucination response: {}", boyd_str);
let hallucination_response: HallucinationClassificationResponse =
match serde_json::from_slice(&body) {
match serde_json::from_str(boyd_str.as_str()) {
Ok(hallucination_response) => hallucination_response,
Err(e) => {
debug!("error deserializing hallucination response: {}", e);
warn!("error deserializing hallucination response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -291,7 +351,7 @@ impl StreamContext {
metadata: None,
};
debug!("hallucination response: {:?}", chat_completion_response);
trace!("hallucination response: {:?}", chat_completion_response);
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
@ -316,7 +376,7 @@ impl StreamContext {
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
debug!(
warn!(
"error deserializing zero shot classification response: {}",
e
);
@ -324,7 +384,10 @@ impl StreamContext {
}
};
debug!("zeroshot intent response: {:?}", zeroshot_intent_response);
trace!(
"zeroshot intent response: {}",
serde_json::to_string(&zeroshot_intent_response).unwrap()
);
let desc_emb_similarity_map: HashMap<String, f64> = callout_context
.similarity_scores
@ -362,7 +425,7 @@ impl StreamContext {
}
}
} else {
info!("no assistant message found, probably first interaction");
debug!("no assistant message found, probably first interaction");
}
// get prompt target similarity thresold from overrides
@ -382,15 +445,16 @@ impl StreamContext {
// if arch fc responded to the user message, then we don't need to check the similarity score
// it may be that arch fc is handling the conversation for parameter collection
if arch_assistant {
info!("arch assistant is handling the conversation");
info!("arch fc is engaged in parameter collection");
} else {
debug!("checking for default prompt target");
if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found");
debug!(
"default prompt target found, forwarding request to default prompt target"
);
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
@ -401,8 +465,6 @@ impl StreamContext {
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
debug!("no prompt target found with similarity score above threshold, using default prompt target");
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
@ -431,7 +493,7 @@ impl StreamContext {
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching default prompt target request: {}", e);
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
@ -444,20 +506,12 @@ impl StreamContext {
}
}
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
Some(prompt_target) => prompt_target.clone(),
None => {
debug!("prompt target not found: {}", prompt_target_name);
return self.send_server_error(
ServerError::LogicError(format!(
"Prompt target not found: {prompt_target_name}"
)),
None,
);
}
};
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.expect("prompt target not found")
.clone();
info!("prompt_target name: {:?}", prompt_target_name);
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.values() {
if pt.default.unwrap_or_default() {
@ -506,7 +560,12 @@ impl StreamContext {
);
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
model: self
.chat_completions_request
.as_ref()
.unwrap()
.model
.clone(),
messages: callout_context.request_body.messages.clone(),
tools: Some(chat_completion_tools),
stream: false,
@ -515,12 +574,9 @@ impl StreamContext {
};
let msg_body = match serde_json::to_string(&chat_completions) {
Ok(msg_body) => {
debug!("arch_fc request body content: {}", msg_body);
msg_body
}
Ok(msg_body) => msg_body,
Err(e) => {
debug!("error serializing arch_fc request body: {}", e);
warn!("error serializing arch_fc request body: {}", e);
return self.send_server_error(ServerError::Serialization(e), None);
}
};
@ -552,6 +608,7 @@ impl StreamContext {
callout_context.response_handler_type = ResponseHandlerType::ArchFC;
callout_context.prompt_target_name = Some(prompt_target.name);
debug!("archgw => archfc request: {}", msg_body);
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching arch_fc request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
@ -564,21 +621,28 @@ impl StreamContext {
mut callout_context: StreamCallContext,
) {
let body_str = String::from_utf8(body).unwrap();
debug!("arch <= app response body: {}", body_str);
debug!("archgw <= archfc response: {}", body_str);
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
debug!("error deserializing arch_fc response: {}", e);
warn!("error deserializing archfc response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let model_resp = &arch_fc_response.choices[0];
arch_fc_response.choices[0]
.message
.tool_calls
.clone_into(&mut self.tool_calls);
if self.tool_calls.as_ref().unwrap().len() > 1 {
warn!(
"multiple tool calls not supported yet, tool_calls count found: {}",
self.tool_calls.as_ref().unwrap().len()
);
}
if model_resp.message.tool_calls.is_none()
|| model_resp.message.tool_calls.as_ref().unwrap().is_empty()
{
if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() {
// This means that Arch FC did not have enough information to resolve the function call
// Arch FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
@ -592,43 +656,43 @@ impl StreamContext {
);
}
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
self.tool_calls = Some(tool_calls.clone());
// TODO CO: pass nli check
// If hallucination, pass chat template to check parameters
// extract all tool names
let tool_names: Vec<String> = tool_calls
.iter()
.map(|tool_call| tool_call.function.name.clone())
.collect();
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let prompt_target = self
.prompt_targets
.get(&tools_call_name)
.expect("prompt target not found for tool call")
.clone();
debug!(
"call context similarity score: {:?}",
callout_context.similarity_scores
"prompt_target_name: {}, tool_name(s): {:?}",
prompt_target.name,
self.tool_calls
.as_ref()
.unwrap()
.iter()
.map(|tc| tc.function.name.clone())
.collect::<Vec<String>>(),
);
// If hallucination, pass chat template to check parameters
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
let mut tool_params = tool_calls[0].function.arguments.clone();
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
debug!(
"tool_params (without messages history): {}",
tool_params_json_str
);
tool_params.insert(
String::from(ARCH_MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
);
let tools_call_name = tool_calls[0].function.name.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
callout_context.tool_calls = Some(tool_calls.clone());
debug!(
"prompt_target_name: {}, tool_name(s): {:?}",
prompt_target.name, tool_names
);
debug!("tool_params: {}", tool_params_json_str);
if model_resp.message.tool_calls.is_some()
&& !model_resp.message.tool_calls.as_ref().unwrap().is_empty()
{
use serde_json::Value;
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
let tool_params_dict: HashMap<String, String> = match v.as_object() {
@ -653,7 +717,7 @@ impl StreamContext {
parameters: tool_params_dict,
};
let json_data: String =
let hallucination_request_str: String =
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
@ -683,30 +747,27 @@ impl StreamContext {
ARCH_INTERNAL_CLUSTER_NAME,
"/hallucination",
headers,
Some(json_data.as_bytes()),
Some(hallucination_request_str.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect;
callout_context.response_handler_type = ResponseHandlerType::Hallucination;
debug!(
"archgw => hallucination request: {}",
hallucination_request_str
);
if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
}
} else {
self.schedule_api_call_request(callout_context);
}
}
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0]
.function
.name
.clone();
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0]
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
@ -741,8 +802,16 @@ impl StreamContext {
vec![],
Duration::from_secs(5),
);
callout_context.upstream_cluster = Some(endpoint.name.clone());
callout_context.upstream_cluster_path = Some(path.clone());
debug!(
"archgw => api call, endpoint: {}/{}, body: {}",
endpoint.name.as_str(),
path,
tool_params_json_str
);
callout_context.upstream_cluster = Some(endpoint.name.to_owned());
callout_context.upstream_cluster_path = Some(path.to_owned());
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if let Err(e) = self.http_call(call_args, callout_context) {
@ -750,14 +819,15 @@ impl StreamContext {
}
}
pub fn function_call_response_handler(
&mut self,
body: Vec<u8>,
callout_context: StreamCallContext,
) {
if let Some(http_status) = self.get_http_call_response_header(":status") {
pub fn api_call_response_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let http_status = self
.get_http_call_response_header(":status")
.expect("http status code not found");
if http_status != StatusCode::OK.as_str() {
debug!("upstream error response: {}", http_status);
warn!(
"api server responded with non 2xx status code: {}",
http_status
);
return self.send_server_error(
ServerError::Upstream {
host: callout_context.upstream_cluster.unwrap(),
@ -768,14 +838,10 @@ impl StreamContext {
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
} else {
warn!("http status code not found in api response");
}
let app_function_call_response_str: String = String::from_utf8(body).unwrap();
self.tool_call_response = Some(app_function_call_response_str.clone());
self.tool_call_response = Some(String::from_utf8(body).unwrap());
debug!(
"arch <= app response body: {}",
app_function_call_response_str
"archgw <= api call response: {}",
self.tool_call_response.as_ref().unwrap()
);
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
@ -825,7 +891,7 @@ impl StreamContext {
let final_prompt = format!(
"{}\ncontext: {}",
user_message.content.unwrap(),
app_function_call_response_str
self.tool_call_response.as_ref().unwrap()
);
// add original user prompt
@ -848,22 +914,24 @@ impl StreamContext {
metadata: None,
};
let json_string = match serde_json::to_string(&chat_completions_request) {
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!("arch => upstream llm request body: {}", json_string);
debug!("archgw => llm request: {}", llm_request_str);
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes());
self.resume_http_request();
}
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
debug!("response received for arch guard");
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
debug!(
"archgw <= archguard response: {:?}",
serde_json::to_string(&prompt_guard_resp)
);
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
//TODO: handle other scenarios like forward to error target
@ -871,7 +939,7 @@ impl StreamContext {
.prompt_guards
.jailbreak_on_exception_message()
.unwrap_or("refrain from discussing jailbreaking.");
debug!("jailbreak detected: {}", msg);
warn!("jailbreak detected: {}", msg);
return self.send_server_error(
ServerError::Jailbreak(String::from(msg)),
Some(StatusCode::BAD_REQUEST),
@ -881,92 +949,27 @@ impl StreamContext {
self.get_embeddings(callout_context);
}
pub fn get_embeddings(&mut self, callout_context: StreamCallContext) {
let user_message = callout_context.user_message.unwrap();
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing get embeddings request: {}", error);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", EMBEDDINGS_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::GetEmbeddings,
user_message: Some(user_message),
prompt_target_name: None,
request_body: callout_context.request_body,
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
tool_calls: None,
};
if let Err(e) = self.http_call(call_args, call_context) {
debug!("error dispatching get embeddings request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_target = self
.prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();
debug!(
"response received for default target: {}",
prompt_target.name
);
// check if the default target should be dispatched to the LLM provider
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
let default_target_response_str = String::from_utf8(body).unwrap();
debug!(
"sending response back to developer: {}",
default_target_response_str
);
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(default_target_response_str.as_bytes()),
);
// self.resume_http_request();
return;
}
debug!("default_target: sending api response to default llm");
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
debug!("error deserializing default target response: {}", e);
warn!("error deserializing default target response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -1000,7 +1003,12 @@ impl StreamContext {
tool_call_id: None,
});
let chat_completion_request = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
model: self
.chat_completions_request
.as_ref()
.unwrap()
.model
.clone(),
messages,
tools: None,
stream: callout_context.request_body.stream,
@ -1008,7 +1016,7 @@ impl StreamContext {
metadata: None,
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp);
debug!("archgw => (default target) llm request: {}", json_resp);
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
self.resume_http_request();
}

View file

@ -33,7 +33,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
.returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue))
@ -74,7 +74,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -92,6 +92,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
)
.returning(Some(1))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
@ -116,6 +117,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.returning(Some(&prompt_guard_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -133,7 +135,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -159,8 +160,9 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -178,7 +180,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
)
.returning(Some(3))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -200,9 +201,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -219,8 +221,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
None,
)
.returning(Some(4))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -245,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module
.call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -426,8 +426,9 @@ fn successful_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
@ -486,13 +487,14 @@ fn bad_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
None,
None,
)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}
@ -564,7 +566,7 @@ fn request_to_llm_gateway() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -603,6 +605,8 @@ fn request_to_llm_gateway() {
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -628,10 +632,10 @@ fn request_to_llm_gateway() {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -664,11 +668,11 @@ fn request_to_llm_gateway() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
.returning(Some(chat_completion_response_str.as_str()))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}