From 866494da271f3ceedaca03014aa9124ea9e12670 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 3 Mar 2025 15:36:31 -0800 Subject: [PATCH] update rust side to handle default targets --- crates/common/src/api/open_ai.rs | 9 +- crates/prompt_gateway/src/stream_context.rs | 174 +++++++++++------- .../currency_exchange/arch_config.yaml | 2 +- model_server/src/core/utils/model_utils.py | 2 +- model_server/src/main.py | 2 +- 5 files changed, 113 insertions(+), 76 deletions(-) diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index 2a07ce3f..f318ebc4 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -138,7 +138,7 @@ impl From for ParameterType { _ => { log::warn!("Unknown parameter type: {}, assuming type str", s); ParameterType::String - }, + } } } } @@ -205,13 +205,6 @@ pub struct ToolCallState { pub enum ArchState { ToolCall(Vec), } -#[derive(Deserialize, Serialize)] -#[serde(untagged)] -pub enum ModelServerResponse { - ChatCompletionsResponse(ChatCompletionsResponse), - ModelServerErrorResponse(ModelServerErrorResponse), -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelServerErrorResponse { pub result: String, diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index d197b3e0..87cf3ee9 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -2,7 +2,7 @@ use crate::metrics::Metrics; use crate::tools::compute_request_path_body; use common::api::open_ai::{ to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, - ChatCompletionsResponse, Message, ModelServerResponse, ToolCall, + ChatCompletionsResponse, Message, ToolCall, }; use common::configuration::{Overrides, PromptTarget, Tracing}; use common::consts::{ @@ -128,7 +128,7 @@ impl StreamContext { debug!("model server response received"); trace!("response body: {}", body_str); - let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) { + let model_server_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { warn!( @@ -139,77 +139,121 @@ impl StreamContext { } }; - let arch_fc_response = match model_server_response { - ModelServerResponse::ChatCompletionsResponse(response) => response, - ModelServerResponse::ModelServerErrorResponse(response) => { - debug!("archgw <= modelserver error response: {}", response.result); - if response.result == "No intent matched" { - if let Some(default_prompt_target) = self - .prompt_targets - .values() - .find(|pt| pt.default.unwrap_or(false)) - { - debug!("default prompt target found, forwarding request to default prompt target"); - let endpoint = default_prompt_target.endpoint.clone().unwrap(); - let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); + // intent was matched if we see function_latency in metadata + let intent_matched = model_server_response + .metadata + .as_ref() + .and_then(|metadata| metadata.get("function_latency")) + .is_some(); - let upstream_endpoint = endpoint.name; - let mut params = HashMap::new(); - params.insert( - MESSAGES_KEY.to_string(), - callout_context.request_body.messages.clone(), - ); - let arch_messages_json = serde_json::to_string(¶ms).unwrap(); - let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string(); + if !intent_matched { + debug!("intent not matched"); + // check if we have a default prompt target + if let Some(default_prompt_target) = self + .prompt_targets + .values() + .find(|pt| pt.default.unwrap_or(false)) + { + debug!("default prompt target found, forwarding request to default prompt target"); + let endpoint = default_prompt_target.endpoint.clone().unwrap(); + let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); - let mut headers = vec![ - (":method", "POST"), - (ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint), - (":path", &upstream_path), - (":authority", &upstream_endpoint), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), - ]; + let upstream_endpoint = endpoint.name; + let mut params = HashMap::new(); + params.insert( + MESSAGES_KEY.to_string(), + callout_context.request_body.messages.clone(), + ); + let arch_messages_json = serde_json::to_string(¶ms).unwrap(); + let timeout_str = DEFAULT_TARGET_REQUEST_TIMEOUT_MS.to_string(); - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } + let mut headers = vec![ + (":method", "POST"), + (ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint), + (":path", &upstream_path), + (":authority", &upstream_endpoint), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), + ]; - // if self.trace_arch_internal() && self.traceparent.is_some() { - // headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap())); - // } + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); + } - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - &upstream_path, - headers, - Some(arch_messages_json.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::DefaultTarget; - callout_context.prompt_target_name = - Some(default_prompt_target.name.clone()); + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + &upstream_path, + headers, + Some(arch_messages_json.as_bytes()), + vec![], + Duration::from_secs(5), + ); + callout_context.response_handler_type = ResponseHandlerType::DefaultTarget; + callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); - if let Err(e) = self.http_call(call_args, callout_context) { - warn!("error dispatching default prompt target request: {}", e); - return self.send_server_error( - ServerError::HttpDispatch(e), - Some(StatusCode::BAD_REQUEST), - ); - } - return; + if let Err(e) = self.http_call(call_args, callout_context) { + warn!("error dispatching default prompt target request: {}", e); + return self.send_server_error( + ServerError::HttpDispatch(e), + Some(StatusCode::BAD_REQUEST), + ); + } + return; + } else { + debug!("no default prompt target found, forwarding request to upstream llm"); + let mut messages = Vec::new(); + // add system prompt + match self.system_prompt.as_ref() { + None => {} + Some(system_prompt) => { + let system_prompt_message = Message { + role: SYSTEM_ROLE.to_string(), + content: Some(system_prompt.clone()), + model: None, + tool_calls: None, + tool_call_id: None, + }; + messages.push(system_prompt_message); } } - return self.send_server_error( - ServerError::LogicError(response.result), - Some(StatusCode::BAD_REQUEST), - ); - } - }; - arch_fc_response.choices[0] + messages.append( + &mut self + .filter_out_arch_messages(callout_context.request_body.messages.as_ref()), + ); + + let chat_completion_request = ChatCompletionsRequest { + model: self + .chat_completions_request + .as_ref() + .unwrap() + .model + .clone(), + messages, + tools: None, + stream: callout_context.request_body.stream, + stream_options: callout_context.request_body.stream_options, + metadata: None, + }; + + let chat_completion_request_json = + serde_json::to_string(&chat_completion_request).unwrap(); + debug!( + "archgw => upstream llm request: {}", + chat_completion_request_json + ); + self.set_http_request_body( + 0, + self.request_body_size, + chat_completion_request_json.as_bytes(), + ); + self.resume_http_request(); + return; + } + } + + model_server_response.choices[0] .message .tool_calls .clone_into(&mut self.tool_calls); @@ -238,7 +282,7 @@ impl StreamContext { ), ChatCompletionStreamResponse::new( Some( - arch_fc_response.choices[0] + model_server_response.choices[0] .message .content .as_ref() diff --git a/demos/samples_python/currency_exchange/arch_config.yaml b/demos/samples_python/currency_exchange/arch_config.yaml index 1475abca..03e5a01d 100644 --- a/demos/samples_python/currency_exchange/arch_config.yaml +++ b/demos/samples_python/currency_exchange/arch_config.yaml @@ -19,7 +19,7 @@ endpoints: protocol: https system_prompt: | - You are a helpful assistant. + You are a helpful assistant. Only respond to queries related to currency exchange. If there are any other questions, I can't help you. prompt_guards: input_guards: diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py index 7dc71acf..73dc6fef 100644 --- a/model_server/src/core/utils/model_utils.py +++ b/model_server/src/core/utils/model_utils.py @@ -171,7 +171,7 @@ class ArchBaseHandler: assert processed_messages[-1]["role"] == "user" if extra_instruction: - processed_messages[-1]["content"] += extra_instruction + processed_messages[-1]["content"] += "\n" + extra_instruction # keep the first system message and shift conversation if the total token length exceeds the limit def truncate_messages(messages: List[Dict[str, Any]]): diff --git a/model_server/src/main.py b/model_server/src/main.py index c9ad04b6..d1916eaa 100644 --- a/model_server/src/main.py +++ b/model_server/src/main.py @@ -104,7 +104,7 @@ async def function_calling(req: ChatMessage, res: Response): res.status_code = 500 error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}" else: - # TODO: make a call to default LLM to get responses + # no intent matched intent_response.metadata = { "intent_latency": str(round(intent_latency * 1000, 3)), }