diff --git a/arch/src/consts.rs b/arch/src/consts.rs index 32172002..76244f6b 100644 --- a/arch/src/consts.rs +++ b/arch/src/consts.rs @@ -1,7 +1,7 @@ pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5"; pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli"; pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8; -pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1; +pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; @@ -19,3 +19,4 @@ pub const REQUEST_ID_HEADER: &str = "x-request-id"; pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream"; pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener"; +pub const ARCH_MODEL_PREFIX: &str = "Arch"; diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index bc9e62fa..7a65609c 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -1,9 +1,9 @@ use crate::consts::{ ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, - ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, - ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, - DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, - DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, + ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER, + ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, + CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, + DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{EmbeddingsStore, WasmMetrics}; @@ -453,7 +453,7 @@ impl StreamContext { if messages.len() >= 2 { let latest_assistant_message = &messages[messages.len() - 2]; if let Some(model) = latest_assistant_message.model.as_ref() { - if model.contains("Arch") { + if model.contains(ARCH_MODEL_PREFIX) { arch_assistant = true; } } @@ -728,8 +728,41 @@ impl StreamContext { None => HashMap::new(), // Return an empty HashMap if v is not an object }; + let messages = &callout_context.request_body.messages; + let mut arch_assistant = false; + let mut user_messages = Vec::new(); + + if messages.len() >= 2 { + let latest_assistant_message = &messages[messages.len() - 2]; + if let Some(model) = latest_assistant_message.model.as_ref() { + if model.starts_with(ARCH_MODEL_PREFIX) { + arch_assistant = true; + } + } + } + if arch_assistant { + for message in messages.iter() { + if let Some(model) = message.model.as_ref() { + if !model.starts_with(ARCH_MODEL_PREFIX) { + break; + } + } + if message.role == "user" { + if let Some(content) = &message.content { + user_messages.push(content.clone()); + } + } + } + } else { + if let Some(user_message) = callout_context.user_message.as_ref() { + user_messages.push(user_message.clone()); + } + } + let user_messages_str = user_messages.join(", "); + debug!("user messages: {}", user_messages_str); + let hallucination_classification_request = HallucinationClassificationRequest { - prompt: callout_context.user_message.as_ref().unwrap().clone(), + prompt: user_messages_str, model: String::from(DEFAULT_INTENT_MODEL), parameters: tool_params_dict, }; diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 06918bdf..c628d9c3 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -586,6 +586,7 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -728,6 +729,7 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("arch_internal"), Some(vec![ diff --git a/model_server/app/main.py b/model_server/app/main.py index 6a90c0b1..c6f5752a 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -177,7 +177,7 @@ async def hallucination(req: HallucinationRequest, res: Response): """ Take input as text and return the prediction of hallucination for each parameter """ - + logger.info(f"hallucination request: {req}") if req.model != zero_shot_model["model_name"]: raise HTTPException(status_code=400, detail="unknown model: " + req.model)