update rust side to handle default targets

This commit is contained in:
Adil Hafeez 2025-03-03 15:36:31 -08:00
parent 6be6cc6346
commit 866494da27
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 113 additions and 76 deletions

View file

@ -138,7 +138,7 @@ impl From<String> for ParameterType {
_ => {
log::warn!("Unknown parameter type: {}, assuming type str", s);
ParameterType::String
},
}
}
}
}
@ -205,13 +205,6 @@ pub struct ToolCallState {
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Deserialize, Serialize)]
#[serde(untagged)]
pub enum ModelServerResponse {
ChatCompletionsResponse(ChatCompletionsResponse),
ModelServerErrorResponse(ModelServerErrorResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelServerErrorResponse {
pub result: String,

View file

@ -2,7 +2,7 @@ use crate::metrics::Metrics;
use crate::tools::compute_request_path_body;
use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
ChatCompletionsResponse, Message, ToolCall,
};
use common::configuration::{Overrides, PromptTarget, Tracing};
use common::consts::{
@ -128,7 +128,7 @@ impl StreamContext {
debug!("model server response received");
trace!("response body: {}", body_str);
let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) {
let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
warn!(
@ -139,77 +139,121 @@ impl StreamContext {
}
};
let arch_fc_response = match model_server_response {
ModelServerResponse::ChatCompletionsResponse(response) => response,
ModelServerResponse::ModelServerErrorResponse(response) => {
debug!("archgw <= modelserver error response: {}", response.result);
if response.result == "No intent matched" {
if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
// intent was matched if we see function_latency in metadata
let intent_matched = model_server_response
.metadata
.as_ref()
.and_then(|metadata| metadata.get("function_latency"))
.is_some();
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();
if !intent_matched {
debug!("intent not matched");
// check if we have a default prompt target
if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string();
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
// if self.trace_arch_internal() && self.traceparent.is_some() {
// headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
// }
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name =
Some(default_prompt_target.name.clone());
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
} else {
debug!("no default prompt target found, forwarding request to upstream llm");
let mut messages = Vec::new();
// add system prompt
match self.system_prompt.as_ref() {
None => {}
Some(system_prompt) => {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
}
return self.send_server_error(
ServerError::LogicError(response.result),
Some(StatusCode::BAD_REQUEST),
);
}
};
arch_fc_response.choices[0]
messages.append(
&mut self
.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
);
let chat_completion_request = ChatCompletionsRequest {
model: self
.chat_completions_request
.as_ref()
.unwrap()
.model
.clone(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let chat_completion_request_json =
serde_json::to_string(&chat_completion_request).unwrap();
debug!(
"archgw => upstream llm request: {}",
chat_completion_request_json
);
self.set_http_request_body(
0,
self.request_body_size,
chat_completion_request_json.as_bytes(),
);
self.resume_http_request();
return;
}
}
model_server_response.choices[0]
.message
.tool_calls
.clone_into(&mut self.tool_calls);
@ -238,7 +282,7 @@ impl StreamContext {
),
ChatCompletionStreamResponse::new(
Some(
arch_fc_response.choices[0]
model_server_response.choices[0]
.message
.content
.as_ref()

View file

@ -19,7 +19,7 @@ endpoints:
protocol: https
system_prompt: |
You are a helpful assistant.
You are a helpful assistant. Only respond to queries related to currency exchange. If there are any other questions, I can't help you.
prompt_guards:
input_guards:

View file

@ -171,7 +171,7 @@ class ArchBaseHandler:
assert processed_messages[-1]["role"] == "user"
if extra_instruction:
processed_messages[-1]["content"] += extra_instruction
processed_messages[-1]["content"] += "\n" + extra_instruction
# keep the first system message and shift conversation if the total token length exceeds the limit
def truncate_messages(messages: List[Dict[str, Any]]):

View file

@ -104,7 +104,7 @@ async def function_calling(req: ChatMessage, res: Response):
res.status_code = 500
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
else:
# TODO: make a call to default LLM to get responses
# no intent matched
intent_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
}