From 8335f0c3dee7ceba823f2890115bbbb28a2708e6 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Thu, 27 Mar 2025 10:26:47 -0700 Subject: [PATCH] minor update --- model_server/src/core/function_calling.py | 16 ++++++++-------- model_server/src/main.py | 13 ++++++++----- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 6714da24..f5b2cc44 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -396,7 +396,8 @@ class ArchFunctionHandler(ArchBaseHandler): model_response += chunk.choices[0].delta.content logger.info(f"[Agent Orchestrator]: response received: {model_response}") else: - # *********************************************************************************************\ + # ********************************************************************************************* + # TODO: # Update the following logic for hallucination check # 1. If the model response starts wtth `tool_calls`, continue halluciantion check: # - If hallucination detected, start prompt prefilling @@ -441,7 +442,7 @@ class ArchFunctionHandler(ArchBaseHandler): # model_response = prefill_response.choices[0].message.content # *********************************************************************************************\ - # Remove the following for loop after updating hallucination check + # TODO: Remove the following for loop after updating hallucination check # ********************************************************************************************* for chunk in response: if len(chunk.choices) > 0 and chunk.choices[0].delta.content: @@ -450,18 +451,18 @@ class ArchFunctionHandler(ArchBaseHandler): # Extract tool calls from model response response_dict = self._parse_model_resonse(model_response) + # General model response if response_dict.get("response", ""): - # General model response model_message = Message(content="", tool_calls=[]) + # Parameter gathering elif response_dict.get("required_functions", []): - # Model response for parameter gathering if not use_agent_orchestrator: clarification = response_dict.get("clarification", "") model_message = Message(content=clarification, tool_calls=[]) else: model_message = Message(content="", tool_calls=[]) + # Function Calling elif response_dict.get("tool_calls", []): - # Response with tool calls if response_dict["is_valid"]: if not use_agent_orchestrator: verification_dict = self._verify_tool_calls( @@ -490,12 +491,11 @@ class ArchFunctionHandler(ArchBaseHandler): ) else: - # Response with tool calls but contain errors + # Response with tool calls but invalid model_message = Message(content="", tool_calls=[]) + # Response not in the desired format else: logger.error(f"Invalid model response - {model_response}") - - # Response with tool calls but contain errors model_message = Message(content="", tool_calls=[]) chat_completion_response = ChatCompletionResponse( diff --git a/model_server/src/main.py b/model_server/src/main.py index 648c5b89..b0434222 100644 --- a/model_server/src/main.py +++ b/model_server/src/main.py @@ -78,11 +78,6 @@ async def function_calling(req: ChatMessage, res: Response): use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False) logger.info(f"Use agent orchestrator: {use_agent_orchestrator}") - # if not use_agent_orchestrator: - # intent_start_time = time.perf_counter() - # intent_response = await handler_map["Arch-Intent"].chat_completion(req) - # intent_latency = time.perf_counter() - intent_start_time - # intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response) try: handler_name = "Arch-Agent" if use_agent_orchestrator else "Arch-Function" @@ -102,6 +97,10 @@ async def function_calling(req: ChatMessage, res: Response): final_response.metadata = { "function_latency": str(round(latency * 1000, 3)), } + + # ********************************************************************************************* + # TODO: Put the following code back when hallucination check is ready + # ********************************************************************************************* # if not use_agent_orchestrator: # final_response.metadata["hallucination"] = str( # model_handler.hallucination_state.hallucination @@ -114,6 +113,10 @@ async def function_calling(req: ChatMessage, res: Response): if not use_agent_orchestrator: final_response.metadata["intent_latency"] = str(round(latency * 1000, 3)) + + # ********************************************************************************************* + # TODO: Put the following code back when hallucination check is ready + # ********************************************************************************************* # final_response.metadata["hallucination"] = str( # model_handler.hallucination_state.hallucination # )