This commit is contained in:
cotran 2024-10-10 15:57:45 -07:00
parent 45aaaf09be
commit 0774f6d7f7

View file

@ -728,14 +728,14 @@ impl StreamContext {
None => HashMap::new(), // Return an empty HashMap if v is not an object None => HashMap::new(), // Return an empty HashMap if v is not an object
}; };
let mut user_messages = String::new(); let mut user_messages = Vec::new();
let messages = &callout_context.request_body.messages; let messages = &callout_context.request_body.messages;
let mut arch_assistant = false; let mut arch_assistant = false;
if messages.len() >= 2 { if messages.len() >= 2 {
let latest_assistant_message = &messages[messages.len() - 2]; let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() { if let Some(model) = latest_assistant_message.model.as_ref() {
if model.contains(ARCH_MODEL_PREFIX) { if model.starts_with(ARCH_MODEL_PREFIX) {
arch_assistant = true; arch_assistant = true;
} }
} }
@ -746,23 +746,26 @@ impl StreamContext {
if arch_assistant { if arch_assistant {
for message in messages.iter() { for message in messages.iter() {
if let Some(model) = message.model.as_ref() { if let Some(model) = message.model.as_ref() {
if !model.contains(ARCH_MODEL_PREFIX) { if !model.starts_with(ARCH_MODEL_PREFIX) {
break; break;
} }
} }
if message.role == "user" { if message.role == "user" {
if let Some(content) = &message.content { if let Some(content) = &message.content {
user_messages = format!("{} , {}", user_messages, content); user_messages.push(content.clone());
} }
} }
} }
} else { } else {
user_messages = callout_context.user_message.as_ref().unwrap().clone(); if let Some(user_message) = callout_context.user_message.as_ref() {
user_messages.push(user_message.clone());
} }
info!("user messages: {}", user_messages); }
let user_messages_str = user_messages.join(", ");
info!("user messages: {}", user_messages_str);
let hallucination_classification_request = HallucinationClassificationRequest { let hallucination_classification_request = HallucinationClassificationRequest {
prompt: user_messages, prompt: user_messages_str,
model: String::from(DEFAULT_INTENT_MODEL), model: String::from(DEFAULT_INTENT_MODEL),
parameters: tool_params_dict, parameters: tool_params_dict,
}; };