minor update

This commit is contained in:
Shuguang Chen 2025-03-27 10:26:47 -07:00
parent 820c0443ee
commit 8335f0c3de
2 changed files with 16 additions and 13 deletions

View file

@ -396,7 +396,8 @@ class ArchFunctionHandler(ArchBaseHandler):
model_response += chunk.choices[0].delta.content
logger.info(f"[Agent Orchestrator]: response received: {model_response}")
else:
# *********************************************************************************************\
# *********************************************************************************************
# TODO:
# Update the following logic for hallucination check
# 1. If the model response starts wtth `tool_calls`, continue halluciantion check:
# - If hallucination detected, start prompt prefilling
@ -441,7 +442,7 @@ class ArchFunctionHandler(ArchBaseHandler):
# model_response = prefill_response.choices[0].message.content
# *********************************************************************************************\
# Remove the following for loop after updating hallucination check
# TODO: Remove the following for loop after updating hallucination check
# *********************************************************************************************
for chunk in response:
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
@ -450,18 +451,18 @@ class ArchFunctionHandler(ArchBaseHandler):
# Extract tool calls from model response
response_dict = self._parse_model_resonse(model_response)
# General model response
if response_dict.get("response", ""):
# General model response
model_message = Message(content="", tool_calls=[])
# Parameter gathering
elif response_dict.get("required_functions", []):
# Model response for parameter gathering
if not use_agent_orchestrator:
clarification = response_dict.get("clarification", "")
model_message = Message(content=clarification, tool_calls=[])
else:
model_message = Message(content="", tool_calls=[])
# Function Calling
elif response_dict.get("tool_calls", []):
# Response with tool calls
if response_dict["is_valid"]:
if not use_agent_orchestrator:
verification_dict = self._verify_tool_calls(
@ -490,12 +491,11 @@ class ArchFunctionHandler(ArchBaseHandler):
)
else:
# Response with tool calls but contain errors
# Response with tool calls but invalid
model_message = Message(content="", tool_calls=[])
# Response not in the desired format
else:
logger.error(f"Invalid model response - {model_response}")
# Response with tool calls but contain errors
model_message = Message(content="", tool_calls=[])
chat_completion_response = ChatCompletionResponse(

View file

@ -78,11 +78,6 @@ async def function_calling(req: ChatMessage, res: Response):
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
# if not use_agent_orchestrator:
# intent_start_time = time.perf_counter()
# intent_response = await handler_map["Arch-Intent"].chat_completion(req)
# intent_latency = time.perf_counter() - intent_start_time
# intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
try:
handler_name = "Arch-Agent" if use_agent_orchestrator else "Arch-Function"
@ -102,6 +97,10 @@ async def function_calling(req: ChatMessage, res: Response):
final_response.metadata = {
"function_latency": str(round(latency * 1000, 3)),
}
# *********************************************************************************************
# TODO: Put the following code back when hallucination check is ready
# *********************************************************************************************
# if not use_agent_orchestrator:
# final_response.metadata["hallucination"] = str(
# model_handler.hallucination_state.hallucination
@ -114,6 +113,10 @@ async def function_calling(req: ChatMessage, res: Response):
if not use_agent_orchestrator:
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
# *********************************************************************************************
# TODO: Put the following code back when hallucination check is ready
# *********************************************************************************************
# final_response.metadata["hallucination"] = str(
# model_handler.hallucination_state.hallucination
# )