mirror of
https://github.com/katanemo/plano.git
synced 2026-05-07 23:02:43 +02:00
Add separate util for hallucination and add tests for it (#203)
This commit is contained in:
parent
faf64960df
commit
dced8a5708
3 changed files with 69 additions and 36 deletions
39
crates/prompt_gateway/src/hallucination.rs
Normal file
39
crates/prompt_gateway/src/hallucination.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
use common::{common_types::open_ai::Message, consts::USER_ROLE};
|
||||||
|
|
||||||
|
pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String> {
|
||||||
|
let all_user_messages = messages
|
||||||
|
.iter()
|
||||||
|
.filter(|m| m.role == USER_ROLE)
|
||||||
|
.map(|m| m.content.as_ref().unwrap().clone())
|
||||||
|
.collect::<Vec<String>>();
|
||||||
|
return all_user_messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use common::common_types::open_ai::Message;
|
||||||
|
|
||||||
|
use super::extract_messages_for_hallucination;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_hallucination_message() {
|
||||||
|
let test_str = r#"
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"headcount\", \"description\": \"Get headcount data for a region by staffing type\", \"parameters\": {\"properties\": {\"staffing_type\": {\"type\": \"str\", \"description\": \"The staffing type like contract, fte or agency\"}, \"region\": {\"type\": \"str\", \"description\": \"the geographical region for which you want headcount data.\"}}, \"required\": [\"staffing_type\", \"region\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "tell me about headcount data" },
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The \"headcount\" tool provides information about the number of employees in a specific region based on the type of staffing used. It requires two parameters: \"staffing_type\" and \"region\". The \"staffing_type\" parameter specifies the type of staffing, such as contract, full-time equivalent (fte), or agency. The \"region\" parameter specifies the geographical region for which you want headcount data."
|
||||||
|
},
|
||||||
|
{ "role": "user", "content": "europe and for fte" }
|
||||||
|
]
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let messages: Vec<Message> = serde_json::from_str(test_str).unwrap();
|
||||||
|
let messages_for_halluncination = extract_messages_for_hallucination(&messages);
|
||||||
|
assert_eq!(messages_for_halluncination.len(), 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,6 +4,7 @@ use proxy_wasm::types::*;
|
||||||
|
|
||||||
mod filter_context;
|
mod filter_context;
|
||||||
mod stream_context;
|
mod stream_context;
|
||||||
|
mod hallucination;
|
||||||
|
|
||||||
proxy_wasm::main! {{
|
proxy_wasm::main! {{
|
||||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||||
|
use crate::hallucination::extract_messages_for_hallucination;
|
||||||
use acap::cos;
|
use acap::cos;
|
||||||
use common::common_types::open_ai::{
|
use common::common_types::open_ai::{
|
||||||
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
|
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
|
||||||
|
|
@ -12,7 +13,12 @@ use common::common_types::{
|
||||||
};
|
};
|
||||||
use common::configuration::{Overrides, PromptGuards, PromptTarget};
|
use common::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||||
use common::consts::{
|
use common::consts::{
|
||||||
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
|
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
|
||||||
|
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER,
|
||||||
|
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL,
|
||||||
|
DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD,
|
||||||
|
EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST,
|
||||||
|
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST,
|
||||||
};
|
};
|
||||||
use common::embeddings::{
|
use common::embeddings::{
|
||||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||||
|
|
@ -20,6 +26,7 @@ use common::embeddings::{
|
||||||
use common::errors::ClientError;
|
use common::errors::ClientError;
|
||||||
use common::http::{CallArgs, Client};
|
use common::http::{CallArgs, Client};
|
||||||
use common::stats::Gauge;
|
use common::stats::Gauge;
|
||||||
|
use derivative::Derivative;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use log::{debug, info, warn};
|
use log::{debug, info, warn};
|
||||||
use proxy_wasm::traits::*;
|
use proxy_wasm::traits::*;
|
||||||
|
|
@ -30,7 +37,6 @@ use std::collections::HashMap;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use derivative::Derivative;
|
|
||||||
|
|
||||||
use common::stats::IncrementingMetric;
|
use common::stats::IncrementingMetric;
|
||||||
|
|
||||||
|
|
@ -234,7 +240,10 @@ impl StreamContext {
|
||||||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||||
Ok(json_data) => json_data,
|
Ok(json_data) => json_data,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
debug!("error serializing zero shot classification request: {}", error);
|
debug!(
|
||||||
|
"error serializing zero shot classification request: {}",
|
||||||
|
error
|
||||||
|
);
|
||||||
return self.send_server_error(ServerError::Serialization(error), None);
|
return self.send_server_error(ServerError::Serialization(error), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -343,7 +352,10 @@ impl StreamContext {
|
||||||
match serde_json::from_slice(&body) {
|
match serde_json::from_slice(&body) {
|
||||||
Ok(zeroshot_response) => zeroshot_response,
|
Ok(zeroshot_response) => zeroshot_response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("error deserializing zero shot classification response: {}", e);
|
debug!(
|
||||||
|
"error deserializing zero shot classification response: {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -662,35 +674,9 @@ impl StreamContext {
|
||||||
None => HashMap::new(), // Return an empty HashMap if v is not an object
|
None => HashMap::new(), // Return an empty HashMap if v is not an object
|
||||||
};
|
};
|
||||||
|
|
||||||
let messages = &callout_context.request_body.messages;
|
let all_user_messages =
|
||||||
let mut arch_assistant = false;
|
extract_messages_for_hallucination(&callout_context.request_body.messages);
|
||||||
let mut user_messages = Vec::new();
|
let user_messages_str = all_user_messages.join(", ");
|
||||||
|
|
||||||
if messages.len() >= 2 {
|
|
||||||
let latest_assistant_message = &messages[messages.len() - 2];
|
|
||||||
if let Some(model) = latest_assistant_message.model.as_ref() {
|
|
||||||
if model.starts_with(ARCH_MODEL_PREFIX) {
|
|
||||||
arch_assistant = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if arch_assistant {
|
|
||||||
for message in messages.iter() {
|
|
||||||
if let Some(model) = message.model.as_ref() {
|
|
||||||
if !model.starts_with(ARCH_MODEL_PREFIX) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if message.role == "user" {
|
|
||||||
if let Some(content) = &message.content {
|
|
||||||
user_messages.push(content.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if let Some(user_message) = callout_context.user_message.as_ref() {
|
|
||||||
user_messages.push(user_message.clone());
|
|
||||||
}
|
|
||||||
let user_messages_str = user_messages.join(", ");
|
|
||||||
debug!("user messages: {}", user_messages_str);
|
debug!("user messages: {}", user_messages_str);
|
||||||
|
|
||||||
let hallucination_classification_request = HallucinationClassificationRequest {
|
let hallucination_classification_request = HallucinationClassificationRequest {
|
||||||
|
|
@ -703,7 +689,10 @@ impl StreamContext {
|
||||||
match serde_json::to_string(&hallucination_classification_request) {
|
match serde_json::to_string(&hallucination_classification_request) {
|
||||||
Ok(json_data) => json_data,
|
Ok(json_data) => json_data,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
debug!("error serializing hallucination classification request: {}", error);
|
debug!(
|
||||||
|
"error serializing hallucination classification request: {}",
|
||||||
|
error
|
||||||
|
);
|
||||||
return self.send_server_error(ServerError::Serialization(error), None);
|
return self.send_server_error(ServerError::Serialization(error), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -848,7 +837,7 @@ impl StreamContext {
|
||||||
// don't send tools message and api response to chat gpt
|
// don't send tools message and api response to chat gpt
|
||||||
for m in callout_context.request_body.messages.iter() {
|
for m in callout_context.request_body.messages.iter() {
|
||||||
if m.role == TOOL_ROLE || m.content.is_none() {
|
if m.role == TOOL_ROLE || m.content.is_none() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
messages.push(m.clone());
|
messages.push(m.clone());
|
||||||
}
|
}
|
||||||
|
|
@ -1286,7 +1275,11 @@ impl HttpContext for StreamContext {
|
||||||
match serde_json::from_slice(&body) {
|
match serde_json::from_slice(&body) {
|
||||||
Ok(de) => de,
|
Ok(de) => de,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e);
|
debug!(
|
||||||
|
"invalid response: {}, {}",
|
||||||
|
String::from_utf8_lossy(&body),
|
||||||
|
e
|
||||||
|
);
|
||||||
return Action::Continue;
|
return Action::Continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue