Use intent model from archfc to pick prompt gateway (#328)

This commit is contained in:
Shuguang Chen 2024-12-20 13:25:01 -08:00 committed by GitHub
parent 67b8fd635e
commit ba7279becb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
151 changed files with 8642 additions and 10932 deletions

View file

@ -1,36 +1,20 @@
use crate::embeddings::EmbeddingType;
use crate::filter_context::EmbeddingsStore;
use crate::metrics::Metrics;
use acap::cos;
use common::api::hallucination::{
extract_messages_for_hallucination, HallucinationClassificationRequest,
HallucinationClassificationResponse,
};
use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
};
use common::api::prompt_guard::PromptGuardResponse;
use common::api::zero_shot::{ZeroShotClassificationRequest, ZeroShotClassificationResponse};
use common::configuration::{Overrides, PromptGuards, PromptTarget, Tracing};
use common::configuration::{Overrides, PromptTarget, Tracing};
use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST,
HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE, ZEROSHOT_INTERNAL_HOST,
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE,
TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
};
use common::errors::ServerError;
use common::http::{CallArgs, Client};
use common::stats::Gauge;
use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, trace, warn};
use log::{debug, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
@ -41,12 +25,8 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub enum ResponseHandlerType {
Embeddings,
ArchFC,
FunctionCall,
ZeroShotIntent,
Hallucination,
ArchGuard,
DefaultTarget,
}
@ -66,8 +46,7 @@ pub struct StreamCallContext {
pub struct StreamContext {
system_prompt: Rc<Option<String>>,
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
pub embeddings_store: Option<Rc<EmbeddingsStore>>,
overrides: Rc<Option<Overrides>>,
_overrides: Rc<Option<Overrides>>,
pub metrics: Rc<Metrics>,
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
pub context_id: u32,
@ -79,12 +58,11 @@ pub struct StreamContext {
pub streaming_response: bool,
pub is_chat_completions_request: bool,
pub chat_completions_request: Option<ChatCompletionsRequest>,
pub prompt_guards: Rc<PromptGuards>,
pub request_id: Option<String>,
pub start_upstream_llm_request_time: u128,
pub time_to_first_token: Option<u128>,
pub traceparent: Option<String>,
pub tracing: Rc<Option<Tracing>>,
pub _tracing: Rc<Option<Tracing>>,
}
impl StreamContext {
@ -94,9 +72,7 @@ impl StreamContext {
metrics: Rc<Metrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
tracing: Rc<Option<Tracing>>,
) -> Self {
StreamContext {
@ -104,7 +80,6 @@ impl StreamContext {
metrics,
system_prompt,
prompt_targets,
embeddings_store,
callouts: RefCell::new(HashMap::new()),
chat_completions_request: None,
tool_calls: None,
@ -114,32 +89,15 @@ impl StreamContext {
streaming_response: false,
user_prompt: None,
is_chat_completions_request: false,
prompt_guards,
overrides,
_overrides: overrides,
request_id: None,
traceparent: None,
tracing,
_tracing: tracing,
start_upstream_llm_request_time: 0,
time_to_first_token: None,
}
}
fn embeddings_store(&self) -> &EmbeddingsStore {
self.embeddings_store.as_ref().unwrap()
}
pub fn is_embedding_store_initialized(&self) -> bool {
if self.embeddings_store.as_ref().is_none() {
return false;
}
if self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() {
return true;
}
false
}
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
self.send_http_response(
override_status_code
@ -151,190 +109,8 @@ impl StreamContext {
);
}
pub fn get_embeddings(&mut self, callout_context: StreamCallContext) {
let user_message = callout_context.user_message.unwrap();
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let embeddings_request_str: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
warn!("error serializing get embeddings request: {}", error);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", EMBEDDINGS_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
headers,
Some(embeddings_request_str.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::Embeddings,
user_message: Some(user_message),
prompt_target_name: None,
request_body: callout_context.request_body,
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
debug!(
"archgw => get embeddings request: {}",
embeddings_request_str
);
if let Err(e) = self.http_call(call_args, call_context) {
warn!("error dispatching get embeddings request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
pub fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
warn!("error deserializing embedding response: {}", e);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let prompt_embeddings_vector = &embedding_response.data[0].embedding;
trace!(
"embedding model: {}, vector length: {:?}",
embedding_response.model,
prompt_embeddings_vector.len()
);
let prompt_target_names = self
.prompt_targets
.iter()
// exclude default target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(name, _)| name.clone())
.collect();
let similarity_scores: Vec<(String, f64)> = self
.prompt_targets
.iter()
// exclude default prompt target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(prompt_name, _)| {
let pte = match self.embeddings_store().get(prompt_name) {
Some(embeddings) => embeddings,
None => {
warn!(
"embeddings not found for prompt target name: {}",
prompt_name
);
return (prompt_name.clone(), 0.0);
}
};
let description_embeddings = match pte.get(&EmbeddingType::Description) {
Some(embeddings) => embeddings,
None => {
warn!(
"description embeddings not found for prompt target name: {}",
prompt_name
);
return (prompt_name.clone(), 0.0);
}
};
let similarity_score_description =
cos::cosine_similarity(&prompt_embeddings_vector, &description_embeddings);
(prompt_name.clone(), similarity_score_description)
})
.collect();
debug!(
"similarity scores based on description embeddings match: {:?}",
similarity_scores
);
callout_context.similarity_scores = Some(similarity_scores);
let zero_shot_classification_request = ZeroShotClassificationRequest {
// Need to clone into input because user_message is used below.
input: callout_context.user_message.as_ref().unwrap().clone(),
model: String::from(DEFAULT_INTENT_MODEL),
labels: prompt_target_names,
};
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!(
"error serializing zero shot classification request: {}",
error
);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, ZEROSHOT_INTERNAL_HOST),
(":method", "POST"),
(":path", "/zeroshot"),
(":authority", ZEROSHOT_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/zeroshot",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching zero shot classification request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
fn trace_arch_internal(&self) -> bool {
match self.tracing.as_ref() {
fn _trace_arch_internal(&self) -> bool {
match self._tracing.as_ref() {
Some(tracing) => match tracing.trace_arch_internal.as_ref() {
Some(trace_arch_internal) => *trace_arch_internal,
None => false,
@ -343,359 +119,6 @@ impl StreamContext {
}
}
pub fn hallucination_classification_resp_handler(
&mut self,
body: Vec<u8>,
callout_context: StreamCallContext,
) {
let body_str = String::from_utf8(body).expect("could not convert body to string");
debug!("archgw <= hallucination response: {}", body_str);
let hallucination_response: HallucinationClassificationResponse =
match serde_json::from_str(body_str.as_str()) {
Ok(hallucination_response) => hallucination_response,
Err(e) => {
warn!(
"error deserializing hallucination response: {}, body: {}",
e,
body_str.as_str()
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let mut keys_with_low_score: Vec<String> = Vec::new();
for (key, value) in &hallucination_response.params_scores {
if *value < DEFAULT_HALLUCINATED_THRESHOLD {
debug!(
"hallucination detected: score for {} : {} is less than threshold {}",
key, value, DEFAULT_HALLUCINATED_THRESHOLD
);
keys_with_low_score.push(key.clone().to_string());
}
}
if !keys_with_low_score.is_empty() {
let response =
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
let response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(response),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];
to_server_events(chunks)
} else {
let chat_completion_response = ChatCompletionsResponse::new(response);
serde_json::to_string(&chat_completion_response).unwrap()
};
debug!("hallucination response: {:?}", response_str);
// make sure on_http_response_body does not attach tool calls and tool response to the response
self.tool_calls = None;
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![],
Some(response_str.as_bytes()),
);
} else {
// not a hallucination, resume the flow
self.schedule_api_call_request(callout_context);
}
}
pub fn zero_shot_intent_detection_resp_handler(
&mut self,
body: Vec<u8>,
mut callout_context: StreamCallContext,
) {
let zeroshot_intent_response: ZeroShotClassificationResponse =
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
warn!(
"error deserializing zero shot classification response: {}",
e
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
trace!(
"zeroshot intent response: {}",
serde_json::to_string(&zeroshot_intent_response).unwrap()
);
let desc_emb_similarity_map: HashMap<String, f64> = callout_context
.similarity_scores
.clone()
.unwrap()
.into_iter()
.collect();
let pred_class_desc_emb_similarity = desc_emb_similarity_map
.get(&zeroshot_intent_response.predicted_class)
.unwrap();
let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7
+ pred_class_desc_emb_similarity * 0.3;
debug!(
"similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}, prompt: {}",
prompt_target_similarity_score,
zeroshot_intent_response.predicted_class_score,
pred_class_desc_emb_similarity,
callout_context.user_message.as_ref().unwrap()
);
let prompt_target_name = zeroshot_intent_response.predicted_class.clone();
// Check to see who responded to user message. This will help us identify if control should be passed to Arch FC or not.
// If the last message was from Arch FC, then Arch FC is handling the conversation (possibly for parameter collection).
let mut arch_assistant = false;
let messages = &callout_context.request_body.messages;
if messages.len() >= 2 {
let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() {
if model.contains(ARCH_MODEL_PREFIX) {
arch_assistant = true;
}
}
} else {
debug!("no assistant message found, probably first interaction");
}
// get prompt target similarity thresold from overrides
let prompt_target_intent_matching_threshold = match self.overrides.as_ref() {
Some(overrides) => match overrides.prompt_target_intent_matching_threshold {
Some(threshold) => threshold,
None => DEFAULT_PROMPT_TARGET_THRESHOLD,
},
None => DEFAULT_PROMPT_TARGET_THRESHOLD,
};
// check to ensure that the prompt target similarity score is above the threshold
if prompt_target_similarity_score < prompt_target_intent_matching_threshold
|| arch_assistant
{
debug!("intent score is low or arch assistant is handling the conversation");
// if arch fc responded to the user message, then we don't need to check the similarity score
// it may be that arch fc is handling the conversation for parameter collection
if arch_assistant {
info!("arch fc is engaged in parameter collection");
} else if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
} else {
// if no default prompt target is found and similarity score is low send response to upstream llm
// removing tool calls and tool response
let messages = self.filter_out_arch_messages(&callout_context);
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!(
"archgw (low similarity score) => llm request: {}",
llm_request_str
);
self.set_http_request_body(
0,
self.request_body_size,
&llm_request_str.into_bytes(),
);
self.resume_http_request();
return;
}
}
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.expect("prompt target not found")
.clone();
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.values() {
if pt.default.unwrap_or_default() {
continue;
}
// only extract entity names
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
let tools_parameters = FunctionParameters { properties };
chat_completion_tools.push({
ChatCompletionTool {
tool_type: ToolType::Function,
function: FunctionDefinition {
name: pt.name.clone(),
description: pt.description.clone(),
parameters: tools_parameters,
},
}
});
}
// archfc handler needs state so it can expand tool calls
let mut metadata = HashMap::new();
metadata.insert(
ARCH_STATE_HEADER.to_string(),
serde_json::to_string(&self.arch_state).unwrap(),
);
let chat_completions = ChatCompletionsRequest {
model: self
.chat_completions_request
.as_ref()
.unwrap()
.model
.clone(),
messages: callout_context.request_body.messages.clone(),
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
metadata: Some(metadata),
};
let msg_body = match serde_json::to_string(&chat_completions) {
Ok(msg_body) => msg_body,
Err(e) => {
warn!("error serializing arch_fc request body: {}", e);
return self.send_server_error(ServerError::Serialization(e), None);
}
};
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, ARCH_FC_INTERNAL_HOST),
(":path", "/v1/chat/completions"),
(":authority", ARCH_FC_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/v1/chat/completions",
headers,
Some(msg_body.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::ArchFC;
callout_context.prompt_target_name = Some(prompt_target.name);
debug!("archgw => archfc request: {}", msg_body);
if let Err(e) = self.http_call(call_args, callout_context) {
debug!("error dispatching arch_fc request: {}", e);
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
}
pub fn arch_fc_response_handler(
&mut self,
body: Vec<u8>,
@ -704,14 +127,87 @@ impl StreamContext {
let body_str = String::from_utf8(body).unwrap();
debug!("archgw <= archfc response: {}", body_str);
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
warn!("error deserializing archfc response: {}", e);
warn!(
"error deserializing archfc response: {}, body: {}",
e, body_str
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let arch_fc_response = match model_server_response {
ModelServerResponse::ChatCompletionsResponse(response) => response,
ModelServerResponse::ModelServerErrorResponse(response) => {
debug!("archgw <= archfc error response: {}", response.result);
if response.result == "No intent matched" {
if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
// if self.trace_arch_internal() && self.traceparent.is_some() {
// headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
// }
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name =
Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
}
}
return self.send_server_error(
ServerError::LogicError(response.result),
Some(StatusCode::BAD_REQUEST),
);
}
};
arch_fc_response.choices[0]
.message
.tool_calls
@ -767,114 +263,7 @@ impl StreamContext {
);
}
// TODO CO: pass nli check
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let prompt_target = self
.prompt_targets
.get(&tools_call_name)
.expect("prompt target not found for tool call")
.clone();
debug!(
"prompt_target_name: {}, tool_name(s): {:?}",
prompt_target.name,
self.tool_calls
.as_ref()
.unwrap()
.iter()
.map(|tc| tc.function.name.clone())
.collect::<Vec<String>>(),
);
// If hallucination, pass chat template to check parameters
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
debug!(
"tool_params (without messages history): {}",
tool_params_json_str
);
tool_params.insert(
String::from(MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
);
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
use serde_json::Value;
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
let tool_params_dict: HashMap<String, String> = match v.as_object() {
Some(obj) => obj
.iter()
.map(|(key, value)| {
// Convert each value to a string, regardless of its type
(key.clone(), value.to_string())
})
.collect(),
None => HashMap::new(), // Return an empty HashMap if v is not an object
};
let all_user_messages =
extract_messages_for_hallucination(&callout_context.request_body.messages);
let user_messages_str = all_user_messages.join(", ");
debug!("user messages: {}", user_messages_str);
let hallucination_classification_request = HallucinationClassificationRequest {
prompt: user_messages_str,
model: String::from(DEFAULT_INTENT_MODEL),
parameters: tool_params_dict,
};
let hallucination_request_str: String =
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
debug!(
"error serializing hallucination classification request: {}",
error
);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST),
(":method", "POST"),
(":path", "/hallucination"),
(":authority", HALLUCINATION_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/hallucination",
headers,
Some(hallucination_request_str.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::Hallucination;
debug!(
"archgw => hallucination request: {}",
hallucination_request_str
);
if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
}
self.schedule_api_call_request(callout_context);
}
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
@ -969,8 +358,9 @@ impl StreamContext {
pub fn api_call_response_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let http_status = self
.get_http_call_response_header(":status")
.expect("http status code not found");
if http_status != StatusCode::OK.as_str() {
.unwrap_or(StatusCode::OK.as_str().to_string());
debug!("api_call_response_handler: http_status: {}", http_status);
if http_status != StatusCode::OK.as_str() {
warn!(
"api server responded with non 2xx status code: {}",
http_status
@ -1093,56 +483,24 @@ impl StreamContext {
messages
}
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!(
"archgw <= archguard response: {:?}",
serde_json::to_string(&prompt_guard_resp)
);
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
//TODO: handle other scenarios like forward to error target
let msg = self
.prompt_guards
.jailbreak_on_exception_message()
.unwrap_or("refrain from discussing jailbreaking.");
info!("jailbreak detected: {}", msg);
let response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(msg.to_string()),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];
to_server_events(chunks)
} else {
let chat_completion_response = ChatCompletionsResponse::new(msg.to_string());
serde_json::to_string(&chat_completion_response).unwrap()
};
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![],
Some(response_str.as_bytes()),
);
return self.send_server_error(
ServerError::Jailbreak(String::from(msg)),
Some(StatusCode::BAD_REQUEST),
);
pub fn generate_toll_call_message(&mut self) -> Message {
Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
}
}
self.get_embeddings(callout_context);
pub fn generate_api_response_message(&mut self) -> Message {
Message {
role: TOOL_ROLE.to_string(),
content: self.tool_call_response.clone(),
model: None,
tool_calls: None,
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
}
}
pub fn default_target_handler(&self, body: Vec<u8>, mut callout_context: StreamCallContext) {
@ -1264,26 +622,6 @@ impl StreamContext {
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
self.resume_http_request();
}
pub fn generate_toll_call_message(&mut self) -> Message {
Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
}
}
pub fn generate_api_response_message(&mut self) -> Message {
Message {
role: TOOL_ROLE.to_string(),
content: self.tool_call_response.clone(),
model: None,
tool_calls: None,
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
}
}
}
impl Client for StreamContext {