mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
update rust side to handle default targets
This commit is contained in:
parent
6be6cc6346
commit
866494da27
5 changed files with 113 additions and 76 deletions
|
|
@ -138,7 +138,7 @@ impl From<String> for ParameterType {
|
|||
_ => {
|
||||
log::warn!("Unknown parameter type: {}, assuming type str", s);
|
||||
ParameterType::String
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -205,13 +205,6 @@ pub struct ToolCallState {
|
|||
pub enum ArchState {
|
||||
ToolCall(Vec<ToolCallState>),
|
||||
}
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ModelServerResponse {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
ModelServerErrorResponse(ModelServerErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelServerErrorResponse {
|
||||
pub result: String,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use crate::metrics::Metrics;
|
|||
use crate::tools::compute_request_path_body;
|
||||
use common::api::open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
|
||||
ChatCompletionsResponse, Message, ToolCall,
|
||||
};
|
||||
use common::configuration::{Overrides, PromptTarget, Tracing};
|
||||
use common::consts::{
|
||||
|
|
@ -128,7 +128,7 @@ impl StreamContext {
|
|||
debug!("model server response received");
|
||||
trace!("response body: {}", body_str);
|
||||
|
||||
let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) {
|
||||
let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
|
|
@ -139,77 +139,121 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let arch_fc_response = match model_server_response {
|
||||
ModelServerResponse::ChatCompletionsResponse(response) => response,
|
||||
ModelServerResponse::ModelServerErrorResponse(response) => {
|
||||
debug!("archgw <= modelserver error response: {}", response.result);
|
||||
if response.result == "No intent matched" {
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
// intent was matched if we see function_latency in metadata
|
||||
let intent_matched = model_server_response
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.get("function_latency"))
|
||||
.is_some();
|
||||
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();
|
||||
if !intent_matched {
|
||||
debug!("intent not matched");
|
||||
// check if we have a default prompt target
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
// if self.trace_arch_internal() && self.traceparent.is_some() {
|
||||
// headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
// }
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name =
|
||||
Some(default_prompt_target.name.clone());
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
debug!("no default prompt target found, forwarding request to upstream llm");
|
||||
let mut messages = Vec::new();
|
||||
// add system prompt
|
||||
match self.system_prompt.as_ref() {
|
||||
None => {}
|
||||
Some(system_prompt) => {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
return self.send_server_error(
|
||||
ServerError::LogicError(response.result),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
arch_fc_response.choices[0]
|
||||
messages.append(
|
||||
&mut self
|
||||
.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
|
||||
);
|
||||
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: self
|
||||
.chat_completions_request
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let chat_completion_request_json =
|
||||
serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!(
|
||||
"archgw => upstream llm request: {}",
|
||||
chat_completion_request_json
|
||||
);
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_body_size,
|
||||
chat_completion_request_json.as_bytes(),
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
model_server_response.choices[0]
|
||||
.message
|
||||
.tool_calls
|
||||
.clone_into(&mut self.tool_calls);
|
||||
|
|
@ -238,7 +282,7 @@ impl StreamContext {
|
|||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(
|
||||
arch_fc_response.choices[0]
|
||||
model_server_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ endpoints:
|
|||
protocol: https
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
You are a helpful assistant. Only respond to queries related to currency exchange. If there are any other questions, I can't help you.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class ArchBaseHandler:
|
|||
assert processed_messages[-1]["role"] == "user"
|
||||
|
||||
if extra_instruction:
|
||||
processed_messages[-1]["content"] += extra_instruction
|
||||
processed_messages[-1]["content"] += "\n" + extra_instruction
|
||||
|
||||
# keep the first system message and shift conversation if the total token length exceeds the limit
|
||||
def truncate_messages(messages: List[Dict[str, Any]]):
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
res.status_code = 500
|
||||
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
|
||||
else:
|
||||
# TODO: make a call to default LLM to get responses
|
||||
# no intent matched
|
||||
intent_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue