From ea76d85b436865c5ce60c43ba8b3d1c85a0bd81a Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 22 Oct 2024 12:07:40 -0700 Subject: [PATCH] Improve logging (#209) * improve logging * fix int tests * better * fix more logs * fix more * fix int --- crates/common/src/consts.rs | 1 - crates/common/src/http.rs | 7 +- crates/prompt_gateway/src/context.rs | 21 +- crates/prompt_gateway/src/http_context.rs | 71 +-- crates/prompt_gateway/src/stream_context.rs | 492 ++++++++++---------- crates/prompt_gateway/tests/integration.rs | 36 +- 6 files changed, 319 insertions(+), 309 deletions(-) diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index fe67f6a8..8f59e981 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -7,7 +7,6 @@ pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; pub const TOOL_ROLE: &str = "tool"; pub const ASSISTANT_ROLE: &str = "assistant"; -pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot"; diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index f5d66a65..0e035624 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -3,7 +3,7 @@ use crate::{ stats::{Gauge, IncrementingMetric}, }; use derivative::Derivative; -use log::debug; +use log::{debug, trace}; use proxy_wasm::{traits::Context, types::Status}; use serde::Serialize; use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration}; @@ -48,9 +48,10 @@ pub trait Client: Context { call_args: CallArgs, call_context: Self::CallContext, ) -> Result { - debug!( + trace!( "dispatching http call with args={:?} context={:?}", - call_args, call_context + call_args, + call_context ); match self.dispatch_http_call( diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs index f33d6238..2df70374 100644 --- a/crates/prompt_gateway/src/context.rs +++ b/crates/prompt_gateway/src/context.rs @@ -74,24 +74,15 @@ impl Context for StreamContext { */ if let Some(body) = self.get_http_call_response_body(0, body_size) { + #[cfg_attr(any(), rustfmt::skip)] match callout_context.response_handler_type { - ResponseHandlerType::GetEmbeddings => { - self.embeddings_handler(body, callout_context) - } ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), - ResponseHandlerType::ZeroShotIntent => { - self.zero_shot_intent_detection_resp_handler(body, callout_context) - } + ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context), + ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context), ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context), - ResponseHandlerType::HallucinationDetect => { - self.hallucination_classification_resp_handler(body, callout_context) - } - ResponseHandlerType::FunctionCall => { - self.function_call_response_handler(body, callout_context) - } - ResponseHandlerType::DefaultTarget => { - self.default_target_handler(body, callout_context) - } + ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context), + ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context), + ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context), } } else { self.send_server_error( diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 04fc976e..da0e69c9 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -16,7 +16,7 @@ use common::{ http::{CallArgs, Client}, }; use http::StatusCode; -use log::{debug, warn}; +use log::{debug, trace, warn}; use proxy_wasm::{traits::HttpContext, types::Action}; use serde_json::Value; @@ -36,7 +36,7 @@ impl HttpContext for StreamContext { self.is_chat_completions_request = self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; - debug!( + trace!( "on_http_request_headers S[{}] req_headers={:?}", self.context_id, self.get_http_request_headers() @@ -60,32 +60,37 @@ impl HttpContext for StreamContext { self.request_body_size = body_size; - debug!( + trace!( "on_http_request_body S[{}] body_size={}", - self.context_id, body_size + self.context_id, + body_size ); + let body_bytes = match self.get_http_request_body(0, body_size) { + Some(body_bytes) => body_bytes, + None => { + self.send_server_error( + ServerError::LogicError(format!( + "Failed to obtain body bytes even though body_size is {}", + body_size + )), + None, + ); + return Action::Pause; + } + }; + + debug!("developer => archgw: {}", String::from_utf8_lossy(&body_bytes)); + // Deserialize body into spec. // Currently OpenAI API. let mut deserialized_body: ChatCompletionsRequest = - match self.get_http_request_body(0, body_size) { - Some(body_bytes) => match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(e) => { - self.send_server_error( - ServerError::Deserialization(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }, - None => { + match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(e) => { self.send_server_error( - ServerError::LogicError(format!( - "Failed to obtain body bytes even though body_size is {}", - body_size - )), - None, + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), ); return Action::Pause; } @@ -145,7 +150,6 @@ impl HttpContext for StreamContext { similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, - tool_calls: None, }; self.get_embeddings(callout_context); return Action::Pause; @@ -201,7 +205,6 @@ impl HttpContext for StreamContext { similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, - tool_calls: None, }; if let Err(e) = self.http_call(call_args, call_context) { @@ -212,7 +215,7 @@ impl HttpContext for StreamContext { } fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - debug!( + trace!( "on_http_response_headers recv [S={}] headers={:?}", self.context_id, self.get_http_response_headers() @@ -224,9 +227,11 @@ impl HttpContext for StreamContext { } fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - debug!( + trace!( "recv [S={}] bytes={} end_stream={}", - self.context_id, body_size, end_of_stream + self.context_id, + body_size, + end_of_stream ); if !self.is_chat_completions_request { @@ -248,14 +253,14 @@ impl HttpContext for StreamContext { .expect("cant get response body"); if self.streaming_response { - debug!("streaming response"); + trace!("streaming response"); } else { - debug!("non streaming response"); + trace!("non streaming response"); let chat_completions_response: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { - debug!( + trace!( "invalid response: {}, {}", String::from_utf8_lossy(&body), e @@ -316,16 +321,18 @@ impl HttpContext for StreamContext { serde_json::Value::String(arch_state_str), ); let data_serialized = serde_json::to_string(&data).unwrap(); - debug!("arch => user: {}", data_serialized); + debug!("archgw <= developer: {}", data_serialized); self.set_http_response_body(0, body_size, data_serialized.as_bytes()); }; } } } - debug!( + trace!( "recv [S={}] total_tokens={} end_stream={}", - self.context_id, self.response_tokens, end_of_stream + self.context_id, + self.response_tokens, + end_of_stream ); Action::Continue diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 9b99409c..dbc991a6 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -15,7 +15,7 @@ use common::consts::{ ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, - DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, + DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST, }; @@ -27,7 +27,7 @@ use common::http::{CallArgs, Client}; use common::stats::Gauge; use derivative::Derivative; use http::StatusCode; -use log::{debug, info, warn}; +use log::{debug, info, trace, warn}; use proxy_wasm::traits::*; use std::cell::RefCell; use std::collections::HashMap; @@ -37,11 +37,11 @@ use std::time::Duration; #[derive(Debug, Clone)] pub enum ResponseHandlerType { - GetEmbeddings, + Embeddings, ArchFC, FunctionCall, ZeroShotIntent, - HallucinationDetect, + Hallucination, ArchGuard, DefaultTarget, } @@ -54,7 +54,6 @@ pub struct StreamCallContext { pub prompt_target_name: Option, #[derivative(Debug = "ignore")] pub request_body: ChatCompletionsRequest, - pub tool_calls: Option>, pub similarity_scores: Option>, pub upstream_cluster: Option, pub upstream_cluster_path: Option, @@ -129,18 +128,77 @@ impl StreamContext { ); } + pub fn get_embeddings(&mut self, callout_context: StreamCallContext) { + let user_message = callout_context.user_message.unwrap(); + let get_embeddings_input = CreateEmbeddingRequest { + // Need to clone into input because user_message is used below. + input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())), + model: String::from(DEFAULT_EMBEDDING_MODEL), + encoding_format: None, + dimensions: None, + user: None, + }; + + let embeddings_request_str: String = match serde_json::to_string(&get_embeddings_input) { + Ok(json_data) => json_data, + Err(error) => { + warn!("error serializing get embeddings request: {}", error); + return self.send_server_error(ServerError::Deserialization(error), None); + } + }; + + let mut headers = vec![ + (ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST), + (":method", "POST"), + (":path", "/embeddings"), + (":authority", EMBEDDINGS_INTERNAL_HOST), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]; + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + } + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + "/embeddings", + headers, + Some(embeddings_request_str.as_bytes()), + vec![], + Duration::from_secs(5), + ); + let call_context = StreamCallContext { + response_handler_type: ResponseHandlerType::Embeddings, + user_message: Some(user_message), + prompt_target_name: None, + request_body: callout_context.request_body, + similarity_scores: None, + upstream_cluster: None, + upstream_cluster_path: None, + }; + + debug!( + "archgw => get embeddings request: {}", + embeddings_request_str + ); + if let Err(e) = self.http_call(call_args, call_context) { + warn!("error dispatching get embeddings request: {}", e); + self.send_server_error(ServerError::HttpDispatch(e), None); + } + } + pub fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { - debug!("error deserializing embedding response: {}", e); + warn!("error deserializing embedding response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; let prompt_embeddings_vector = &embedding_response.data[0].embedding; - debug!( + trace!( "embedding model: {}, vector length: {:?}", embedding_response.model, prompt_embeddings_vector.len() @@ -237,7 +295,7 @@ impl StreamContext { callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; if let Err(e) = self.http_call(call_args, callout_context) { - debug!("error dispatching zero shot classification request: {}", e); + warn!("error dispatching zero shot classification request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -247,11 +305,13 @@ impl StreamContext { body: Vec, callout_context: StreamCallContext, ) { + let boyd_str = String::from_utf8(body).expect("could not convert body to string"); + debug!("archgw <= hallucination response: {}", boyd_str); let hallucination_response: HallucinationClassificationResponse = - match serde_json::from_slice(&body) { + match serde_json::from_str(boyd_str.as_str()) { Ok(hallucination_response) => hallucination_response, Err(e) => { - debug!("error deserializing hallucination response: {}", e); + warn!("error deserializing hallucination response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -291,7 +351,7 @@ impl StreamContext { metadata: None, }; - debug!("hallucination response: {:?}", chat_completion_response); + trace!("hallucination response: {:?}", chat_completion_response); self.send_http_response( StatusCode::OK.as_u16().into(), vec![("Powered-By", "Katanemo")], @@ -316,7 +376,7 @@ impl StreamContext { match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { - debug!( + warn!( "error deserializing zero shot classification response: {}", e ); @@ -324,7 +384,10 @@ impl StreamContext { } }; - debug!("zeroshot intent response: {:?}", zeroshot_intent_response); + trace!( + "zeroshot intent response: {}", + serde_json::to_string(&zeroshot_intent_response).unwrap() + ); let desc_emb_similarity_map: HashMap = callout_context .similarity_scores @@ -362,7 +425,7 @@ impl StreamContext { } } } else { - info!("no assistant message found, probably first interaction"); + debug!("no assistant message found, probably first interaction"); } // get prompt target similarity thresold from overrides @@ -382,15 +445,16 @@ impl StreamContext { // if arch fc responded to the user message, then we don't need to check the similarity score // it may be that arch fc is handling the conversation for parameter collection if arch_assistant { - info!("arch assistant is handling the conversation"); + info!("arch fc is engaged in parameter collection"); } else { - debug!("checking for default prompt target"); if let Some(default_prompt_target) = self .prompt_targets .values() .find(|pt| pt.default.unwrap_or(false)) { - debug!("default prompt target found"); + debug!( + "default prompt target found, forwarding request to default prompt target" + ); let endpoint = default_prompt_target.endpoint.clone().unwrap(); let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); @@ -401,8 +465,6 @@ impl StreamContext { callout_context.request_body.messages.clone(), ); let arch_messages_json = serde_json::to_string(¶ms).unwrap(); - debug!("no prompt target found with similarity score above threshold, using default prompt target"); - let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); let mut headers = vec![ @@ -431,7 +493,7 @@ impl StreamContext { callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); if let Err(e) = self.http_call(call_args, callout_context) { - debug!("error dispatching default prompt target request: {}", e); + warn!("error dispatching default prompt target request: {}", e); return self.send_server_error( ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST), @@ -444,20 +506,12 @@ impl StreamContext { } } - let prompt_target = match self.prompt_targets.get(&prompt_target_name) { - Some(prompt_target) => prompt_target.clone(), - None => { - debug!("prompt target not found: {}", prompt_target_name); - return self.send_server_error( - ServerError::LogicError(format!( - "Prompt target not found: {prompt_target_name}" - )), - None, - ); - } - }; + let prompt_target = self + .prompt_targets + .get(&prompt_target_name) + .expect("prompt target not found") + .clone(); - info!("prompt_target name: {:?}", prompt_target_name); let mut chat_completion_tools: Vec = Vec::new(); for pt in self.prompt_targets.values() { if pt.default.unwrap_or_default() { @@ -506,7 +560,12 @@ impl StreamContext { ); let chat_completions = ChatCompletionsRequest { - model: GPT_35_TURBO.to_string(), + model: self + .chat_completions_request + .as_ref() + .unwrap() + .model + .clone(), messages: callout_context.request_body.messages.clone(), tools: Some(chat_completion_tools), stream: false, @@ -515,12 +574,9 @@ impl StreamContext { }; let msg_body = match serde_json::to_string(&chat_completions) { - Ok(msg_body) => { - debug!("arch_fc request body content: {}", msg_body); - msg_body - } + Ok(msg_body) => msg_body, Err(e) => { - debug!("error serializing arch_fc request body: {}", e); + warn!("error serializing arch_fc request body: {}", e); return self.send_server_error(ServerError::Serialization(e), None); } }; @@ -552,6 +608,7 @@ impl StreamContext { callout_context.response_handler_type = ResponseHandlerType::ArchFC; callout_context.prompt_target_name = Some(prompt_target.name); + debug!("archgw => archfc request: {}", msg_body); if let Err(e) = self.http_call(call_args, callout_context) { debug!("error dispatching arch_fc request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); @@ -564,21 +621,28 @@ impl StreamContext { mut callout_context: StreamCallContext, ) { let body_str = String::from_utf8(body).unwrap(); - debug!("arch <= app response body: {}", body_str); + debug!("archgw <= archfc response: {}", body_str); let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { - debug!("error deserializing arch_fc response: {}", e); + warn!("error deserializing archfc response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; - let model_resp = &arch_fc_response.choices[0]; + arch_fc_response.choices[0] + .message + .tool_calls + .clone_into(&mut self.tool_calls); + if self.tool_calls.as_ref().unwrap().len() > 1 { + warn!( + "multiple tool calls not supported yet, tool_calls count found: {}", + self.tool_calls.as_ref().unwrap().len() + ); + } - if model_resp.message.tool_calls.is_none() - || model_resp.message.tool_calls.as_ref().unwrap().is_empty() - { + if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() { // This means that Arch FC did not have enough information to resolve the function call // Arch FC probably responded with a message asking for more information. // Let's send the response back to the user to initalize lightweight dialog for parameter collection @@ -592,121 +656,118 @@ impl StreamContext { ); } - let tool_calls = model_resp.message.tool_calls.as_ref().unwrap(); - self.tool_calls = Some(tool_calls.clone()); - // TODO CO: pass nli check - // If hallucination, pass chat template to check parameters - - // extract all tool names - let tool_names: Vec = tool_calls - .iter() - .map(|tool_call| tool_call.function.name.clone()) - .collect(); + let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone(); + let prompt_target = self + .prompt_targets + .get(&tools_call_name) + .expect("prompt target not found for tool call") + .clone(); debug!( - "call context similarity score: {:?}", - callout_context.similarity_scores + "prompt_target_name: {}, tool_name(s): {:?}", + prompt_target.name, + self.tool_calls + .as_ref() + .unwrap() + .iter() + .map(|tc| tc.function.name.clone()) + .collect::>(), ); + + // If hallucination, pass chat template to check parameters //HACK: for now we only support one tool call, we will support multiple tool calls in the future - let mut tool_params = tool_calls[0].function.arguments.clone(); + + let mut tool_params = self.tool_calls.as_ref().unwrap()[0] + .function + .arguments + .clone(); + let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); + debug!( + "tool_params (without messages history): {}", + tool_params_json_str + ); tool_params.insert( String::from(ARCH_MESSAGES_KEY), serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), ); - - let tools_call_name = tool_calls[0].function.name.clone(); let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); - let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - callout_context.tool_calls = Some(tool_calls.clone()); + + use serde_json::Value; + let v: Value = serde_json::from_str(&tool_params_json_str).unwrap(); + let tool_params_dict: HashMap = match v.as_object() { + Some(obj) => obj + .iter() + .map(|(key, value)| { + // Convert each value to a string, regardless of its type + (key.clone(), value.to_string()) + }) + .collect(), + None => HashMap::new(), // Return an empty HashMap if v is not an object + }; + + let all_user_messages = + extract_messages_for_hallucination(&callout_context.request_body.messages); + let user_messages_str = all_user_messages.join(", "); + debug!("user messages: {}", user_messages_str); + + let hallucination_classification_request = HallucinationClassificationRequest { + prompt: user_messages_str, + model: String::from(DEFAULT_INTENT_MODEL), + parameters: tool_params_dict, + }; + + let hallucination_request_str: String = + match serde_json::to_string(&hallucination_classification_request) { + Ok(json_data) => json_data, + Err(error) => { + debug!( + "error serializing hallucination classification request: {}", + error + ); + return self.send_server_error(ServerError::Serialization(error), None); + } + }; + + let mut headers = vec![ + (ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST), + (":method", "POST"), + (":path", "/hallucination"), + (":authority", HALLUCINATION_INTERNAL_HOST), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]; + + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + } + + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + "/hallucination", + headers, + Some(hallucination_request_str.as_bytes()), + vec![], + Duration::from_secs(5), + ); + callout_context.response_handler_type = ResponseHandlerType::Hallucination; debug!( - "prompt_target_name: {}, tool_name(s): {:?}", - prompt_target.name, tool_names + "archgw => hallucination request: {}", + hallucination_request_str ); - debug!("tool_params: {}", tool_params_json_str); - - if model_resp.message.tool_calls.is_some() - && !model_resp.message.tool_calls.as_ref().unwrap().is_empty() - { - use serde_json::Value; - let v: Value = serde_json::from_str(&tool_params_json_str).unwrap(); - let tool_params_dict: HashMap = match v.as_object() { - Some(obj) => obj - .iter() - .map(|(key, value)| { - // Convert each value to a string, regardless of its type - (key.clone(), value.to_string()) - }) - .collect(), - None => HashMap::new(), // Return an empty HashMap if v is not an object - }; - - let all_user_messages = - extract_messages_for_hallucination(&callout_context.request_body.messages); - let user_messages_str = all_user_messages.join(", "); - debug!("user messages: {}", user_messages_str); - - let hallucination_classification_request = HallucinationClassificationRequest { - prompt: user_messages_str, - model: String::from(DEFAULT_INTENT_MODEL), - parameters: tool_params_dict, - }; - - let json_data: String = - match serde_json::to_string(&hallucination_classification_request) { - Ok(json_data) => json_data, - Err(error) => { - debug!( - "error serializing hallucination classification request: {}", - error - ); - return self.send_server_error(ServerError::Serialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", HALLUCINATION_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/hallucination", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect; - - if let Err(e) = self.http_call(call_args, callout_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } else { - self.schedule_api_call_request(callout_context); + if let Err(e) = self.http_call(call_args, callout_context) { + self.send_server_error(ServerError::HttpDispatch(e), None); } } fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) { - let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0] - .function - .name - .clone(); + let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone(); let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - //HACK: for now we only support one tool call, we will support multiple tool calls in the future - let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0] + let mut tool_params = self.tool_calls.as_ref().unwrap()[0] .function .arguments .clone(); @@ -741,8 +802,16 @@ impl StreamContext { vec![], Duration::from_secs(5), ); - callout_context.upstream_cluster = Some(endpoint.name.clone()); - callout_context.upstream_cluster_path = Some(path.clone()); + + debug!( + "archgw => api call, endpoint: {}/{}, body: {}", + endpoint.name.as_str(), + path, + tool_params_json_str + ); + + callout_context.upstream_cluster = Some(endpoint.name.to_owned()); + callout_context.upstream_cluster_path = Some(path.to_owned()); callout_context.response_handler_type = ResponseHandlerType::FunctionCall; if let Err(e) = self.http_call(call_args, callout_context) { @@ -750,32 +819,29 @@ impl StreamContext { } } - pub fn function_call_response_handler( - &mut self, - body: Vec, - callout_context: StreamCallContext, - ) { - if let Some(http_status) = self.get_http_call_response_header(":status") { - if http_status != StatusCode::OK.as_str() { - debug!("upstream error response: {}", http_status); - return self.send_server_error( - ServerError::Upstream { - host: callout_context.upstream_cluster.unwrap(), - path: callout_context.upstream_cluster_path.unwrap(), - status: http_status.clone(), - body: String::from_utf8(body).unwrap(), - }, - Some(StatusCode::from_str(http_status.as_str()).unwrap()), - ); - } - } else { - warn!("http status code not found in api response"); + pub fn api_call_response_handler(&mut self, body: Vec, callout_context: StreamCallContext) { + let http_status = self + .get_http_call_response_header(":status") + .expect("http status code not found"); + if http_status != StatusCode::OK.as_str() { + warn!( + "api server responded with non 2xx status code: {}", + http_status + ); + return self.send_server_error( + ServerError::Upstream { + host: callout_context.upstream_cluster.unwrap(), + path: callout_context.upstream_cluster_path.unwrap(), + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), + }, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), + ); } - let app_function_call_response_str: String = String::from_utf8(body).unwrap(); - self.tool_call_response = Some(app_function_call_response_str.clone()); + self.tool_call_response = Some(String::from_utf8(body).unwrap()); debug!( - "arch <= app response body: {}", - app_function_call_response_str + "archgw <= api call response: {}", + self.tool_call_response.as_ref().unwrap() ); let prompt_target_name = callout_context.prompt_target_name.unwrap(); let prompt_target = self @@ -825,7 +891,7 @@ impl StreamContext { let final_prompt = format!( "{}\ncontext: {}", user_message.content.unwrap(), - app_function_call_response_str + self.tool_call_response.as_ref().unwrap() ); // add original user prompt @@ -848,22 +914,24 @@ impl StreamContext { metadata: None, }; - let json_string = match serde_json::to_string(&chat_completions_request) { + let llm_request_str = match serde_json::to_string(&chat_completions_request) { Ok(json_string) => json_string, Err(e) => { return self.send_server_error(ServerError::Serialization(e), None); } }; - debug!("arch => upstream llm request body: {}", json_string); + debug!("archgw => llm request: {}", llm_request_str); - self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes()); + self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes()); self.resume_http_request(); } pub fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { - debug!("response received for arch guard"); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); - debug!("prompt_guard_resp: {:?}", prompt_guard_resp); + debug!( + "archgw <= archguard response: {:?}", + serde_json::to_string(&prompt_guard_resp) + ); if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() { //TODO: handle other scenarios like forward to error target @@ -871,7 +939,7 @@ impl StreamContext { .prompt_guards .jailbreak_on_exception_message() .unwrap_or("refrain from discussing jailbreaking."); - debug!("jailbreak detected: {}", msg); + warn!("jailbreak detected: {}", msg); return self.send_server_error( ServerError::Jailbreak(String::from(msg)), Some(StatusCode::BAD_REQUEST), @@ -881,92 +949,27 @@ impl StreamContext { self.get_embeddings(callout_context); } - pub fn get_embeddings(&mut self, callout_context: StreamCallContext) { - let user_message = callout_context.user_message.unwrap(); - let get_embeddings_input = CreateEmbeddingRequest { - // Need to clone into input because user_message is used below. - input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())), - model: String::from(DEFAULT_EMBEDDING_MODEL), - encoding_format: None, - dimensions: None, - user: None, - }; - - let json_data: String = match serde_json::to_string(&get_embeddings_input) { - Ok(json_data) => json_data, - Err(error) => { - debug!("error serializing get embeddings request: {}", error); - return self.send_server_error(ServerError::Deserialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", EMBEDDINGS_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/embeddings", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::GetEmbeddings, - user_message: Some(user_message), - prompt_target_name: None, - request_body: callout_context.request_body, - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - - if let Err(e) = self.http_call(call_args, call_context) { - debug!("error dispatching get embeddings request: {}", e); - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } - pub fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { let prompt_target = self .prompt_targets .get(callout_context.prompt_target_name.as_ref().unwrap()) .unwrap() .clone(); - debug!( - "response received for default target: {}", - prompt_target.name - ); + // check if the default target should be dispatched to the LLM provider if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) { let default_target_response_str = String::from_utf8(body).unwrap(); - debug!( - "sending response back to developer: {}", - default_target_response_str - ); self.send_http_response( StatusCode::OK.as_u16().into(), vec![("Powered-By", "Katanemo")], Some(default_target_response_str.as_bytes()), ); - // self.resume_http_request(); return; } - debug!("default_target: sending api response to default llm"); let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(chat_completions_resp) => chat_completions_resp, Err(e) => { - debug!("error deserializing default target response: {}", e); + warn!("error deserializing default target response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -1000,7 +1003,12 @@ impl StreamContext { tool_call_id: None, }); let chat_completion_request = ChatCompletionsRequest { - model: GPT_35_TURBO.to_string(), + model: self + .chat_completions_request + .as_ref() + .unwrap() + .model + .clone(), messages, tools: None, stream: callout_context.request_body.stream, @@ -1008,7 +1016,7 @@ impl StreamContext { metadata: None, }; let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); - debug!("sending response back to default llm: {}", json_resp); + debug!("archgw => (default target) llm request: {}", json_resp); self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); self.resume_http_request(); } diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 7f7322f8..27eac427 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -33,7 +33,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { .returning(Some("/v1/chat/completions")) .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) .returning(None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) .returning(None) .execute_and_expect(ReturnType::Action(Action::Continue)) @@ -74,7 +74,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -92,6 +92,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { ) .returning(Some(1)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -116,6 +117,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .returning(Some(&prompt_guard_response_buffer)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -133,7 +135,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { ) .returning(Some(2)) .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -159,8 +160,9 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&embeddings_response_buffer)) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -178,7 +180,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { ) .returning(Some(3)) .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -200,9 +201,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&zeroshot_intent_detection_buffer)) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -219,8 +221,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { None, ) .returning(Some(4)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); @@ -245,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { module .call_proxy_on_tick(filter_context) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -426,8 +426,9 @@ fn successful_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call(Some("arch_internal"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) @@ -486,13 +487,14 @@ fn bad_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), None, None, None, ) + .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); } @@ -564,7 +566,7 @@ fn request_to_llm_gateway() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -603,6 +605,8 @@ fn request_to_llm_gateway() { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -628,10 +632,10 @@ fn request_to_llm_gateway() { .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) + .expect_log(Some(LogLevel::Debug), None) .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) .returning(Some("200")) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -664,11 +668,11 @@ fn request_to_llm_gateway() { ) .expect_get_buffer_bytes(Some(BufferType::HttpResponseBody)) .returning(Some(chat_completion_response_str.as_str())) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); }