mirror of
https://github.com/katanemo/plano.git
synced 2026-05-15 11:02:39 +02:00
Cotran/hallucination (#208)
This commit is contained in:
parent
ea76d85b43
commit
8495f89fda
3 changed files with 136 additions and 9 deletions
1
crates/Cargo.lock
generated
1
crates/Cargo.lock
generated
|
|
@ -1120,6 +1120,7 @@ dependencies = [
|
||||||
"http",
|
"http",
|
||||||
"log",
|
"log",
|
||||||
"md5",
|
"md5",
|
||||||
|
"pretty_assertions",
|
||||||
"proxy-wasm",
|
"proxy-wasm",
|
||||||
"proxy-wasm-test-framework",
|
"proxy-wasm-test-framework",
|
||||||
"rand",
|
"rand",
|
||||||
|
|
|
||||||
|
|
@ -26,3 +26,4 @@ sha2 = "0.10.8"
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
|
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
|
||||||
serial_test = "3.1.1"
|
serial_test = "3.1.1"
|
||||||
|
pretty_assertions = "1.4.1"
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,63 @@
|
||||||
use common::{common_types::open_ai::Message, consts::USER_ROLE};
|
use common::{
|
||||||
|
common_types::open_ai::Message,
|
||||||
|
consts::{ARCH_MODEL_PREFIX, ASSISTANT_ROLE, USER_ROLE},
|
||||||
|
};
|
||||||
|
|
||||||
pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String> {
|
||||||
let all_user_messages = messages
|
let mut arch_assistant = false;
|
||||||
.iter()
|
let mut user_messages = Vec::new();
|
||||||
.filter(|m| m.role == USER_ROLE)
|
if messages.len() >= 2 {
|
||||||
.map(|m| m.content.as_ref().unwrap().clone())
|
let latest_assistant_message = &messages[messages.len() - 2];
|
||||||
.collect::<Vec<String>>();
|
if let Some(model) = latest_assistant_message.model.as_ref() {
|
||||||
all_user_messages
|
if model.starts_with(ARCH_MODEL_PREFIX) {
|
||||||
|
arch_assistant = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if arch_assistant {
|
||||||
|
for message in messages.iter().rev() {
|
||||||
|
if let Some(model) = message.model.as_ref() {
|
||||||
|
if !model.starts_with(ARCH_MODEL_PREFIX) {
|
||||||
|
if message.role == ASSISTANT_ROLE {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if message.role == USER_ROLE {
|
||||||
|
if let Some(content) = &message.content {
|
||||||
|
user_messages.push(content.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let Some(message) = messages.last() {
|
||||||
|
if let Some(content) = &message.content {
|
||||||
|
user_messages.push(content.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
user_messages.reverse(); // Reverse to maintain the original order
|
||||||
|
return user_messages;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
use common::common_types::open_ai::Message;
|
use common::common_types::open_ai::Message;
|
||||||
|
|
||||||
use super::extract_messages_for_hallucination;
|
use super::extract_messages_for_hallucination;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_hallucination_message() {
|
fn test_hallucination_message_simple() {
|
||||||
let test_str = r#"
|
let test_str = r#"
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
"model" : "gpt-3.5-turbo",
|
||||||
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
|
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
|
||||||
},
|
},
|
||||||
{ "role": "user", "content": "tell me about headcount data" },
|
{ "role": "user", "content": "tell me about headcount data" },
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
"model": "Arch-Function-1.5B",
|
||||||
"content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data."
|
"content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data."
|
||||||
},
|
},
|
||||||
{ "role": "user", "content": "europe and for fte" }
|
{ "role": "user", "content": "europe and for fte" }
|
||||||
|
|
@ -36,4 +68,97 @@ mod test {
|
||||||
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
|
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
|
||||||
assert_eq!(messages_for_halluncination.len(), 2);
|
assert_eq!(messages_for_halluncination.len(), 2);
|
||||||
}
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_hallucination_message_medium() {
|
||||||
|
let test_str = r#"
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"model" : "gpt-3.5-turbo",
|
||||||
|
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "Hello" },
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"content": "Hi there!"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "tell me about headcount data" },
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "Arch-Function-1.5B",
|
||||||
|
"content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data."
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "europe" }
|
||||||
|
,
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"model": "Arch-Function-1.5B",
|
||||||
|
"content": "It seems like you are asking for headcount data for Europe. Could you please specify the staffing type?"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "fte" }
|
||||||
|
]
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
|
||||||
|
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
|
||||||
|
println!("{:?}", messages_for_halluncination);
|
||||||
|
assert_eq!(messages_for_halluncination.len(), 3);
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_hallucination_message_long() {
|
||||||
|
let test_str = r#"
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"model" : "gpt-3.5-turbo",
|
||||||
|
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "Hello" },
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"content": "Hi there!"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "tell me about headcount data" },
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "Arch-Function-1.5B",
|
||||||
|
"content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data."
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "europe" },
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"model": "Arch-Function-1.5B",
|
||||||
|
"content": "It seems like you are asking for headcount data for Europe. Could you please specify the staffing type?"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "fte" },
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"content": "The headcount is 50000"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "tell me about the weather" },
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "Arch-Function-1.5B",
|
||||||
|
"content" : "The weather forcast tools requires 2 parameters: city and days. Please specify"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "Seattle" },
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"model": "Arch-Function-1.5B",
|
||||||
|
"content": "It seems like you are asking for weather data for Seattle. Could you please specify the days?"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "7 days" }
|
||||||
|
]
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
|
||||||
|
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
|
||||||
|
println!("{:?}", messages_for_halluncination);
|
||||||
|
assert_eq!(messages_for_halluncination.len(), 3);
|
||||||
|
assert_eq!(["tell me about the weather", "Seattle", "7 days"], messages_for_halluncination.as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue