Merge branch 'main' into debug-streaming-v2

This commit is contained in:
Adil Hafeez 2024-10-24 15:33:40 -07:00
commit 8e098fb5c0
29 changed files with 662 additions and 2974 deletions

View file

@ -15,7 +15,7 @@ pub const HALLUCINATION_INTERNAL_HOST: &str = "hallucination";
pub const EMBEDDINGS_INTERNAL_HOST: &str = "embeddings";
pub const GUARD_INTERNAL_HOST: &str = "guard";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const MESSAGES_KEY: &str = "messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
@ -25,3 +25,4 @@ pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
pub const ARCH_MODEL_PREFIX: &str = "Arch";
pub const HALLUCINATION_TEMPLATE: &str = "It seems Im missing some information. Could you provide the following details ";

View file

@ -1,6 +1,6 @@
use common::{
common_types::open_ai::Message,
consts::{ARCH_MODEL_PREFIX, ASSISTANT_ROLE, USER_ROLE},
consts::{ARCH_MODEL_PREFIX, USER_ROLE, HALLUCINATION_TEMPLATE},
};
pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String> {
@ -18,9 +18,11 @@ pub fn extract_messages_for_hallucination(messages: &Vec<Message>) -> Vec<String
for message in messages.iter().rev() {
if let Some(model) = message.model.as_ref() {
if !model.starts_with(ARCH_MODEL_PREFIX) {
if message.role == ASSISTANT_ROLE {
break;
}
if let Some(content) = &message.content {
if !content.starts_with(HALLUCINATION_TEMPLATE) {
break;
}
}
}
}
if message.role == USER_ROLE {

View file

@ -12,12 +12,7 @@ use common::common_types::{
};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST,
HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
ZEROSHOT_INTERNAL_HOST,
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, HALLUCINATION_TEMPLATE, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -328,12 +323,11 @@ impl StreamContext {
if !keys_with_low_score.is_empty() {
let response =
"It seems Im missing some information. Could you provide the following details: "
.to_string()
HALLUCINATION_TEMPLATE.to_string()
+ &keys_with_low_score.join(", ")
+ " ?";
let message = Message {
role: SYSTEM_ROLE.to_string(),
role: ASSISTANT_ROLE.to_string(),
content: Some(response),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
@ -461,7 +455,7 @@ impl StreamContext {
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
ARCH_MESSAGES_KEY.to_string(),
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
@ -688,7 +682,7 @@ impl StreamContext {
tool_params_json_str
);
tool_params.insert(
String::from(ARCH_MESSAGES_KEY),
String::from(MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
);
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
@ -772,7 +766,7 @@ impl StreamContext {
.arguments
.clone();
tool_params.insert(
String::from(ARCH_MESSAGES_KEY),
String::from(MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
);