From 2f374df0345e61cd4917923f65c7fe942c8a2702 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 21 Oct 2024 15:04:15 -0700 Subject: [PATCH 1/3] refactor prompt gateway (#204) --- crates/common/src/errors.rs | 5 +- crates/common/src/http.rs | 5 +- crates/common/src/lib.rs | 2 +- crates/prompt_gateway/src/context.rs | 103 +++++ crates/prompt_gateway/src/filter_context.rs | 2 +- crates/prompt_gateway/src/hallucination.rs | 12 +- crates/prompt_gateway/src/http_context.rs | 333 ++++++++++++++ crates/prompt_gateway/src/lib.rs | 4 +- crates/prompt_gateway/src/stream_context.rs | 475 ++------------------ 9 files changed, 500 insertions(+), 441 deletions(-) create mode 100644 crates/prompt_gateway/src/context.rs create mode 100644 crates/prompt_gateway/src/http_context.rs diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index fd634915..27b0341e 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -22,11 +22,12 @@ pub enum ServerError { Serialization(serde_json::Error), #[error("{0}")] LogicError(String), - #[error("upstream error response authority={authority}, path={path}, status={status}")] + #[error("upstream application error host={host}, path={path}, status={status}, body={body}")] Upstream { - authority: String, + host: String, path: String, status: String, + body: String, }, #[error("jailbreak detected: {0}")] Jailbreak(String), diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index 842818e2..f5d66a65 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -1,4 +1,7 @@ -use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}}; +use crate::{ + errors::ClientError, + stats::{Gauge, IncrementingMetric}, +}; use derivative::Derivative; use log::debug; use proxy_wasm::{traits::Context, types::Status}; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index c23443ca..f2c95bc5 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -4,10 +4,10 @@ pub mod common_types; pub mod configuration; pub mod consts; pub mod embeddings; +pub mod errors; pub mod http; pub mod llm_providers; pub mod ratelimit; pub mod routing; pub mod stats; pub mod tokenizer; -pub mod errors; diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs new file mode 100644 index 00000000..f33d6238 --- /dev/null +++ b/crates/prompt_gateway/src/context.rs @@ -0,0 +1,103 @@ +use common::errors::ServerError; +use common::stats::IncrementingMetric; +use proxy_wasm::traits::Context; + +use crate::stream_context::{ResponseHandlerType, StreamContext}; + +impl Context for StreamContext { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + body_size: usize, + _num_trailers: usize, + ) { + let callout_context = self + .callouts + .get_mut() + .remove(&token_id) + .expect("invalid token_id"); + self.metrics.active_http_calls.increment(-1); + + /* + state transition + + graph LR + + on_http_request_body --> prompt received + prompt received --> get embeddings & arch guard + arch guard --> get embeddings + get embeddings --> zeroshot intent + + ┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐ + │ │ │ │ │ │ │ │ + │ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │ + │ │ │ │ │ │ │ │ + └──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘ + │ ▲ + │ │ + │ │ + │ ┌────────┴───────┐ + │ │ │ + └───────────►│ arch guard │ + │ │ + └────────────────┘ + + + continue from zeroshot intent + + graph LR + + zeroshot intent --> arch_fc + zeroshot intent --> default prompt target + arch_fc --> developer api call & hallucination check + hallucination check --> parameter gathering & developer api call + developer api call --> resume request to llm + + + ┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐ + │ │ │ │ │ │ │ │ + │ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │ + │ │ │ │ │ │ │ │ + └────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘ + │ │ ▲ + │ └─────────────┐ │ + │ │ │ + │ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐ + │ │ │ │ │ │ │ │ + └───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │ + │ │ │ │ │ │ + └───────────────────────┘ └─────────────────────┘ └───────────────────────┘ + + + using https://mermaid-ascii.art/ + */ + + if let Some(body) = self.get_http_call_response_body(0, body_size) { + match callout_context.response_handler_type { + ResponseHandlerType::GetEmbeddings => { + self.embeddings_handler(body, callout_context) + } + ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), + ResponseHandlerType::ZeroShotIntent => { + self.zero_shot_intent_detection_resp_handler(body, callout_context) + } + ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context), + ResponseHandlerType::HallucinationDetect => { + self.hallucination_classification_resp_handler(body, callout_context) + } + ResponseHandlerType::FunctionCall => { + self.function_call_response_handler(body, callout_context) + } + ResponseHandlerType::DefaultTarget => { + self.default_target_handler(body, callout_context) + } + } + } else { + self.send_server_error( + ServerError::LogicError(String::from("No response body in inline HTTP request")), + None, + ); + } + } +} diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index b60191a5..3f1d3f0d 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,9 +1,9 @@ use crate::stream_context::StreamContext; use common::common_types::EmbeddingType; -use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST}; use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget}; use common::consts::ARCH_UPSTREAM_HOST_HEADER; use common::consts::DEFAULT_EMBEDDING_MODEL; +use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST}; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; diff --git a/crates/prompt_gateway/src/hallucination.rs b/crates/prompt_gateway/src/hallucination.rs index 99d487ab..71d1e7cf 100644 --- a/crates/prompt_gateway/src/hallucination.rs +++ b/crates/prompt_gateway/src/hallucination.rs @@ -1,12 +1,12 @@ use common::{common_types::open_ai::Message, consts::USER_ROLE}; -pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec { +pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { let all_user_messages = messages .iter() .filter(|m| m.role == USER_ROLE) .map(|m| m.content.as_ref().unwrap().clone()) .collect::>(); - return all_user_messages; + all_user_messages } #[cfg(test)] @@ -17,7 +17,7 @@ mod test { #[test] fn test_hallucination_message() { - let test_str = r#" + let test_str = r#" [ { "role": "system", @@ -32,8 +32,8 @@ mod test { ] "#; - let messages: Vec = serde_json::from_str(test_str).unwrap(); - let messages_for_halluncination = extract_messages_for_hallucination(&messages); - assert_eq!(messages_for_halluncination.len(), 2); + let messages: Vec = serde_json::from_str(test_str).unwrap(); + let messages_for_halluncination = extract_messages_for_hallucination(&messages); + assert_eq!(messages_for_halluncination.len(), 2); } } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs new file mode 100644 index 00000000..04fc976e --- /dev/null +++ b/crates/prompt_gateway/src/http_context.rs @@ -0,0 +1,333 @@ +use std::{collections::HashMap, time::Duration}; + +use common::{ + common_types::{ + open_ai::{ + ArchState, ChatCompletionsRequest, ChatCompletionsResponse, Message, StreamOptions, + }, + PromptGuardRequest, PromptGuardTask, + }, + consts::{ + ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER, + ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST, + REQUEST_ID_HEADER, TOOL_ROLE, USER_ROLE, + }, + errors::ServerError, + http::{CallArgs, Client}, +}; +use http::StatusCode; +use log::{debug, warn}; +use proxy_wasm::{traits::HttpContext, types::Action}; +use serde_json::Value; + +use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext}; + +// HttpContext is the trait that allows the Rust code to interact with HTTP objects. +impl HttpContext for StreamContext { + // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto + // the lifecycle of the http request and response. + fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. + // Server's generally throw away requests whose body length do not match the Content-Length header. + // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could + // manipulate the body in benign ways e.g., compression. + self.set_http_request_header("content-length", None); + + self.is_chat_completions_request = + self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; + + debug!( + "on_http_request_headers S[{}] req_headers={:?}", + self.context_id, + self.get_http_request_headers() + ); + + self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); + + Action::Continue + } + + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + // Let the client send the gateway all the data before sending to the LLM_provider. + // TODO: consider a streaming API. + if !end_of_stream { + return Action::Pause; + } + + if body_size == 0 { + return Action::Continue; + } + + self.request_body_size = body_size; + + debug!( + "on_http_request_body S[{}] body_size={}", + self.context_id, body_size + ); + + // Deserialize body into spec. + // Currently OpenAI API. + let mut deserialized_body: ChatCompletionsRequest = + match self.get_http_request_body(0, body_size) { + Some(body_bytes) => match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + self.send_server_error( + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }, + None => { + self.send_server_error( + ServerError::LogicError(format!( + "Failed to obtain body bytes even though body_size is {}", + body_size + )), + None, + ); + return Action::Pause; + } + }; + + self.arch_state = match deserialized_body.metadata { + Some(ref metadata) => { + if metadata.contains_key(ARCH_STATE_HEADER) { + let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); + let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); + Some(arch_state) + } else { + None + } + } + None => None, + }; + + self.streaming_response = deserialized_body.stream; + if deserialized_body.stream && deserialized_body.stream_options.is_none() { + deserialized_body.stream_options = Some(StreamOptions { + include_usage: true, + }); + } + + let last_user_prompt = match deserialized_body + .messages + .iter() + .filter(|msg| msg.role == USER_ROLE) + .last() + { + Some(content) => content, + None => { + warn!("No messages in the request body"); + return Action::Continue; + } + }; + + self.user_prompt = Some(last_user_prompt.clone()); + + let user_message_str = self.user_prompt.as_ref().unwrap().content.clone(); + + let prompt_guard_jailbreak_task = self + .prompt_guards + .input_guards + .contains_key(&common::configuration::GuardType::Jailbreak); + + self.chat_completions_request = Some(deserialized_body); + + if !prompt_guard_jailbreak_task { + debug!("Missing input guard. Making inline call to retrieve embeddings"); + let callout_context = StreamCallContext { + response_handler_type: ResponseHandlerType::ArchGuard, + user_message: user_message_str.clone(), + prompt_target_name: None, + request_body: self.chat_completions_request.as_ref().unwrap().clone(), + similarity_scores: None, + upstream_cluster: None, + upstream_cluster_path: None, + tool_calls: None, + }; + self.get_embeddings(callout_context); + return Action::Pause; + } + + let get_prompt_guards_request = PromptGuardRequest { + input: self + .user_prompt + .as_ref() + .unwrap() + .content + .as_ref() + .unwrap() + .clone(), + task: PromptGuardTask::Jailbreak, + }; + + let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { + Ok(json_data) => json_data, + Err(error) => { + self.send_server_error(ServerError::Serialization(error), None); + return Action::Pause; + } + }; + + let mut headers = vec![ + (ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST), + (":method", "POST"), + (":path", "/guard"), + (":authority", GUARD_INTERNAL_HOST), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]; + + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + } + + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + "/guard", + headers, + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ); + let call_context = StreamCallContext { + response_handler_type: ResponseHandlerType::ArchGuard, + user_message: self.user_prompt.as_ref().unwrap().content.clone(), + prompt_target_name: None, + request_body: self.chat_completions_request.as_ref().unwrap().clone(), + similarity_scores: None, + upstream_cluster: None, + upstream_cluster_path: None, + tool_calls: None, + }; + + if let Err(e) = self.http_call(call_args, call_context) { + self.send_server_error(ServerError::HttpDispatch(e), None); + } + + Action::Pause + } + + fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + debug!( + "on_http_response_headers recv [S={}] headers={:?}", + self.context_id, + self.get_http_response_headers() + ); + // delete content-lenght header let envoy calculate it, because we modify the response body + // that would result in a different content-length + self.set_http_response_header("content-length", None); + Action::Continue + } + + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + debug!( + "recv [S={}] bytes={} end_stream={}", + self.context_id, body_size, end_of_stream + ); + + if !self.is_chat_completions_request { + if let Some(body_str) = self + .get_http_response_body(0, body_size) + .and_then(|bytes| String::from_utf8(bytes).ok()) + { + debug!("recv [S={}] body_str={}", self.context_id, body_str); + } + return Action::Continue; + } + + if !end_of_stream { + return Action::Pause; + } + + let body = self + .get_http_response_body(0, body_size) + .expect("cant get response body"); + + if self.streaming_response { + debug!("streaming response"); + } else { + debug!("non streaming response"); + let chat_completions_response: ChatCompletionsResponse = + match serde_json::from_slice(&body) { + Ok(de) => de, + Err(e) => { + debug!( + "invalid response: {}, {}", + String::from_utf8_lossy(&body), + e + ); + return Action::Continue; + } + }; + + if chat_completions_response.usage.is_some() { + self.response_tokens += chat_completions_response + .usage + .as_ref() + .unwrap() + .completion_tokens; + } + + if let Some(tool_calls) = self.tool_calls.as_ref() { + if !tool_calls.is_empty() { + if self.arch_state.is_none() { + self.arch_state = Some(Vec::new()); + } + + let mut data = serde_json::from_slice(&body).unwrap(); + // use serde::Value to manipulate the json object and ensure that we don't lose any data + if let Value::Object(ref mut map) = data { + // serialize arch state and add to metadata + let metadata = map + .entry("metadata") + .or_insert(Value::Object(serde_json::Map::new())); + if metadata == &Value::Null { + *metadata = Value::Object(serde_json::Map::new()); + } + + // since arch gateway generates tool calls (using arch-fc) and calls upstream api to + // get response, we will send these back to developer so they can see the api response + // and tool call arch-fc generated + let fc_messages = vec![ + Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + }, + Message { + role: TOOL_ROLE.to_string(), + content: self.tool_call_response.clone(), + model: None, + tool_calls: None, + tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), + }, + ]; + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); + let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); + let arch_state_str = serde_json::to_string(&arch_state).unwrap(); + metadata.as_object_mut().unwrap().insert( + ARCH_STATE_HEADER.to_string(), + serde_json::Value::String(arch_state_str), + ); + let data_serialized = serde_json::to_string(&data).unwrap(); + debug!("arch => user: {}", data_serialized); + self.set_http_response_body(0, body_size, data_serialized.as_bytes()); + }; + } + } + } + + debug!( + "recv [S={}] total_tokens={} end_stream={}", + self.context_id, self.response_tokens, end_of_stream + ); + + Action::Continue + } +} diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index 7ca26e44..f873b9bf 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -2,9 +2,11 @@ use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; +mod context; mod filter_context; -mod stream_context; mod hallucination; +mod http_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 0364cd87..9b99409c 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -3,47 +3,42 @@ use crate::hallucination::extract_messages_for_hallucination; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, - FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, - StreamOptions, ToolCall, ToolType, + FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall, + ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, - PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, - ZeroShotClassificationResponse, + PromptGuardResponse, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; use common::configuration::{Overrides, PromptGuards, PromptTarget}; use common::consts::{ ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, - ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, - DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, - EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, - REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST, + ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, + DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, + HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, + ZEROSHOT_INTERNAL_HOST, }; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use common::errors::ClientError; +use common::errors::ServerError; use common::http::{CallArgs, Client}; use common::stats::Gauge; use derivative::Derivative; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::traits::*; -use proxy_wasm::types::*; -use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; use std::str::FromStr; use std::time::Duration; -use common::stats::IncrementingMetric; - #[derive(Debug, Clone)] -enum ResponseHandlerType { +pub enum ResponseHandlerType { GetEmbeddings, - FunctionResolver, + ArchFC, FunctionCall, ZeroShotIntent, HallucinationDetect, @@ -54,59 +49,36 @@ enum ResponseHandlerType { #[derive(Clone, Derivative)] #[derivative(Debug)] pub struct StreamCallContext { - response_handler_type: ResponseHandlerType, - user_message: Option, - prompt_target_name: Option, + pub response_handler_type: ResponseHandlerType, + pub user_message: Option, + pub prompt_target_name: Option, #[derivative(Debug = "ignore")] - request_body: ChatCompletionsRequest, - tool_calls: Option>, - similarity_scores: Option>, - upstream_cluster: Option, - upstream_cluster_path: Option, -} - -#[derive(thiserror::Error, Debug)] -pub enum ServerError { - #[error(transparent)] - HttpDispatch(ClientError), - #[error(transparent)] - Deserialization(serde_json::Error), - #[error(transparent)] - Serialization(serde_json::Error), - #[error("{0}")] - LogicError(String), - #[error("upstream application error host={host}, path={path}, status={status}, body={body}")] - Upstream { - host: String, - path: String, - status: String, - body: String, - }, - #[error("jailbreak detected: {0}")] - Jailbreak(String), - #[error("{why}")] - NoMessagesFound { why: String }, + pub request_body: ChatCompletionsRequest, + pub tool_calls: Option>, + pub similarity_scores: Option>, + pub upstream_cluster: Option, + pub upstream_cluster_path: Option, } pub struct StreamContext { - context_id: u32, - metrics: Rc, system_prompt: Rc>, prompt_targets: Rc>, embeddings_store: Option>, overrides: Rc>, - callouts: RefCell>, - tool_calls: Option>, - tool_call_response: Option, - arch_state: Option>, - request_body_size: usize, - streaming_response: bool, - user_prompt: Option, - response_tokens: usize, - is_chat_completions_request: bool, - chat_completions_request: Option, - prompt_guards: Rc, - request_id: Option, + pub metrics: Rc, + pub callouts: RefCell>, + pub context_id: u32, + pub tool_calls: Option>, + pub tool_call_response: Option, + pub arch_state: Option>, + pub request_body_size: usize, + pub streaming_response: bool, + pub user_prompt: Option, + pub response_tokens: usize, + pub is_chat_completions_request: bool, + pub chat_completions_request: Option, + pub prompt_guards: Rc, + pub request_id: Option, } impl StreamContext { @@ -146,15 +118,7 @@ impl StreamContext { .expect("embeddings store is not set") } - fn delete_content_length_header(&mut self) { - // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. - // Server's generally throw away requests whose body length do not match the Content-Length header. - // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could - // manipulate the body in benign ways e.g., compression. - self.set_http_request_header("content-length", None); - } - - fn send_server_error(&self, error: ServerError, override_status_code: Option) { + pub fn send_server_error(&self, error: ServerError, override_status_code: Option) { self.send_http_response( override_status_code .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) @@ -165,7 +129,7 @@ impl StreamContext { ); } - fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { + pub fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { @@ -278,7 +242,7 @@ impl StreamContext { } } - fn hallucination_classification_resp_handler( + pub fn hallucination_classification_resp_handler( &mut self, body: Vec, callout_context: StreamCallContext, @@ -343,7 +307,7 @@ impl StreamContext { } } - fn zero_shot_intent_detection_resp_handler( + pub fn zero_shot_intent_detection_resp_handler( &mut self, body: Vec, mut callout_context: StreamCallContext, @@ -585,7 +549,7 @@ impl StreamContext { vec![], Duration::from_secs(5), ); - callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; + callout_context.response_handler_type = ResponseHandlerType::ArchFC; callout_context.prompt_target_name = Some(prompt_target.name); if let Err(e) = self.http_call(call_args, callout_context) { @@ -594,7 +558,11 @@ impl StreamContext { } } - fn function_resolver_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { + pub fn arch_fc_response_handler( + &mut self, + body: Vec, + mut callout_context: StreamCallContext, + ) { let body_str = String::from_utf8(body).unwrap(); debug!("arch <= app response body: {}", body_str); @@ -782,7 +750,7 @@ impl StreamContext { } } - fn function_call_response_handler( + pub fn function_call_response_handler( &mut self, body: Vec, callout_context: StreamCallContext, @@ -892,7 +860,7 @@ impl StreamContext { self.resume_http_request(); } - fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { + pub fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { debug!("response received for arch guard"); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); debug!("prompt_guard_resp: {:?}", prompt_guard_resp); @@ -913,7 +881,7 @@ impl StreamContext { self.get_embeddings(callout_context); } - fn get_embeddings(&mut self, callout_context: StreamCallContext) { + pub fn get_embeddings(&mut self, callout_context: StreamCallContext) { let user_message = callout_context.user_message.unwrap(); let get_embeddings_input = CreateEmbeddingRequest { // Need to clone into input because user_message is used below. @@ -969,7 +937,7 @@ impl StreamContext { } } - fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { + pub fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { let prompt_target = self .prompt_targets .get(callout_context.prompt_target_name.as_ref().unwrap()) @@ -1046,357 +1014,6 @@ impl StreamContext { } } -// HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for StreamContext { - // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto - // the lifecycle of the http request and response. - fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - self.delete_content_length_header(); - - self.is_chat_completions_request = - self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; - - debug!( - "on_http_request_headers S[{}] req_headers={:?}", - self.context_id, - self.get_http_request_headers() - ); - - self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); - - Action::Continue - } - - fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - // Let the client send the gateway all the data before sending to the LLM_provider. - // TODO: consider a streaming API. - if !end_of_stream { - return Action::Pause; - } - - if body_size == 0 { - return Action::Continue; - } - - self.request_body_size = body_size; - - debug!( - "on_http_request_body S[{}] body_size={}", - self.context_id, body_size - ); - - // Deserialize body into spec. - // Currently OpenAI API. - let mut deserialized_body: ChatCompletionsRequest = - match self.get_http_request_body(0, body_size) { - Some(body_bytes) => match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(e) => { - self.send_server_error( - ServerError::Deserialization(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }, - None => { - self.send_server_error( - ServerError::LogicError(format!( - "Failed to obtain body bytes even though body_size is {}", - body_size - )), - None, - ); - return Action::Pause; - } - }; - - self.arch_state = match deserialized_body.metadata { - Some(ref metadata) => { - if metadata.contains_key(ARCH_STATE_HEADER) { - let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); - let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); - Some(arch_state) - } else { - None - } - } - None => None, - }; - - self.streaming_response = deserialized_body.stream; - if deserialized_body.stream && deserialized_body.stream_options.is_none() { - deserialized_body.stream_options = Some(StreamOptions { - include_usage: true, - }); - } - - let last_user_prompt = match deserialized_body - .messages - .iter() - .filter(|msg| msg.role == USER_ROLE) - .last() - { - Some(content) => content, - None => { - warn!("No messages in the request body"); - return Action::Continue; - } - }; - - self.user_prompt = Some(last_user_prompt.clone()); - - let user_message_str = self.user_prompt.as_ref().unwrap().content.clone(); - - let prompt_guard_jailbreak_task = self - .prompt_guards - .input_guards - .contains_key(&common::configuration::GuardType::Jailbreak); - - self.chat_completions_request = Some(deserialized_body); - - if !prompt_guard_jailbreak_task { - debug!("Missing input guard. Making inline call to retrieve embeddings"); - let callout_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: user_message_str.clone(), - prompt_target_name: None, - request_body: self.chat_completions_request.as_ref().unwrap().clone(), - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - self.get_embeddings(callout_context); - return Action::Pause; - } - - let get_prompt_guards_request = PromptGuardRequest { - input: self - .user_prompt - .as_ref() - .unwrap() - .content - .as_ref() - .unwrap() - .clone(), - task: PromptGuardTask::Jailbreak, - }; - - let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { - Ok(json_data) => json_data, - Err(error) => { - self.send_server_error(ServerError::Serialization(error), None); - return Action::Pause; - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST), - (":method", "POST"), - (":path", "/guard"), - (":authority", GUARD_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/guard", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: self.user_prompt.as_ref().unwrap().content.clone(), - prompt_target_name: None, - request_body: self.chat_completions_request.as_ref().unwrap().clone(), - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - - if let Err(e) = self.http_call(call_args, call_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } - - Action::Pause - } - - fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - debug!( - "on_http_response_headers recv [S={}] headers={:?}", - self.context_id, - self.get_http_response_headers() - ); - // delete content-lenght header let envoy calculate it, because we modify the response body - // that would result in a different content-length - self.set_http_response_header("content-length", None); - Action::Continue - } - - fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - debug!( - "recv [S={}] bytes={} end_stream={}", - self.context_id, body_size, end_of_stream - ); - - if !self.is_chat_completions_request { - if let Some(body_str) = self - .get_http_response_body(0, body_size) - .and_then(|bytes| String::from_utf8(bytes).ok()) - { - debug!("recv [S={}] body_str={}", self.context_id, body_str); - } - return Action::Continue; - } - - if !end_of_stream { - return Action::Pause; - } - - let body = self - .get_http_response_body(0, body_size) - .expect("cant get response body"); - - if self.streaming_response { - debug!("streaming response"); - } else { - debug!("non streaming response"); - let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_slice(&body) { - Ok(de) => de, - Err(e) => { - debug!( - "invalid response: {}, {}", - String::from_utf8_lossy(&body), - e - ); - return Action::Continue; - } - }; - - if chat_completions_response.usage.is_some() { - self.response_tokens += chat_completions_response - .usage - .as_ref() - .unwrap() - .completion_tokens; - } - - if let Some(tool_calls) = self.tool_calls.as_ref() { - if !tool_calls.is_empty() { - if self.arch_state.is_none() { - self.arch_state = Some(Vec::new()); - } - - let mut data: Value = serde_json::from_slice(&body).unwrap(); - // use serde::Value to manipulate the json object and ensure that we don't lose any data - if let Value::Object(ref mut map) = data { - // serialize arch state and add to metadata - let metadata = map - .entry("metadata") - .or_insert(Value::Object(serde_json::Map::new())); - if metadata == &Value::Null { - *metadata = Value::Object(serde_json::Map::new()); - } - - // since arch gateway generates tool calls (using arch-fc) and calls upstream api to - // get response, we will send these back to developer so they can see the api response - // and tool call arch-fc generated - let mut fc_messages = Vec::new(); - fc_messages.push(Message { - role: ASSISTANT_ROLE.to_string(), - content: None, - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: self.tool_calls.clone(), - tool_call_id: None, - }); - fc_messages.push(Message { - role: TOOL_ROLE.to_string(), - content: self.tool_call_response.clone(), - model: None, - tool_calls: None, - tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), - }); - let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); - let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); - let arch_state_str = serde_json::to_string(&arch_state).unwrap(); - metadata.as_object_mut().unwrap().insert( - ARCH_STATE_HEADER.to_string(), - serde_json::Value::String(arch_state_str), - ); - let data_serialized = serde_json::to_string(&data).unwrap(); - debug!("arch => user: {}", data_serialized); - self.set_http_response_body(0, body_size, data_serialized.as_bytes()); - }; - } - } - } - - debug!( - "recv [S={}] total_tokens={} end_stream={}", - self.context_id, self.response_tokens, end_of_stream - ); - - Action::Continue - } -} - -impl Context for StreamContext { - fn on_http_call_response( - &mut self, - token_id: u32, - _num_headers: usize, - body_size: usize, - _num_trailers: usize, - ) { - let callout_context = self - .callouts - .get_mut() - .remove(&token_id) - .expect("invalid token_id"); - self.metrics.active_http_calls.increment(-1); - - if let Some(body) = self.get_http_call_response_body(0, body_size) { - match callout_context.response_handler_type { - ResponseHandlerType::GetEmbeddings => { - self.embeddings_handler(body, callout_context) - } - ResponseHandlerType::ZeroShotIntent => { - self.zero_shot_intent_detection_resp_handler(body, callout_context) - } - ResponseHandlerType::HallucinationDetect => { - self.hallucination_classification_resp_handler(body, callout_context) - } - ResponseHandlerType::FunctionResolver => { - self.function_resolver_handler(body, callout_context) - } - ResponseHandlerType::FunctionCall => { - self.function_call_response_handler(body, callout_context) - } - ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), - ResponseHandlerType::DefaultTarget => { - self.default_target_handler(body, callout_context) - } - } - } else { - self.send_server_error( - ServerError::LogicError(String::from("No response body in inline HTTP request")), - None, - ); - } - } -} - impl Client for StreamContext { type CallContext = StreamCallContext; From ea76d85b436865c5ce60c43ba8b3d1c85a0bd81a Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 22 Oct 2024 12:07:40 -0700 Subject: [PATCH 2/3] Improve logging (#209) * improve logging * fix int tests * better * fix more logs * fix more * fix int --- crates/common/src/consts.rs | 1 - crates/common/src/http.rs | 7 +- crates/prompt_gateway/src/context.rs | 21 +- crates/prompt_gateway/src/http_context.rs | 71 +-- crates/prompt_gateway/src/stream_context.rs | 492 ++++++++++---------- crates/prompt_gateway/tests/integration.rs | 36 +- 6 files changed, 319 insertions(+), 309 deletions(-) diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index fe67f6a8..8f59e981 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -7,7 +7,6 @@ pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; pub const TOOL_ROLE: &str = "tool"; pub const ASSISTANT_ROLE: &str = "assistant"; -pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot"; diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index f5d66a65..0e035624 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -3,7 +3,7 @@ use crate::{ stats::{Gauge, IncrementingMetric}, }; use derivative::Derivative; -use log::debug; +use log::{debug, trace}; use proxy_wasm::{traits::Context, types::Status}; use serde::Serialize; use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration}; @@ -48,9 +48,10 @@ pub trait Client: Context { call_args: CallArgs, call_context: Self::CallContext, ) -> Result { - debug!( + trace!( "dispatching http call with args={:?} context={:?}", - call_args, call_context + call_args, + call_context ); match self.dispatch_http_call( diff --git a/crates/prompt_gateway/src/context.rs b/crates/prompt_gateway/src/context.rs index f33d6238..2df70374 100644 --- a/crates/prompt_gateway/src/context.rs +++ b/crates/prompt_gateway/src/context.rs @@ -74,24 +74,15 @@ impl Context for StreamContext { */ if let Some(body) = self.get_http_call_response_body(0, body_size) { + #[cfg_attr(any(), rustfmt::skip)] match callout_context.response_handler_type { - ResponseHandlerType::GetEmbeddings => { - self.embeddings_handler(body, callout_context) - } ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), - ResponseHandlerType::ZeroShotIntent => { - self.zero_shot_intent_detection_resp_handler(body, callout_context) - } + ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context), + ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context), ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context), - ResponseHandlerType::HallucinationDetect => { - self.hallucination_classification_resp_handler(body, callout_context) - } - ResponseHandlerType::FunctionCall => { - self.function_call_response_handler(body, callout_context) - } - ResponseHandlerType::DefaultTarget => { - self.default_target_handler(body, callout_context) - } + ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context), + ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context), + ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context), } } else { self.send_server_error( diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 04fc976e..da0e69c9 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -16,7 +16,7 @@ use common::{ http::{CallArgs, Client}, }; use http::StatusCode; -use log::{debug, warn}; +use log::{debug, trace, warn}; use proxy_wasm::{traits::HttpContext, types::Action}; use serde_json::Value; @@ -36,7 +36,7 @@ impl HttpContext for StreamContext { self.is_chat_completions_request = self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; - debug!( + trace!( "on_http_request_headers S[{}] req_headers={:?}", self.context_id, self.get_http_request_headers() @@ -60,32 +60,37 @@ impl HttpContext for StreamContext { self.request_body_size = body_size; - debug!( + trace!( "on_http_request_body S[{}] body_size={}", - self.context_id, body_size + self.context_id, + body_size ); + let body_bytes = match self.get_http_request_body(0, body_size) { + Some(body_bytes) => body_bytes, + None => { + self.send_server_error( + ServerError::LogicError(format!( + "Failed to obtain body bytes even though body_size is {}", + body_size + )), + None, + ); + return Action::Pause; + } + }; + + debug!("developer => archgw: {}", String::from_utf8_lossy(&body_bytes)); + // Deserialize body into spec. // Currently OpenAI API. let mut deserialized_body: ChatCompletionsRequest = - match self.get_http_request_body(0, body_size) { - Some(body_bytes) => match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(e) => { - self.send_server_error( - ServerError::Deserialization(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }, - None => { + match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(e) => { self.send_server_error( - ServerError::LogicError(format!( - "Failed to obtain body bytes even though body_size is {}", - body_size - )), - None, + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), ); return Action::Pause; } @@ -145,7 +150,6 @@ impl HttpContext for StreamContext { similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, - tool_calls: None, }; self.get_embeddings(callout_context); return Action::Pause; @@ -201,7 +205,6 @@ impl HttpContext for StreamContext { similarity_scores: None, upstream_cluster: None, upstream_cluster_path: None, - tool_calls: None, }; if let Err(e) = self.http_call(call_args, call_context) { @@ -212,7 +215,7 @@ impl HttpContext for StreamContext { } fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - debug!( + trace!( "on_http_response_headers recv [S={}] headers={:?}", self.context_id, self.get_http_response_headers() @@ -224,9 +227,11 @@ impl HttpContext for StreamContext { } fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - debug!( + trace!( "recv [S={}] bytes={} end_stream={}", - self.context_id, body_size, end_of_stream + self.context_id, + body_size, + end_of_stream ); if !self.is_chat_completions_request { @@ -248,14 +253,14 @@ impl HttpContext for StreamContext { .expect("cant get response body"); if self.streaming_response { - debug!("streaming response"); + trace!("streaming response"); } else { - debug!("non streaming response"); + trace!("non streaming response"); let chat_completions_response: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { - debug!( + trace!( "invalid response: {}, {}", String::from_utf8_lossy(&body), e @@ -316,16 +321,18 @@ impl HttpContext for StreamContext { serde_json::Value::String(arch_state_str), ); let data_serialized = serde_json::to_string(&data).unwrap(); - debug!("arch => user: {}", data_serialized); + debug!("archgw <= developer: {}", data_serialized); self.set_http_response_body(0, body_size, data_serialized.as_bytes()); }; } } } - debug!( + trace!( "recv [S={}] total_tokens={} end_stream={}", - self.context_id, self.response_tokens, end_of_stream + self.context_id, + self.response_tokens, + end_of_stream ); Action::Continue diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 9b99409c..dbc991a6 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -15,7 +15,7 @@ use common::consts::{ ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, - DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, + DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST, }; @@ -27,7 +27,7 @@ use common::http::{CallArgs, Client}; use common::stats::Gauge; use derivative::Derivative; use http::StatusCode; -use log::{debug, info, warn}; +use log::{debug, info, trace, warn}; use proxy_wasm::traits::*; use std::cell::RefCell; use std::collections::HashMap; @@ -37,11 +37,11 @@ use std::time::Duration; #[derive(Debug, Clone)] pub enum ResponseHandlerType { - GetEmbeddings, + Embeddings, ArchFC, FunctionCall, ZeroShotIntent, - HallucinationDetect, + Hallucination, ArchGuard, DefaultTarget, } @@ -54,7 +54,6 @@ pub struct StreamCallContext { pub prompt_target_name: Option, #[derivative(Debug = "ignore")] pub request_body: ChatCompletionsRequest, - pub tool_calls: Option>, pub similarity_scores: Option>, pub upstream_cluster: Option, pub upstream_cluster_path: Option, @@ -129,18 +128,77 @@ impl StreamContext { ); } + pub fn get_embeddings(&mut self, callout_context: StreamCallContext) { + let user_message = callout_context.user_message.unwrap(); + let get_embeddings_input = CreateEmbeddingRequest { + // Need to clone into input because user_message is used below. + input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())), + model: String::from(DEFAULT_EMBEDDING_MODEL), + encoding_format: None, + dimensions: None, + user: None, + }; + + let embeddings_request_str: String = match serde_json::to_string(&get_embeddings_input) { + Ok(json_data) => json_data, + Err(error) => { + warn!("error serializing get embeddings request: {}", error); + return self.send_server_error(ServerError::Deserialization(error), None); + } + }; + + let mut headers = vec![ + (ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST), + (":method", "POST"), + (":path", "/embeddings"), + (":authority", EMBEDDINGS_INTERNAL_HOST), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]; + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + } + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + "/embeddings", + headers, + Some(embeddings_request_str.as_bytes()), + vec![], + Duration::from_secs(5), + ); + let call_context = StreamCallContext { + response_handler_type: ResponseHandlerType::Embeddings, + user_message: Some(user_message), + prompt_target_name: None, + request_body: callout_context.request_body, + similarity_scores: None, + upstream_cluster: None, + upstream_cluster_path: None, + }; + + debug!( + "archgw => get embeddings request: {}", + embeddings_request_str + ); + if let Err(e) = self.http_call(call_args, call_context) { + warn!("error dispatching get embeddings request: {}", e); + self.send_server_error(ServerError::HttpDispatch(e), None); + } + } + pub fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { - debug!("error deserializing embedding response: {}", e); + warn!("error deserializing embedding response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; let prompt_embeddings_vector = &embedding_response.data[0].embedding; - debug!( + trace!( "embedding model: {}, vector length: {:?}", embedding_response.model, prompt_embeddings_vector.len() @@ -237,7 +295,7 @@ impl StreamContext { callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; if let Err(e) = self.http_call(call_args, callout_context) { - debug!("error dispatching zero shot classification request: {}", e); + warn!("error dispatching zero shot classification request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -247,11 +305,13 @@ impl StreamContext { body: Vec, callout_context: StreamCallContext, ) { + let boyd_str = String::from_utf8(body).expect("could not convert body to string"); + debug!("archgw <= hallucination response: {}", boyd_str); let hallucination_response: HallucinationClassificationResponse = - match serde_json::from_slice(&body) { + match serde_json::from_str(boyd_str.as_str()) { Ok(hallucination_response) => hallucination_response, Err(e) => { - debug!("error deserializing hallucination response: {}", e); + warn!("error deserializing hallucination response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -291,7 +351,7 @@ impl StreamContext { metadata: None, }; - debug!("hallucination response: {:?}", chat_completion_response); + trace!("hallucination response: {:?}", chat_completion_response); self.send_http_response( StatusCode::OK.as_u16().into(), vec![("Powered-By", "Katanemo")], @@ -316,7 +376,7 @@ impl StreamContext { match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { - debug!( + warn!( "error deserializing zero shot classification response: {}", e ); @@ -324,7 +384,10 @@ impl StreamContext { } }; - debug!("zeroshot intent response: {:?}", zeroshot_intent_response); + trace!( + "zeroshot intent response: {}", + serde_json::to_string(&zeroshot_intent_response).unwrap() + ); let desc_emb_similarity_map: HashMap = callout_context .similarity_scores @@ -362,7 +425,7 @@ impl StreamContext { } } } else { - info!("no assistant message found, probably first interaction"); + debug!("no assistant message found, probably first interaction"); } // get prompt target similarity thresold from overrides @@ -382,15 +445,16 @@ impl StreamContext { // if arch fc responded to the user message, then we don't need to check the similarity score // it may be that arch fc is handling the conversation for parameter collection if arch_assistant { - info!("arch assistant is handling the conversation"); + info!("arch fc is engaged in parameter collection"); } else { - debug!("checking for default prompt target"); if let Some(default_prompt_target) = self .prompt_targets .values() .find(|pt| pt.default.unwrap_or(false)) { - debug!("default prompt target found"); + debug!( + "default prompt target found, forwarding request to default prompt target" + ); let endpoint = default_prompt_target.endpoint.clone().unwrap(); let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); @@ -401,8 +465,6 @@ impl StreamContext { callout_context.request_body.messages.clone(), ); let arch_messages_json = serde_json::to_string(¶ms).unwrap(); - debug!("no prompt target found with similarity score above threshold, using default prompt target"); - let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); let mut headers = vec![ @@ -431,7 +493,7 @@ impl StreamContext { callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); if let Err(e) = self.http_call(call_args, callout_context) { - debug!("error dispatching default prompt target request: {}", e); + warn!("error dispatching default prompt target request: {}", e); return self.send_server_error( ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST), @@ -444,20 +506,12 @@ impl StreamContext { } } - let prompt_target = match self.prompt_targets.get(&prompt_target_name) { - Some(prompt_target) => prompt_target.clone(), - None => { - debug!("prompt target not found: {}", prompt_target_name); - return self.send_server_error( - ServerError::LogicError(format!( - "Prompt target not found: {prompt_target_name}" - )), - None, - ); - } - }; + let prompt_target = self + .prompt_targets + .get(&prompt_target_name) + .expect("prompt target not found") + .clone(); - info!("prompt_target name: {:?}", prompt_target_name); let mut chat_completion_tools: Vec = Vec::new(); for pt in self.prompt_targets.values() { if pt.default.unwrap_or_default() { @@ -506,7 +560,12 @@ impl StreamContext { ); let chat_completions = ChatCompletionsRequest { - model: GPT_35_TURBO.to_string(), + model: self + .chat_completions_request + .as_ref() + .unwrap() + .model + .clone(), messages: callout_context.request_body.messages.clone(), tools: Some(chat_completion_tools), stream: false, @@ -515,12 +574,9 @@ impl StreamContext { }; let msg_body = match serde_json::to_string(&chat_completions) { - Ok(msg_body) => { - debug!("arch_fc request body content: {}", msg_body); - msg_body - } + Ok(msg_body) => msg_body, Err(e) => { - debug!("error serializing arch_fc request body: {}", e); + warn!("error serializing arch_fc request body: {}", e); return self.send_server_error(ServerError::Serialization(e), None); } }; @@ -552,6 +608,7 @@ impl StreamContext { callout_context.response_handler_type = ResponseHandlerType::ArchFC; callout_context.prompt_target_name = Some(prompt_target.name); + debug!("archgw => archfc request: {}", msg_body); if let Err(e) = self.http_call(call_args, callout_context) { debug!("error dispatching arch_fc request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); @@ -564,21 +621,28 @@ impl StreamContext { mut callout_context: StreamCallContext, ) { let body_str = String::from_utf8(body).unwrap(); - debug!("arch <= app response body: {}", body_str); + debug!("archgw <= archfc response: {}", body_str); let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { - debug!("error deserializing arch_fc response: {}", e); + warn!("error deserializing archfc response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; - let model_resp = &arch_fc_response.choices[0]; + arch_fc_response.choices[0] + .message + .tool_calls + .clone_into(&mut self.tool_calls); + if self.tool_calls.as_ref().unwrap().len() > 1 { + warn!( + "multiple tool calls not supported yet, tool_calls count found: {}", + self.tool_calls.as_ref().unwrap().len() + ); + } - if model_resp.message.tool_calls.is_none() - || model_resp.message.tool_calls.as_ref().unwrap().is_empty() - { + if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() { // This means that Arch FC did not have enough information to resolve the function call // Arch FC probably responded with a message asking for more information. // Let's send the response back to the user to initalize lightweight dialog for parameter collection @@ -592,121 +656,118 @@ impl StreamContext { ); } - let tool_calls = model_resp.message.tool_calls.as_ref().unwrap(); - self.tool_calls = Some(tool_calls.clone()); - // TODO CO: pass nli check - // If hallucination, pass chat template to check parameters - - // extract all tool names - let tool_names: Vec = tool_calls - .iter() - .map(|tool_call| tool_call.function.name.clone()) - .collect(); + let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone(); + let prompt_target = self + .prompt_targets + .get(&tools_call_name) + .expect("prompt target not found for tool call") + .clone(); debug!( - "call context similarity score: {:?}", - callout_context.similarity_scores + "prompt_target_name: {}, tool_name(s): {:?}", + prompt_target.name, + self.tool_calls + .as_ref() + .unwrap() + .iter() + .map(|tc| tc.function.name.clone()) + .collect::>(), ); + + // If hallucination, pass chat template to check parameters //HACK: for now we only support one tool call, we will support multiple tool calls in the future - let mut tool_params = tool_calls[0].function.arguments.clone(); + + let mut tool_params = self.tool_calls.as_ref().unwrap()[0] + .function + .arguments + .clone(); + let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); + debug!( + "tool_params (without messages history): {}", + tool_params_json_str + ); tool_params.insert( String::from(ARCH_MESSAGES_KEY), serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), ); - - let tools_call_name = tool_calls[0].function.name.clone(); let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); - let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - callout_context.tool_calls = Some(tool_calls.clone()); + + use serde_json::Value; + let v: Value = serde_json::from_str(&tool_params_json_str).unwrap(); + let tool_params_dict: HashMap = match v.as_object() { + Some(obj) => obj + .iter() + .map(|(key, value)| { + // Convert each value to a string, regardless of its type + (key.clone(), value.to_string()) + }) + .collect(), + None => HashMap::new(), // Return an empty HashMap if v is not an object + }; + + let all_user_messages = + extract_messages_for_hallucination(&callout_context.request_body.messages); + let user_messages_str = all_user_messages.join(", "); + debug!("user messages: {}", user_messages_str); + + let hallucination_classification_request = HallucinationClassificationRequest { + prompt: user_messages_str, + model: String::from(DEFAULT_INTENT_MODEL), + parameters: tool_params_dict, + }; + + let hallucination_request_str: String = + match serde_json::to_string(&hallucination_classification_request) { + Ok(json_data) => json_data, + Err(error) => { + debug!( + "error serializing hallucination classification request: {}", + error + ); + return self.send_server_error(ServerError::Serialization(error), None); + } + }; + + let mut headers = vec![ + (ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST), + (":method", "POST"), + (":path", "/hallucination"), + (":authority", HALLUCINATION_INTERNAL_HOST), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]; + + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + } + + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + "/hallucination", + headers, + Some(hallucination_request_str.as_bytes()), + vec![], + Duration::from_secs(5), + ); + callout_context.response_handler_type = ResponseHandlerType::Hallucination; debug!( - "prompt_target_name: {}, tool_name(s): {:?}", - prompt_target.name, tool_names + "archgw => hallucination request: {}", + hallucination_request_str ); - debug!("tool_params: {}", tool_params_json_str); - - if model_resp.message.tool_calls.is_some() - && !model_resp.message.tool_calls.as_ref().unwrap().is_empty() - { - use serde_json::Value; - let v: Value = serde_json::from_str(&tool_params_json_str).unwrap(); - let tool_params_dict: HashMap = match v.as_object() { - Some(obj) => obj - .iter() - .map(|(key, value)| { - // Convert each value to a string, regardless of its type - (key.clone(), value.to_string()) - }) - .collect(), - None => HashMap::new(), // Return an empty HashMap if v is not an object - }; - - let all_user_messages = - extract_messages_for_hallucination(&callout_context.request_body.messages); - let user_messages_str = all_user_messages.join(", "); - debug!("user messages: {}", user_messages_str); - - let hallucination_classification_request = HallucinationClassificationRequest { - prompt: user_messages_str, - model: String::from(DEFAULT_INTENT_MODEL), - parameters: tool_params_dict, - }; - - let json_data: String = - match serde_json::to_string(&hallucination_classification_request) { - Ok(json_data) => json_data, - Err(error) => { - debug!( - "error serializing hallucination classification request: {}", - error - ); - return self.send_server_error(ServerError::Serialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", HALLUCINATION_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/hallucination", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect; - - if let Err(e) = self.http_call(call_args, callout_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } else { - self.schedule_api_call_request(callout_context); + if let Err(e) = self.http_call(call_args, callout_context) { + self.send_server_error(ServerError::HttpDispatch(e), None); } } fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) { - let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0] - .function - .name - .clone(); + let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone(); let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - //HACK: for now we only support one tool call, we will support multiple tool calls in the future - let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0] + let mut tool_params = self.tool_calls.as_ref().unwrap()[0] .function .arguments .clone(); @@ -741,8 +802,16 @@ impl StreamContext { vec![], Duration::from_secs(5), ); - callout_context.upstream_cluster = Some(endpoint.name.clone()); - callout_context.upstream_cluster_path = Some(path.clone()); + + debug!( + "archgw => api call, endpoint: {}/{}, body: {}", + endpoint.name.as_str(), + path, + tool_params_json_str + ); + + callout_context.upstream_cluster = Some(endpoint.name.to_owned()); + callout_context.upstream_cluster_path = Some(path.to_owned()); callout_context.response_handler_type = ResponseHandlerType::FunctionCall; if let Err(e) = self.http_call(call_args, callout_context) { @@ -750,32 +819,29 @@ impl StreamContext { } } - pub fn function_call_response_handler( - &mut self, - body: Vec, - callout_context: StreamCallContext, - ) { - if let Some(http_status) = self.get_http_call_response_header(":status") { - if http_status != StatusCode::OK.as_str() { - debug!("upstream error response: {}", http_status); - return self.send_server_error( - ServerError::Upstream { - host: callout_context.upstream_cluster.unwrap(), - path: callout_context.upstream_cluster_path.unwrap(), - status: http_status.clone(), - body: String::from_utf8(body).unwrap(), - }, - Some(StatusCode::from_str(http_status.as_str()).unwrap()), - ); - } - } else { - warn!("http status code not found in api response"); + pub fn api_call_response_handler(&mut self, body: Vec, callout_context: StreamCallContext) { + let http_status = self + .get_http_call_response_header(":status") + .expect("http status code not found"); + if http_status != StatusCode::OK.as_str() { + warn!( + "api server responded with non 2xx status code: {}", + http_status + ); + return self.send_server_error( + ServerError::Upstream { + host: callout_context.upstream_cluster.unwrap(), + path: callout_context.upstream_cluster_path.unwrap(), + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), + }, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), + ); } - let app_function_call_response_str: String = String::from_utf8(body).unwrap(); - self.tool_call_response = Some(app_function_call_response_str.clone()); + self.tool_call_response = Some(String::from_utf8(body).unwrap()); debug!( - "arch <= app response body: {}", - app_function_call_response_str + "archgw <= api call response: {}", + self.tool_call_response.as_ref().unwrap() ); let prompt_target_name = callout_context.prompt_target_name.unwrap(); let prompt_target = self @@ -825,7 +891,7 @@ impl StreamContext { let final_prompt = format!( "{}\ncontext: {}", user_message.content.unwrap(), - app_function_call_response_str + self.tool_call_response.as_ref().unwrap() ); // add original user prompt @@ -848,22 +914,24 @@ impl StreamContext { metadata: None, }; - let json_string = match serde_json::to_string(&chat_completions_request) { + let llm_request_str = match serde_json::to_string(&chat_completions_request) { Ok(json_string) => json_string, Err(e) => { return self.send_server_error(ServerError::Serialization(e), None); } }; - debug!("arch => upstream llm request body: {}", json_string); + debug!("archgw => llm request: {}", llm_request_str); - self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes()); + self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes()); self.resume_http_request(); } pub fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { - debug!("response received for arch guard"); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); - debug!("prompt_guard_resp: {:?}", prompt_guard_resp); + debug!( + "archgw <= archguard response: {:?}", + serde_json::to_string(&prompt_guard_resp) + ); if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() { //TODO: handle other scenarios like forward to error target @@ -871,7 +939,7 @@ impl StreamContext { .prompt_guards .jailbreak_on_exception_message() .unwrap_or("refrain from discussing jailbreaking."); - debug!("jailbreak detected: {}", msg); + warn!("jailbreak detected: {}", msg); return self.send_server_error( ServerError::Jailbreak(String::from(msg)), Some(StatusCode::BAD_REQUEST), @@ -881,92 +949,27 @@ impl StreamContext { self.get_embeddings(callout_context); } - pub fn get_embeddings(&mut self, callout_context: StreamCallContext) { - let user_message = callout_context.user_message.unwrap(); - let get_embeddings_input = CreateEmbeddingRequest { - // Need to clone into input because user_message is used below. - input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())), - model: String::from(DEFAULT_EMBEDDING_MODEL), - encoding_format: None, - dimensions: None, - user: None, - }; - - let json_data: String = match serde_json::to_string(&get_embeddings_input) { - Ok(json_data) => json_data, - Err(error) => { - debug!("error serializing get embeddings request: {}", error); - return self.send_server_error(ServerError::Deserialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", EMBEDDINGS_INTERNAL_HOST), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/embeddings", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::GetEmbeddings, - user_message: Some(user_message), - prompt_target_name: None, - request_body: callout_context.request_body, - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - - if let Err(e) = self.http_call(call_args, call_context) { - debug!("error dispatching get embeddings request: {}", e); - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } - pub fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { let prompt_target = self .prompt_targets .get(callout_context.prompt_target_name.as_ref().unwrap()) .unwrap() .clone(); - debug!( - "response received for default target: {}", - prompt_target.name - ); + // check if the default target should be dispatched to the LLM provider if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) { let default_target_response_str = String::from_utf8(body).unwrap(); - debug!( - "sending response back to developer: {}", - default_target_response_str - ); self.send_http_response( StatusCode::OK.as_u16().into(), vec![("Powered-By", "Katanemo")], Some(default_target_response_str.as_bytes()), ); - // self.resume_http_request(); return; } - debug!("default_target: sending api response to default llm"); let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(chat_completions_resp) => chat_completions_resp, Err(e) => { - debug!("error deserializing default target response: {}", e); + warn!("error deserializing default target response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -1000,7 +1003,12 @@ impl StreamContext { tool_call_id: None, }); let chat_completion_request = ChatCompletionsRequest { - model: GPT_35_TURBO.to_string(), + model: self + .chat_completions_request + .as_ref() + .unwrap() + .model + .clone(), messages, tools: None, stream: callout_context.request_body.stream, @@ -1008,7 +1016,7 @@ impl StreamContext { metadata: None, }; let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); - debug!("sending response back to default llm: {}", json_resp); + debug!("archgw => (default target) llm request: {}", json_resp); self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); self.resume_http_request(); } diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 7f7322f8..27eac427 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -33,7 +33,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { .returning(Some("/v1/chat/completions")) .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) .returning(None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) .returning(None) .execute_and_expect(ReturnType::Action(Action::Continue)) @@ -74,7 +74,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -92,6 +92,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { ) .returning(Some(1)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -116,6 +117,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .returning(Some(&prompt_guard_response_buffer)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -133,7 +135,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { ) .returning(Some(2)) .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -159,8 +160,9 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&embeddings_response_buffer)) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -178,7 +180,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { ) .returning(Some(3)) .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -200,9 +201,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&zeroshot_intent_detection_buffer)) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -219,8 +221,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { None, ) .returning(Some(4)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); @@ -245,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { module .call_proxy_on_tick(filter_context) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -426,8 +426,9 @@ fn successful_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call(Some("arch_internal"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) @@ -486,13 +487,14 @@ fn bad_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), None, None, None, ) + .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); } @@ -564,7 +566,7 @@ fn request_to_llm_gateway() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -603,6 +605,8 @@ fn request_to_llm_gateway() { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -628,10 +632,10 @@ fn request_to_llm_gateway() { .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) + .expect_log(Some(LogLevel::Debug), None) .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) .returning(Some("200")) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -664,11 +668,11 @@ fn request_to_llm_gateway() { ) .expect_get_buffer_bytes(Some(BufferType::HttpResponseBody)) .returning(Some(chat_completion_response_str.as_str())) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } From 8495f89fda6ecee9adc70c481091e1d88069d991 Mon Sep 17 00:00:00 2001 From: CTran Date: Tue, 22 Oct 2024 12:52:01 -0700 Subject: [PATCH 3/3] Cotran/hallucination (#208) --- crates/Cargo.lock | 1 + crates/prompt_gateway/Cargo.toml | 1 + crates/prompt_gateway/src/hallucination.rs | 143 +++++++++++++++++++-- 3 files changed, 136 insertions(+), 9 deletions(-) diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 4a76be14..a9c06a49 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -1120,6 +1120,7 @@ dependencies = [ "http", "log", "md5", + "pretty_assertions", "proxy-wasm", "proxy-wasm-test-framework", "rand", diff --git a/crates/prompt_gateway/Cargo.toml b/crates/prompt_gateway/Cargo.toml index 29d385b7..e8a166f8 100644 --- a/crates/prompt_gateway/Cargo.toml +++ b/crates/prompt_gateway/Cargo.toml @@ -26,3 +26,4 @@ sha2 = "0.10.8" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" } serial_test = "3.1.1" +pretty_assertions = "1.4.1" diff --git a/crates/prompt_gateway/src/hallucination.rs b/crates/prompt_gateway/src/hallucination.rs index 71d1e7cf..62b119ac 100644 --- a/crates/prompt_gateway/src/hallucination.rs +++ b/crates/prompt_gateway/src/hallucination.rs @@ -1,31 +1,63 @@ -use common::{common_types::open_ai::Message, consts::USER_ROLE}; +use common::{ + common_types::open_ai::Message, + consts::{ARCH_MODEL_PREFIX, ASSISTANT_ROLE, USER_ROLE}, +}; -pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { - let all_user_messages = messages - .iter() - .filter(|m| m.role == USER_ROLE) - .map(|m| m.content.as_ref().unwrap().clone()) - .collect::>(); - all_user_messages +pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec { + let mut arch_assistant = false; + let mut user_messages = Vec::new(); + if messages.len() >= 2 { + let latest_assistant_message = &messages[messages.len() - 2]; + if let Some(model) = latest_assistant_message.model.as_ref() { + if model.starts_with(ARCH_MODEL_PREFIX) { + arch_assistant = true; + } + } + } + if arch_assistant { + for message in messages.iter().rev() { + if let Some(model) = message.model.as_ref() { + if !model.starts_with(ARCH_MODEL_PREFIX) { + if message.role == ASSISTANT_ROLE { + break; + } + } + } + if message.role == USER_ROLE { + if let Some(content) = &message.content { + user_messages.push(content.clone()); + } + } + } + } else if let Some(message) = messages.last() { + if let Some(content) = &message.content { + user_messages.push(content.clone()); + } + } + user_messages.reverse(); // Reverse to maintain the original order + return user_messages; } #[cfg(test)] mod test { + use pretty_assertions::assert_eq; use common::common_types::open_ai::Message; use super::extract_messages_for_hallucination; #[test] - fn test_hallucination_message() { + fn test_hallucination_message_simple() { let test_str = r#" [ { "role": "system", + "model" : "gpt-3.5-turbo", "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" }, { "role": "user", "content": "tell me about headcount data" }, { "role": "assistant", + "model": "Arch-Function-1.5B", "content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data." }, { "role": "user", "content": "europe and for fte" } @@ -36,4 +68,97 @@ mod test { let messages_for_halluncination = extract_messages_for_hallucination(&messages); assert_eq!(messages_for_halluncination.len(), 2); } + #[test] + fn test_hallucination_message_medium() { + let test_str = r#" + [ + { + "role": "system", + "model" : "gpt-3.5-turbo", + "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" + }, + { "role": "user", "content": "Hello" }, + { + "role": "assistant", + "model": "gpt-3.5-turbo", + "content": "Hi there!" + }, + { "role": "user", "content": "tell me about headcount data" }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data." + }, + { "role": "user", "content": "europe" } + , + { + "role": "system", + "model": "Arch-Function-1.5B", + "content": "It seems like you are asking for headcount data for Europe. Could you please specify the staffing type?" + }, + { "role": "user", "content": "fte" } + ] + "#; + + let messages: Vec = serde_json::from_str(test_str).unwrap(); + let messages_for_halluncination = extract_messages_for_hallucination(&messages); + println!("{:?}", messages_for_halluncination); + assert_eq!(messages_for_halluncination.len(), 3); + } + #[test] + fn test_hallucination_message_long() { + let test_str = r#" + [ + { + "role": "system", + "model" : "gpt-3.5-turbo", + "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" + }, + { "role": "user", "content": "Hello" }, + { + "role": "assistant", + "model": "gpt-3.5-turbo", + "content": "Hi there!" + }, + { "role": "user", "content": "tell me about headcount data" }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data." + }, + { "role": "user", "content": "europe" }, + { + "role": "system", + "model": "Arch-Function-1.5B", + "content": "It seems like you are asking for headcount data for Europe. Could you please specify the staffing type?" + }, + { "role": "user", "content": "fte" }, + { + "role": "assistant", + "model": "gpt-3.5-turbo", + "content": "The headcount is 50000" + }, + { "role": "user", "content": "tell me about the weather" }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "content" : "The weather forcast tools requires 2 parameters: city and days. Please specify" + }, + { "role": "user", "content": "Seattle" }, + { + "role": "system", + "model": "Arch-Function-1.5B", + "content": "It seems like you are asking for weather data for Seattle. Could you please specify the days?" + }, + { "role": "user", "content": "7 days" } + ] + "#; + + let messages: Vec = serde_json::from_str(test_str).unwrap(); + let messages_for_halluncination = extract_messages_for_hallucination(&messages); + println!("{:?}", messages_for_halluncination); + assert_eq!(messages_for_halluncination.len(), 3); + assert_eq!(["tell me about the weather", "Seattle", "7 days"], messages_for_halluncination.as_slice()); + } + }