Use intent model from archfc to pick prompt gateway (#328)

This commit is contained in:
Shuguang Chen 2024-12-20 13:25:01 -08:00 committed by GitHub
parent 67b8fd635e
commit ba7279becb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
151 changed files with 8642 additions and 10932 deletions

View file

@ -1,13 +1,12 @@
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
use common::{
api::{
open_ai::{self, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest},
prompt_guard::{PromptGuardRequest, PromptGuardTask},
api::open_ai::{
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
},
consts::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST,
HEALTHZ_PATH, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
MODEL_SERVER_NAME, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
},
errors::ServerError,
http::{CallArgs, Client},
@ -35,11 +34,7 @@ impl HttpContext for StreamContext {
let request_path = self.get_http_request_header(":path").unwrap_or_default();
if request_path == HEALTHZ_PATH {
if self.is_embedding_store_initialized() {
self.send_http_response(200, vec![], None);
} else {
self.send_http_response(503, vec![], None);
}
self.send_http_response(200, vec![], None);
return Action::Continue;
}
@ -138,43 +133,25 @@ impl HttpContext for StreamContext {
self.user_prompt = Some(last_user_prompt.clone());
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
// convert prompt targets to ChatCompletionTool
let tool_calls: Vec<ChatCompletionTool> = self
.prompt_targets
.iter()
.map(|(_, pt)| pt.into())
.collect();
let prompt_guard_jailbreak_task = self
.prompt_guards
.input_guards
.contains_key(&common::configuration::GuardType::Jailbreak);
let arch_fc_chat_completion_request = ChatCompletionsRequest {
messages: deserialized_body.messages.clone(),
metadata: deserialized_body.metadata.clone(),
stream: deserialized_body.stream,
model: "--".to_string(),
stream_options: deserialized_body.stream_options.clone(),
tools: Some(tool_calls),
};
self.chat_completions_request = Some(deserialized_body);
if !prompt_guard_jailbreak_task {
debug!("Missing input guard. Making inline call to retrieve embeddings");
let callout_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: user_message_str.clone(),
prompt_target_name: None,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
}
let get_prompt_guards_request = PromptGuardRequest {
input: self
.user_prompt
.as_ref()
.unwrap()
.content
.as_ref()
.unwrap()
.clone(),
task: PromptGuardTask::Jailbreak,
};
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) {
Ok(json_data) => json_data,
Err(error) => {
self.send_server_error(ServerError::Serialization(error), None);
@ -182,14 +159,14 @@ impl HttpContext for StreamContext {
}
};
debug!("archgw => archfc: {}", json_data);
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST),
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
(":method", "POST"),
(":path", "/guard"),
(":authority", GUARD_INTERNAL_HOST),
(":path", "/function_calling"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
(":authority", MODEL_SERVER_NAME),
];
if self.request_id.is_some() {
@ -202,23 +179,25 @@ impl HttpContext for StreamContext {
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/guard",
"/function_calling",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
response_handler_type: ResponseHandlerType::ArchFC,
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
prompt_target_name: None,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
upstream_cluster_path: Some("/function_calling".to_string()),
};
if let Err(e) = self.http_call(call_args, call_context) {
debug!("http_call failed: {:?}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
@ -337,9 +316,11 @@ impl HttpContext for StreamContext {
let mut data = match serde_json::from_str(&body_utf8) {
Ok(data) => data,
Err(e) => {
warn!("could not deserialize response: {}", e);
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
warn!(
"could not deserialize response, sending data as it is: {}",
e
);
return Action::Continue;
}
};
// use serde::Value to manipulate the json object and ensure that we don't lose any data