Add separate util for hallucination and add tests for it (#203)

This commit is contained in:
Adil Hafeez 2024-10-18 19:34:17 -07:00 committed by GitHub
parent faf64960df
commit dced8a5708
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 69 additions and 36 deletions

View file

@ -1,4 +1,5 @@
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::hallucination::extract_messages_for_hallucination;
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
@ -12,7 +13,12 @@ use common::common_types::{
};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL,
DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD,
EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST,
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST,
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -20,6 +26,7 @@ use common::embeddings::{
use common::errors::ClientError;
use common::http::{CallArgs, Client};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
@ -30,7 +37,6 @@ use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use std::time::Duration;
use derivative::Derivative;
use common::stats::IncrementingMetric;
@ -234,7 +240,10 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing zero shot classification request: {}", error);
debug!(
"error serializing zero shot classification request: {}",
error
);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
@ -343,7 +352,10 @@ impl StreamContext {
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
debug!("error deserializing zero shot classification response: {}", e);
debug!(
"error deserializing zero shot classification response: {}",
e
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -662,35 +674,9 @@ impl StreamContext {
None => HashMap::new(), // Return an empty HashMap if v is not an object
};
let messages = &callout_context.request_body.messages;
let mut arch_assistant = false;
let mut user_messages = Vec::new();
if messages.len() >= 2 {
let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() {
if model.starts_with(ARCH_MODEL_PREFIX) {
arch_assistant = true;
}
}
}
if arch_assistant {
for message in messages.iter() {
if let Some(model) = message.model.as_ref() {
if !model.starts_with(ARCH_MODEL_PREFIX) {
break;
}
}
if message.role == "user" {
if let Some(content) = &message.content {
user_messages.push(content.clone());
}
}
}
} else if let Some(user_message) = callout_context.user_message.as_ref() {
user_messages.push(user_message.clone());
}
let user_messages_str = user_messages.join(", ");
let all_user_messages =
extract_messages_for_hallucination(&callout_context.request_body.messages);
let user_messages_str = all_user_messages.join(", ");
debug!("user messages: {}", user_messages_str);
let hallucination_classification_request = HallucinationClassificationRequest {
@ -703,7 +689,10 @@ impl StreamContext {
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!("error serializing hallucination classification request: {}", error);
debug!(
"error serializing hallucination classification request: {}",
error
);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
@ -848,7 +837,7 @@ impl StreamContext {
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
if m.role == TOOL_ROLE || m.content.is_none() {
continue;
continue;
}
messages.push(m.clone());
}
@ -1286,7 +1275,11 @@ impl HttpContext for StreamContext {
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e);
debug!(
"invalid response: {}, {}",
String::from_utf8_lossy(&body),
e
);
return Action::Continue;
}
};