mirror of
https://github.com/katanemo/plano.git
synced 2026-05-10 16:22:42 +02:00
Handle intent matching better in arch gateway (#391)
This commit is contained in:
parent
10cad4d0b7
commit
e77fc47225
10 changed files with 653 additions and 309 deletions
|
|
@ -2,7 +2,7 @@ use crate::metrics::Metrics;
|
|||
use crate::tools::compute_request_path_body;
|
||||
use common::api::open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
|
||||
ChatCompletionsResponse, Message, ToolCall,
|
||||
};
|
||||
use common::configuration::{Overrides, PromptTarget, Tracing};
|
||||
use common::consts::{
|
||||
|
|
@ -128,7 +128,7 @@ impl StreamContext {
|
|||
debug!("model server response received");
|
||||
trace!("response body: {}", body_str);
|
||||
|
||||
let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) {
|
||||
let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
|
|
@ -139,77 +139,121 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let arch_fc_response = match model_server_response {
|
||||
ModelServerResponse::ChatCompletionsResponse(response) => response,
|
||||
ModelServerResponse::ModelServerErrorResponse(response) => {
|
||||
debug!("archgw <= modelserver error response: {}", response.result);
|
||||
if response.result == "No intent matched" {
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
// intent was matched if we see function_latency in metadata
|
||||
let intent_matched = model_server_response
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.get("function_latency"))
|
||||
.is_some();
|
||||
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();
|
||||
if !intent_matched {
|
||||
debug!("intent not matched");
|
||||
// check if we have a default prompt target
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
// if self.trace_arch_internal() && self.traceparent.is_some() {
|
||||
// headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
// }
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name =
|
||||
Some(default_prompt_target.name.clone());
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
debug!("no default prompt target found, forwarding request to upstream llm");
|
||||
let mut messages = Vec::new();
|
||||
// add system prompt
|
||||
match self.system_prompt.as_ref() {
|
||||
None => {}
|
||||
Some(system_prompt) => {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
return self.send_server_error(
|
||||
ServerError::LogicError(response.result),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
arch_fc_response.choices[0]
|
||||
messages.append(
|
||||
&mut self
|
||||
.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
|
||||
);
|
||||
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: self
|
||||
.chat_completions_request
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let chat_completion_request_json =
|
||||
serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!(
|
||||
"archgw => upstream llm request: {}",
|
||||
chat_completion_request_json
|
||||
);
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_body_size,
|
||||
chat_completion_request_json.as_bytes(),
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
model_server_response.choices[0]
|
||||
.message
|
||||
.tool_calls
|
||||
.clone_into(&mut self.tool_calls);
|
||||
|
|
@ -238,7 +282,7 @@ impl StreamContext {
|
|||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(
|
||||
arch_fc_response.choices[0]
|
||||
model_server_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue