diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 4a76be14..a9c06a49 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -1120,6 +1120,7 @@ dependencies = [ "http", "log", "md5", + "pretty_assertions", "proxy-wasm", "proxy-wasm-test-framework", "rand", diff --git a/crates/prompt_gateway/Cargo.toml b/crates/prompt_gateway/Cargo.toml index 29d385b7..e8a166f8 100644 --- a/crates/prompt_gateway/Cargo.toml +++ b/crates/prompt_gateway/Cargo.toml @@ -26,3 +26,4 @@ sha2 = "0.10.8" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" } serial_test = "3.1.1" +pretty_assertions = "1.4.1" diff --git a/crates/prompt_gateway/src/hallucination.rs b/crates/prompt_gateway/src/hallucination.rs index 71d1e7cf..62b119ac 100644 --- a/crates/prompt_gateway/src/hallucination.rs +++ b/crates/prompt_gateway/src/hallucination.rs @@ -1,31 +1,63 @@ -use common::{common_types::open_ai::Message, consts::USER_ROLE}; +use common::{ + common_types::open_ai::Message, + consts::{ARCH_MODEL_PREFIX, ASSISTANT_ROLE, USER_ROLE}, +}; -pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { - let all_user_messages = messages - .iter() - .filter(|m| m.role == USER_ROLE) - .map(|m| m.content.as_ref().unwrap().clone()) - .collect::>(); - all_user_messages +pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec { + let mut arch_assistant = false; + let mut user_messages = Vec::new(); + if messages.len() >= 2 { + let latest_assistant_message = &messages[messages.len() - 2]; + if let Some(model) = latest_assistant_message.model.as_ref() { + if model.starts_with(ARCH_MODEL_PREFIX) { + arch_assistant = true; + } + } + } + if arch_assistant { + for message in messages.iter().rev() { + if let Some(model) = message.model.as_ref() { + if !model.starts_with(ARCH_MODEL_PREFIX) { + if message.role == ASSISTANT_ROLE { + break; + } + } + } + if message.role == USER_ROLE { + if let Some(content) = &message.content { + user_messages.push(content.clone()); + } + } + } + } else if let Some(message) = messages.last() { + if let Some(content) = &message.content { + user_messages.push(content.clone()); + } + } + user_messages.reverse(); // Reverse to maintain the original order + return user_messages; } #[cfg(test)] mod test { + use pretty_assertions::assert_eq; use common::common_types::open_ai::Message; use super::extract_messages_for_hallucination; #[test] - fn test_hallucination_message() { + fn test_hallucination_message_simple() { let test_str = r#" [ { "role": "system", + "model" : "gpt-3.5-turbo", "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" }, { "role": "user", "content": "tell me about headcount data" }, { "role": "assistant", + "model": "Arch-Function-1.5B", "content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data." }, { "role": "user", "content": "europe and for fte" } @@ -36,4 +68,97 @@ mod test { let messages_for_halluncination = extract_messages_for_hallucination(&messages); assert_eq!(messages_for_halluncination.len(), 2); } + #[test] + fn test_hallucination_message_medium() { + let test_str = r#" + [ + { + "role": "system", + "model" : "gpt-3.5-turbo", + "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" + }, + { "role": "user", "content": "Hello" }, + { + "role": "assistant", + "model": "gpt-3.5-turbo", + "content": "Hi there!" + }, + { "role": "user", "content": "tell me about headcount data" }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data." + }, + { "role": "user", "content": "europe" } + , + { + "role": "system", + "model": "Arch-Function-1.5B", + "content": "It seems like you are asking for headcount data for Europe. Could you please specify the staffing type?" + }, + { "role": "user", "content": "fte" } + ] + "#; + + let messages: Vec = serde_json::from_str(test_str).unwrap(); + let messages_for_halluncination = extract_messages_for_hallucination(&messages); + println!("{:?}", messages_for_halluncination); + assert_eq!(messages_for_halluncination.len(), 3); + } + #[test] + fn test_hallucination_message_long() { + let test_str = r#" + [ + { + "role": "system", + "model" : "gpt-3.5-turbo", + "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" + }, + { "role": "user", "content": "Hello" }, + { + "role": "assistant", + "model": "gpt-3.5-turbo", + "content": "Hi there!" + }, + { "role": "user", "content": "tell me about headcount data" }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data." + }, + { "role": "user", "content": "europe" }, + { + "role": "system", + "model": "Arch-Function-1.5B", + "content": "It seems like you are asking for headcount data for Europe. Could you please specify the staffing type?" + }, + { "role": "user", "content": "fte" }, + { + "role": "assistant", + "model": "gpt-3.5-turbo", + "content": "The headcount is 50000" + }, + { "role": "user", "content": "tell me about the weather" }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "content" : "The weather forcast tools requires 2 parameters: city and days. Please specify" + }, + { "role": "user", "content": "Seattle" }, + { + "role": "system", + "model": "Arch-Function-1.5B", + "content": "It seems like you are asking for weather data for Seattle. Could you please specify the days?" + }, + { "role": "user", "content": "7 days" } + ] + "#; + + let messages: Vec = serde_json::from_str(test_str).unwrap(); + let messages_for_halluncination = extract_messages_for_hallucination(&messages); + println!("{:?}", messages_for_halluncination); + assert_eq!(messages_for_halluncination.len(), 3); + assert_eq!(["tell me about the weather", "Seattle", "7 days"], messages_for_halluncination.as_slice()); + } + }