diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index cd52220e..8a12f1d1 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -13,8 +13,11 @@ pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const HEALTHZ_PATH: &str = "/healthz"; -pub const ARCH_STATE_HEADER: &str = "x-arch-state"; -pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; +pub const X_ARCH_STATE_HEADER: &str = "x-arch-state"; +pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message"; +pub const X_ARCH_TOOL_CALL: &str = "x-arch-tool-call-message"; +pub const X_ARCH_FC_MODEL_RESPONSE: &str = "x-arch-fc-model-response"; +pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; pub const TRACE_PARENT_HEADER: &str = "traceparent"; pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 78d7e21e..9a6e3985 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -411,7 +411,7 @@ impl HttpContext for StreamContext { ); if self.request_body_sent_time.is_none() { - debug!("on_http_response_body: request body not sent, no doing any processing in llm filter"); + debug!("on_http_response_body: request body not sent, not doing any processing in llm filter"); return Action::Continue; } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 3a7dc7d9..ac747a6b 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -4,10 +4,11 @@ use common::{ self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, }, consts::{ - ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, + ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE, - TRACE_PARENT_HEADER, USER_ROLE, + TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE, + X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL, }, errors::ServerError, http::{CallArgs, Client}, @@ -125,8 +126,8 @@ impl HttpContext for StreamContext { self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { - if metadata.contains_key(ARCH_STATE_HEADER) { - let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); + if metadata.contains_key(X_ARCH_STATE_HEADER) { + let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone(); let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); Some(arch_state) } else { @@ -336,10 +337,10 @@ impl HttpContext for StreamContext { if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() { let chunks = vec![ ChatCompletionStreamResponse::new( - None, + self.arch_fc_response.clone(), Some(ASSISTANT_ROLE.to_string()), Some(ARCH_FC_MODEL_NAME.to_string()), - self.tool_calls.to_owned(), + None, ), ChatCompletionStreamResponse::new( self.tool_call_response.clone(), @@ -381,17 +382,39 @@ impl HttpContext for StreamContext { *metadata = Value::Object(serde_json::Map::new()); } - let fc_messages = vec![ - self.generate_toll_call_message(), - self.generate_api_response_message(), - ]; + let tool_call_message = self.generate_toll_call_message(); + let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_TOOL_CALL.to_string(), + serde_json::Value::String(tool_call_message_str), + ); + + let api_response_message = self.generate_api_response_message(); + let api_response_message_str = + serde_json::to_string(&api_response_message).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_API_RESPONSE.to_string(), + serde_json::Value::String(api_response_message_str), + ); + + let fc_messages = vec![tool_call_message, api_response_message]; + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); let arch_state_str = serde_json::to_string(&arch_state).unwrap(); metadata.as_object_mut().unwrap().insert( - ARCH_STATE_HEADER.to_string(), + X_ARCH_STATE_HEADER.to_string(), serde_json::Value::String(arch_state_str), ); + + if let Some(arch_fc_response) = self.arch_fc_response.as_ref() { + metadata.as_object_mut().unwrap().insert( + X_ARCH_FC_MODEL_RESPONSE.to_string(), + serde_json::Value::String( + serde_json::to_string(arch_fc_response).unwrap(), + ), + ); + } let data_serialized = serde_json::to_string(&data).unwrap(); info!("archgw <= developer: {}", data_serialized); self.set_http_response_body(0, body_size, data_serialized.as_bytes()); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index ec68c104..b2950d2f 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -9,6 +9,7 @@ use common::consts::{ API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, + X_ARCH_FC_MODEL_RESPONSE, }; use common::errors::ServerError; use common::http::{CallArgs, Client}; @@ -64,10 +65,10 @@ pub struct StreamContext { pub time_to_first_token: Option, pub traceparent: Option, pub _tracing: Rc>, + pub arch_fc_response: Option, } impl StreamContext { - #[allow(clippy::too_many_arguments)] pub fn new( context_id: u32, metrics: Rc, @@ -98,6 +99,7 @@ impl StreamContext { _tracing: tracing, start_upstream_llm_request_time: 0, time_to_first_token: None, + arch_fc_response: None, } } @@ -142,15 +144,17 @@ impl StreamContext { } }; - // intent was matched if we see function_latency in metadata - let intent_matched = model_server_response + let intent_matched = check_intent_matched(&model_server_response); + info!("intent matched: {}", intent_matched); + + self.arch_fc_response = model_server_response .metadata .as_ref() - .and_then(|metadata| metadata.get("function_latency")) - .is_some(); + .and_then(|metadata| metadata.get(X_ARCH_FC_MODEL_RESPONSE)) + .cloned(); + if !intent_matched { - info!("intent not matched"); // check if we have a default prompt target if let Some(default_prompt_target) = self .prompt_targets @@ -278,9 +282,9 @@ impl StreamContext { let direct_response_str = if self.streaming_response { let chunks = vec![ ChatCompletionStreamResponse::new( - None, + self.arch_fc_response.clone(), Some(ASSISTANT_ROLE.to_string()), - Some(ARCH_FC_MODEL_NAME.to_owned()), + Some(ARCH_FC_MODEL_NAME.to_string()), None, ), ChatCompletionStreamResponse::new( @@ -293,7 +297,7 @@ impl StreamContext { .clone(), ), None, - Some(ARCH_FC_MODEL_NAME.to_owned()), + Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())), None, ), ]; @@ -624,12 +628,23 @@ impl StreamContext { } pub fn generate_toll_call_message(&mut self) -> Message { - Message { - role: ASSISTANT_ROLE.to_string(), - content: None, - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: self.tool_calls.clone(), - tool_call_id: None, + if self.arch_fc_response.is_none() { + info!("arch_fc_response is none, generating tool call message"); + Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + } + } else { + Message { + role: ASSISTANT_ROLE.to_string(), + content: self.arch_fc_response.as_ref().cloned(), + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: None, + tool_call_id: None, + } } } @@ -761,6 +776,26 @@ impl StreamContext { } } +fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool { + let content = model_server_response + .choices + .get(0) + .and_then(|choice| choice.message.content.as_ref()); + + let content_has_value = content.is_some() && !content.unwrap().is_empty(); + + let tool_calls = model_server_response + .choices + .get(0) + .and_then(|choice| choice.message.tool_calls.as_ref()); + + // intent was matched if content has some value or tool_calls is empty + let intent_matched = + content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty()); + + return intent_matched; +} + impl Client for StreamContext { type CallContext = StreamCallContext; @@ -772,3 +807,77 @@ impl Client for StreamContext { &self.metrics.active_http_calls } } + +#[cfg(test)] +mod test { + use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall}; + + use crate::stream_context::check_intent_matched; + + #[test] + fn test_intent_matched() { + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("".to_string()), + tool_calls: Some(vec![]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert_eq!(check_intent_matched(&model_server_response), false); + + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("hello".to_string()), + tool_calls: Some(vec![]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert_eq!(check_intent_matched(&model_server_response), true); + + let model_server_response = ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + content: Some("".to_string()), + tool_calls: Some(vec![ToolCall { + id: "1".to_string(), + function: common::api::open_ai::FunctionCallDetail { + name: "test".to_string(), + arguments: None, + }, + tool_type: common::api::open_ai::ToolType::Function, + }]), + role: "assistant".to_string(), + model: None, + tool_call_id: None, + }, + finish_reason: None, + index: None, + }], + usage: None, + model: "arch-fc".to_string(), + metadata: None, + }; + + assert_eq!(check_intent_matched(&model_server_response), true); + } +} diff --git a/demos/shared/chatbot_ui/common.py b/demos/shared/chatbot_ui/common.py index 42e50bd4..1de8f94c 100644 --- a/demos/shared/chatbot_ui/common.py +++ b/demos/shared/chatbot_ui/common.py @@ -120,8 +120,11 @@ def process_stream_chunk(chunk, history): if delta.content: # append content to the last history item - history[-1]["content"] = history[-1].get("content", "") + delta.content + if history[-1]["model"] != "Arch-Function-Chat": + history[-1]["content"] = history[-1].get("content", "") + delta.content # yield content if it is from assistant + if history[-1]["model"] == "Arch-Function": + return None if history[-1]["role"] == "assistant": return delta.content diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index f5b2cc44..75673f66 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -197,12 +197,12 @@ class ArchFunctionHandler(ArchBaseHandler): response_dict["response"] = model_response.get("response", "") response_dict["required_functions"] = model_response.get( - "required_functions", "" + "required_functions", [] ) response_dict["clarification"] = model_response.get("clarification", "") for tool_call in model_response.get("tool_calls", []): - response_dict["tool_call"].append( + response_dict["tool_calls"].append( { "id": f"call_{random.randint(1000, 10000)}", "type": "function", @@ -448,6 +448,7 @@ class ArchFunctionHandler(ArchBaseHandler): if len(chunk.choices) > 0 and chunk.choices[0].delta.content: model_response += chunk.choices[0].delta.content + logger.info(f"[arch-fc]: raw model response: {model_response}") # Extract tool calls from model response response_dict = self._parse_model_resonse(model_response) @@ -499,10 +500,15 @@ class ArchFunctionHandler(ArchBaseHandler): model_message = Message(content="", tool_calls=[]) chat_completion_response = ChatCompletionResponse( - choices=[Choice(message=model_message)], model=self.model_name + choices=[Choice(message=model_message)], + model=self.model_name, + metadata={"x-arch-fc-model-response": model_response}, + role="assistant", ) - logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}") + logger.info( + f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump())}" + ) return chat_completion_response diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py index 0c42c796..2dfd2237 100644 --- a/model_server/src/core/utils/model_utils.py +++ b/model_server/src/core/utils/model_utils.py @@ -142,7 +142,7 @@ class ArchBaseHandler: {"role": "system", "content": self._format_system_prompt(tools)} ) - for message in messages: + for idx, message in enumerate(messages): role, content, tool_calls = ( message.role, message.content, @@ -158,9 +158,17 @@ class ArchBaseHandler: if metadata.get("optimize_context_window", "false").lower() == "true": content = f"\n\n" else: - content = ( - f"\n{json.dumps(content)}\n" - ) + # sample response below + # "content": "\n{'name': 'get_stock_price', 'result': '$196.66'}\n" + # msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}' + func_name = json.loads(messages[idx - 1].content)["tool_calls"][ + 0 + ].get("name", "no_name") + tool_response = { + "name": func_name, + "result": content, + } + content = f"\n{json.dumps(tool_response)}\n" processed_messages.append({"role": role, "content": content}) diff --git a/model_server/src/main.py b/model_server/src/main.py index b0434222..c3070392 100644 --- a/model_server/src/main.py +++ b/model_server/src/main.py @@ -87,16 +87,15 @@ async def function_calling(req: ChatMessage, res: Response): final_response = await model_handler.chat_completion(req) latency = time.perf_counter() - start_time + if not final_response.metadata: + final_response.metadata = {} + # Parameter gathering for detected intents if final_response.choices[0].message.content: - final_response.metadata = { - "function_latency": str(round(latency * 1000, 3)), - } + final_response.metadata["function_latency"] = str(round(latency * 1000, 3)) # Function Calling elif final_response.choices[0].message.tool_calls: - final_response.metadata = { - "function_latency": str(round(latency * 1000, 3)), - } + final_response.metadata["function_latency"] = str(round(latency * 1000, 3)) # ********************************************************************************************* # TODO: Put the following code back when hallucination check is ready @@ -107,9 +106,7 @@ async def function_calling(req: ChatMessage, res: Response): # ) # No intent detected else: - final_response.metadata = { - "intent_latency": str(round(latency * 1000, 3)), - } + final_response.metadata["intent_latency"] = str(round(latency * 1000, 3)) if not use_agent_orchestrator: final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))