mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 21:02:56 +02:00
add support for default target (#111)
* add support for default target * add more fixes
This commit is contained in:
parent
c8d0dbec26
commit
1b57a49c9d
8 changed files with 215 additions and 88 deletions
|
|
@ -39,6 +39,7 @@ enum ResponseHandlerType {
|
|||
FunctionCall,
|
||||
ZeroShotIntent,
|
||||
ArchGuard,
|
||||
DefaultTarget,
|
||||
}
|
||||
|
||||
pub struct CallContext {
|
||||
|
|
@ -179,12 +180,16 @@ impl StreamContext {
|
|||
|
||||
let prompt_target_names = prompt_targets
|
||||
.iter()
|
||||
// exclude default target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect();
|
||||
|
||||
let similarity_scores: Vec<(String, f64)> = prompt_targets
|
||||
.iter()
|
||||
.map(|(prompt_name, _prompt_target)| {
|
||||
// exclude default prompt target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(prompt_name, _)| {
|
||||
let default_embeddings = HashMap::new();
|
||||
let pte = prompt_target_embeddings
|
||||
.get(prompt_name)
|
||||
|
|
@ -331,34 +336,84 @@ impl StreamContext {
|
|||
|
||||
// check to ensure that the prompt target similarity score is above the threshold
|
||||
if prompt_target_similarity_score < prompt_target_intent_matching_threshold
|
||||
&& !arch_assistant
|
||||
|| arch_assistant
|
||||
{
|
||||
debug!("intent score is low or arch assistant is handling the conversation");
|
||||
// if arch fc responded to the user message, then we don't need to check the similarity score
|
||||
// it may be that arch fc is handling the conversation for parameter collection
|
||||
if arch_assistant {
|
||||
info!("arch assistant is handling the conversation");
|
||||
} else {
|
||||
info!(
|
||||
"prompt target below limit: {:.3}, threshold: {:.3}, continue conversation with user",
|
||||
prompt_target_similarity_score,
|
||||
prompt_target_intent_matching_threshold
|
||||
);
|
||||
debug!("checking for default prompt target");
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
ARCH_MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
debug!("no prompt target found with similarity score above threshold, using default prompt target");
|
||||
let token_id = match self.dispatch_http_call(
|
||||
&upstream_endpoint,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
(
|
||||
"x-envoy-upstream-rq-timeout-ms",
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||
),
|
||||
],
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg =
|
||||
format!("Error dispatching HTTP call for default-target: {:?}", e);
|
||||
return self
|
||||
.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
return;
|
||||
}
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&prompt_target_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) {
|
||||
Some(prompt_target) => prompt_target.clone(),
|
||||
None => {
|
||||
return self.send_server_error(
|
||||
format!("Prompt target not found: {}", prompt_target_name),
|
||||
None,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
info!("prompt_target name: {:?}", prompt_target_name);
|
||||
|
||||
//TODO: handle default function resolver type
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
// only extract entity names
|
||||
|
|
@ -761,6 +816,83 @@ impl StreamContext {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
debug!(
|
||||
"response received for default target: {}",
|
||||
prompt_target.name
|
||||
);
|
||||
// check if the default target should be dispatched to the LLM provider
|
||||
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
|
||||
let default_target_response_str = String::from_utf8(body).unwrap();
|
||||
debug!(
|
||||
"sending response back to developer: {}",
|
||||
default_target_response_str
|
||||
);
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(default_target_response_str.as_bytes()),
|
||||
);
|
||||
// self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
debug!("default_target: sending api response to default llm");
|
||||
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(chat_completions_resp) => chat_completions_resp,
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
format!("Error deserializing default target response: {:?}", e),
|
||||
None,
|
||||
);
|
||||
}
|
||||
};
|
||||
let api_resp = chat_completions_resp.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
let mut messages = callout_context.request_body.messages;
|
||||
|
||||
// add system prompt
|
||||
match prompt_target.system_prompt.as_ref() {
|
||||
None => {}
|
||||
Some(system_prompt) => {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(api_resp.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
};
|
||||
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!("sending response back to default llm: {}", json_resp);
|
||||
self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
|
|
@ -1067,6 +1199,9 @@ impl Context for StreamContext {
|
|||
self.function_call_response_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||
ResponseHandlerType::DefaultTarget => {
|
||||
self.default_target_handler(body, callout_context)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.send_server_error(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue