update config (#93)

This commit is contained in:
Adil Hafeez 2024-09-30 17:49:05 -07:00 committed by GitHub
parent 4182879717
commit cc35eb0cd7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 575 additions and 329 deletions

View file

@ -23,7 +23,7 @@ use public_types::common_types::{
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
@ -358,103 +358,97 @@ impl StreamContext {
info!("prompt_target name: {:?}", prompt_target_name);
match prompt_target.prompt_type {
PromptType::FunctionResolver => {
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.read().unwrap().values() {
// only extract entity names
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
let tools_parameters = FunctionParameters { properties };
chat_completion_tools.push({
ChatCompletionTool {
tool_type: ToolType::Function,
function: FunctionDefinition {
name: pt.name.clone(),
description: pt.description.clone(),
parameters: tools_parameters,
},
}
});
//TODO: handle default function resolver type
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.read().unwrap().values() {
// only extract entity names
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
let tools_parameters = FunctionParameters { properties };
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
};
let msg_body = match serde_json::to_string(&chat_completions) {
Ok(msg_body) => {
debug!("arch_fc request body content: {}", msg_body);
msg_body
}
Err(e) => {
return self.send_server_error(
format!("Error serializing request_params: {:?}", e),
None,
);
}
};
let token_id = match self.dispatch_http_call(
ARC_FC_CLUSTER,
vec![
(":method", "POST"),
(":path", "/v1/chat/completions"),
(":authority", ARC_FC_CLUSTER),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
(
"x-envoy-upstream-rq-timeout-ms",
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
),
],
Some(msg_body.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg =
format!("Error dispatching HTTP call for function-call: {:?}", e);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
debug!(
"dispatched call to function {} token_id={}",
ARC_FC_CLUSTER, token_id
);
self.metrics.active_http_calls.increment(1);
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
callout_context.prompt_target_name = Some(prompt_target.name);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
chat_completion_tools.push({
ChatCompletionTool {
tool_type: ToolType::Function,
function: FunctionDefinition {
name: pt.name.clone(),
description: pt.description.clone(),
parameters: tools_parameters,
},
}
});
}
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
};
let msg_body = match serde_json::to_string(&chat_completions) {
Ok(msg_body) => {
debug!("arch_fc request body content: {}", msg_body);
msg_body
}
Err(e) => {
return self
.send_server_error(format!("Error serializing request_params: {:?}", e), None);
}
};
let token_id = match self.dispatch_http_call(
ARC_FC_CLUSTER,
vec![
(":method", "POST"),
(":path", "/v1/chat/completions"),
(":authority", ARC_FC_CLUSTER),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
(
"x-envoy-upstream-rq-timeout-ms",
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
),
],
Some(msg_body.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg = format!("Error dispatching HTTP call for function-call: {:?}", e);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
debug!(
"dispatched call to function {} token_id={}",
ARC_FC_CLUSTER, token_id
);
self.metrics.active_http_calls.increment(1);
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
callout_context.prompt_target_name = Some(prompt_target.name);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
}
@ -530,17 +524,32 @@ impl StreamContext {
debug!("tool_params: {}", tool_params_json_str);
let endpoint = prompt_target.endpoint.unwrap();
let path = endpoint.path.unwrap_or(String::from("/"));
let mut path = endpoint.path.unwrap_or(String::from("/"));
let method = endpoint
.method
.unwrap_or(public_types::configuration::Method::Post);
let mut body = Some(tool_params_json_str.as_bytes());
if method == public_types::configuration::Method::Post {
let mut query_params = vec![];
for (key, value) in tool_params {
query_params.push(format!("{}={}", key, format!("{:?}", value)));
}
let path_args = &query_params.join("&");
path.push_str("?");
path.push_str(path_args);
} else {
body = None;
}
let token_id = match self.dispatch_http_call(
&endpoint.cluster,
&endpoint.name,
vec![
(":method", "POST"),
(":method", method.to_string().as_str()),
(":path", path.as_ref()),
(":authority", endpoint.cluster.as_str()),
(":authority", endpoint.name.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(tool_params_json_str.as_bytes()),
body,
vec![],
Duration::from_secs(5),
) {
@ -548,14 +557,14 @@ impl StreamContext {
Err(e) => {
let error_msg = format!(
"Error dispatching call to cluster: {}, path: {}, err: {:?}",
&endpoint.cluster, path, e
&endpoint.name, path, e
);
debug!("{}", error_msg);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
callout_context.up_stream_cluster = Some(endpoint.cluster);
callout_context.up_stream_cluster = Some(endpoint.name);
callout_context.up_stream_cluster_path = Some(path);
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if self.callouts.insert(token_id, callout_context).is_some() {
@ -682,27 +691,18 @@ impl StreamContext {
if prompt_guard_resp.jailbreak_verdict.is_some()
&& prompt_guard_resp.jailbreak_verdict.unwrap()
{
//TODO: handle other scenarios like forward to error target
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
let error_msg = match self.prompt_guards.as_ref() {
Some(prompt_guards) => match prompt_guards.input_guards.jailbreak.as_ref() {
Some(jailbreak) => match jailbreak.on_exception_message.as_ref() {
Some(error_msg) => error_msg,
None => default_err,
},
None => default_err,
},
None => default_err,
};
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
}
if prompt_guard_resp.toxic_verdict.is_some() && prompt_guard_resp.toxic_verdict.unwrap() {
let default_err = "Toxicity detected. Please refrain from using toxic language.";
let error_msg = match self.prompt_guards.as_ref() {
Some(prompt_guards) => match prompt_guards.input_guards.toxicity.as_ref() {
Some(toxicity) => match toxicity.on_exception_message.as_ref() {
Some(error_msg) => error_msg,
Some(prompt_guards) => match prompt_guards
.input_guards
.get(&public_types::configuration::GuardType::Jailbreak)
{
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
Some(on_exception_details) => match on_exception_details.message.as_ref() {
Some(error_msg) => error_msg,
None => default_err,
},
None => default_err,
},
None => default_err,
@ -883,32 +883,27 @@ impl HttpContext for StreamContext {
}
};
let prompt_guard_task = match (
prompt_guards.input_guards.toxicity.is_some(),
prompt_guards.input_guards.jailbreak.is_some(),
) {
(true, true) => PromptGuardTask::Both,
(true, false) => PromptGuardTask::Toxicity,
(false, true) => PromptGuardTask::Jailbreak,
(false, false) => {
info!("Input guards set but no prompt guards were found");
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
}
};
let prompt_guard_jailbreak_task = prompt_guards
.input_guards
.contains_key(&public_types::configuration::GuardType::Jailbreak);
if !prompt_guard_jailbreak_task {
info!("Input guards set but no prompt guards were found");
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
}
let get_prompt_guards_request = PromptGuardRequest {
input: user_message.clone(),
task: prompt_guard_task,
task: PromptGuardTask::Jailbreak,
};
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {