concatenate history of user messages for hallucination (#177)

* concatenate history of user messages for hallucination

* add history of messages

* fix gpt to not arch

* add model prefix

* fix

* correct init of user_messages

* fmt

* fix test
This commit is contained in:
Co Tran 2024-10-15 11:43:05 -07:00 committed by GitHub
parent 35c5e303b7
commit b1746b38b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 44 additions and 8 deletions

View file

@ -1,9 +1,9 @@
use crate::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH,
DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER,
ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
@ -453,7 +453,7 @@ impl StreamContext {
if messages.len() >= 2 {
let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() {
if model.contains("Arch") {
if model.contains(ARCH_MODEL_PREFIX) {
arch_assistant = true;
}
}
@ -728,8 +728,41 @@ impl StreamContext {
None => HashMap::new(), // Return an empty HashMap if v is not an object
};
let messages = &callout_context.request_body.messages;
let mut arch_assistant = false;
let mut user_messages = Vec::new();
if messages.len() >= 2 {
let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() {
if model.starts_with(ARCH_MODEL_PREFIX) {
arch_assistant = true;
}
}
}
if arch_assistant {
for message in messages.iter() {
if let Some(model) = message.model.as_ref() {
if !model.starts_with(ARCH_MODEL_PREFIX) {
break;
}
}
if message.role == "user" {
if let Some(content) = &message.content {
user_messages.push(content.clone());
}
}
}
} else {
if let Some(user_message) = callout_context.user_message.as_ref() {
user_messages.push(user_message.clone());
}
}
let user_messages_str = user_messages.join(", ");
debug!("user messages: {}", user_messages_str);
let hallucination_classification_request = HallucinationClassificationRequest {
prompt: callout_context.user_message.as_ref().unwrap().clone(),
prompt: user_messages_str,
model: String::from(DEFAULT_INTENT_MODEL),
parameters: tool_params_dict,
};