diff --git a/crates/prompt_gateway/src/hallucination.rs b/crates/prompt_gateway/src/hallucination.rs new file mode 100644 index 00000000..99d487ab --- /dev/null +++ b/crates/prompt_gateway/src/hallucination.rs @@ -0,0 +1,39 @@ +use common::{common_types::open_ai::Message, consts::USER_ROLE}; + +pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec { + let all_user_messages = messages + .iter() + .filter(|m| m.role == USER_ROLE) + .map(|m| m.content.as_ref().unwrap().clone()) + .collect::>(); + return all_user_messages; +} + +#[cfg(test)] +mod test { + use common::common_types::open_ai::Message; + + use super::extract_messages_for_hallucination; + + #[test] + fn test_hallucination_message() { + let test_str = r#" + [ + { + "role": "system", + "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" + }, + { "role": "user", "content": "tell me about headcount data" }, + { + "role": "assistant", + "content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data." + }, + { "role": "user", "content": "europe and for fte" } + ] + "#; + + let messages: Vec = serde_json::from_str(test_str).unwrap(); + let messages_for_halluncination = extract_messages_for_hallucination(&messages); + assert_eq!(messages_for_halluncination.len(), 2); + } +} diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index e2ad9025..7ca26e44 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -4,6 +4,7 @@ use proxy_wasm::types::*; mod filter_context; mod stream_context; +mod hallucination; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 33938deb..0364cd87 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,4 +1,5 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics}; +use crate::hallucination::extract_messages_for_hallucination; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, @@ -12,7 +13,12 @@ use common::common_types::{ }; use common::configuration::{Overrides, PromptGuards, PromptTarget}; use common::consts::{ - ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST + ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, + ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, + ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, + DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, + EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, + REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST, }; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, @@ -20,6 +26,7 @@ use common::embeddings::{ use common::errors::ClientError; use common::http::{CallArgs, Client}; use common::stats::Gauge; +use derivative::Derivative; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::traits::*; @@ -30,7 +37,6 @@ use std::collections::HashMap; use std::rc::Rc; use std::str::FromStr; use std::time::Duration; -use derivative::Derivative; use common::stats::IncrementingMetric; @@ -234,7 +240,10 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { Ok(json_data) => json_data, Err(error) => { - debug!("error serializing zero shot classification request: {}", error); + debug!( + "error serializing zero shot classification request: {}", + error + ); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -343,7 +352,10 @@ impl StreamContext { match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { - debug!("error deserializing zero shot classification response: {}", e); + debug!( + "error deserializing zero shot classification response: {}", + e + ); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -662,35 +674,9 @@ impl StreamContext { None => HashMap::new(), // Return an empty HashMap if v is not an object }; - let messages = &callout_context.request_body.messages; - let mut arch_assistant = false; - let mut user_messages = Vec::new(); - - if messages.len() >= 2 { - let latest_assistant_message = &messages[messages.len() - 2]; - if let Some(model) = latest_assistant_message.model.as_ref() { - if model.starts_with(ARCH_MODEL_PREFIX) { - arch_assistant = true; - } - } - } - if arch_assistant { - for message in messages.iter() { - if let Some(model) = message.model.as_ref() { - if !model.starts_with(ARCH_MODEL_PREFIX) { - break; - } - } - if message.role == "user" { - if let Some(content) = &message.content { - user_messages.push(content.clone()); - } - } - } - } else if let Some(user_message) = callout_context.user_message.as_ref() { - user_messages.push(user_message.clone()); - } - let user_messages_str = user_messages.join(", "); + let all_user_messages = + extract_messages_for_hallucination(&callout_context.request_body.messages); + let user_messages_str = all_user_messages.join(", "); debug!("user messages: {}", user_messages_str); let hallucination_classification_request = HallucinationClassificationRequest { @@ -703,7 +689,10 @@ impl StreamContext { match serde_json::to_string(&hallucination_classification_request) { Ok(json_data) => json_data, Err(error) => { - debug!("error serializing hallucination classification request: {}", error); + debug!( + "error serializing hallucination classification request: {}", + error + ); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -848,7 +837,7 @@ impl StreamContext { // don't send tools message and api response to chat gpt for m in callout_context.request_body.messages.iter() { if m.role == TOOL_ROLE || m.content.is_none() { - continue; + continue; } messages.push(m.clone()); } @@ -1286,7 +1275,11 @@ impl HttpContext for StreamContext { match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { - debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e); + debug!( + "invalid response: {}, {}", + String::from_utf8_lossy(&body), + e + ); return Action::Continue; } };