mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 21:02:56 +02:00
update config (#93)
This commit is contained in:
parent
4182879717
commit
cc35eb0cd7
13 changed files with 575 additions and 329 deletions
|
|
@ -23,7 +23,7 @@ use public_types::common_types::{
|
|||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
|
|
@ -358,103 +358,97 @@ impl StreamContext {
|
|||
|
||||
info!("prompt_target name: {:?}", prompt_target_name);
|
||||
|
||||
match prompt_target.prompt_type {
|
||||
PromptType::FunctionResolver => {
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: ParameterType::from(
|
||||
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
||||
),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
default: entity.default.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = FunctionParameters { properties };
|
||||
|
||||
chat_completion_tools.push({
|
||||
ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
},
|
||||
}
|
||||
});
|
||||
//TODO: handle default function resolver type
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: ParameterType::from(
|
||||
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
||||
),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
default: entity.default.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = FunctionParameters { properties };
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
Ok(msg_body) => {
|
||||
debug!("arch_fc request body content: {}", msg_body);
|
||||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
format!("Error serializing request_params: {:?}", e),
|
||||
None,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
ARC_FC_CLUSTER,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", ARC_FC_CLUSTER),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
(
|
||||
"x-envoy-upstream-rq-timeout-ms",
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||
),
|
||||
],
|
||||
Some(msg_body.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg =
|
||||
format!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"dispatched call to function {} token_id={}",
|
||||
ARC_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
chat_completion_tools.push({
|
||||
ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
Ok(msg_body) => {
|
||||
debug!("arch_fc request body content: {}", msg_body);
|
||||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
return self
|
||||
.send_server_error(format!("Error serializing request_params: {:?}", e), None);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
ARC_FC_CLUSTER,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", ARC_FC_CLUSTER),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
(
|
||||
"x-envoy-upstream-rq-timeout-ms",
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||
),
|
||||
],
|
||||
Some(msg_body.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg = format!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"dispatched call to function {} token_id={}",
|
||||
ARC_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -530,17 +524,32 @@ impl StreamContext {
|
|||
debug!("tool_params: {}", tool_params_json_str);
|
||||
|
||||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
let path = endpoint.path.unwrap_or(String::from("/"));
|
||||
let mut path = endpoint.path.unwrap_or(String::from("/"));
|
||||
let method = endpoint
|
||||
.method
|
||||
.unwrap_or(public_types::configuration::Method::Post);
|
||||
let mut body = Some(tool_params_json_str.as_bytes());
|
||||
if method == public_types::configuration::Method::Post {
|
||||
let mut query_params = vec![];
|
||||
for (key, value) in tool_params {
|
||||
query_params.push(format!("{}={}", key, format!("{:?}", value)));
|
||||
}
|
||||
let path_args = &query_params.join("&");
|
||||
path.push_str("?");
|
||||
path.push_str(path_args);
|
||||
} else {
|
||||
body = None;
|
||||
}
|
||||
let token_id = match self.dispatch_http_call(
|
||||
&endpoint.cluster,
|
||||
&endpoint.name,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":method", method.to_string().as_str()),
|
||||
(":path", path.as_ref()),
|
||||
(":authority", endpoint.cluster.as_str()),
|
||||
(":authority", endpoint.name.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(tool_params_json_str.as_bytes()),
|
||||
body,
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
|
|
@ -548,14 +557,14 @@ impl StreamContext {
|
|||
Err(e) => {
|
||||
let error_msg = format!(
|
||||
"Error dispatching call to cluster: {}, path: {}, err: {:?}",
|
||||
&endpoint.cluster, path, e
|
||||
&endpoint.name, path, e
|
||||
);
|
||||
debug!("{}", error_msg);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
callout_context.up_stream_cluster = Some(endpoint.cluster);
|
||||
callout_context.up_stream_cluster = Some(endpoint.name);
|
||||
callout_context.up_stream_cluster_path = Some(path);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
|
|
@ -682,27 +691,18 @@ impl StreamContext {
|
|||
if prompt_guard_resp.jailbreak_verdict.is_some()
|
||||
&& prompt_guard_resp.jailbreak_verdict.unwrap()
|
||||
{
|
||||
//TODO: handle other scenarios like forward to error target
|
||||
let default_err = "Jailbreak detected. Please refrain from discussing jailbreaking.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards.input_guards.jailbreak.as_ref() {
|
||||
Some(jailbreak) => match jailbreak.on_exception_message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
};
|
||||
|
||||
return self.send_server_error(error_msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
|
||||
if prompt_guard_resp.toxic_verdict.is_some() && prompt_guard_resp.toxic_verdict.unwrap() {
|
||||
let default_err = "Toxicity detected. Please refrain from using toxic language.";
|
||||
let error_msg = match self.prompt_guards.as_ref() {
|
||||
Some(prompt_guards) => match prompt_guards.input_guards.toxicity.as_ref() {
|
||||
Some(toxicity) => match toxicity.on_exception_message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
Some(prompt_guards) => match prompt_guards
|
||||
.input_guards
|
||||
.get(&public_types::configuration::GuardType::Jailbreak)
|
||||
{
|
||||
Some(jailbreak) => match jailbreak.on_exception.as_ref() {
|
||||
Some(on_exception_details) => match on_exception_details.message.as_ref() {
|
||||
Some(error_msg) => error_msg,
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
},
|
||||
None => default_err,
|
||||
|
|
@ -883,32 +883,27 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let prompt_guard_task = match (
|
||||
prompt_guards.input_guards.toxicity.is_some(),
|
||||
prompt_guards.input_guards.jailbreak.is_some(),
|
||||
) {
|
||||
(true, true) => PromptGuardTask::Both,
|
||||
(true, false) => PromptGuardTask::Toxicity,
|
||||
(false, true) => PromptGuardTask::Jailbreak,
|
||||
(false, false) => {
|
||||
info!("Input guards set but no prompt guards were found");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
let prompt_guard_jailbreak_task = prompt_guards
|
||||
.input_guards
|
||||
.contains_key(&public_types::configuration::GuardType::Jailbreak);
|
||||
if !prompt_guard_jailbreak_task {
|
||||
info!("Input guards set but no prompt guards were found");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
up_stream_cluster: None,
|
||||
up_stream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
let get_prompt_guards_request = PromptGuardRequest {
|
||||
input: user_message.clone(),
|
||||
task: prompt_guard_task,
|
||||
task: PromptGuardTask::Jailbreak,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue