Handle intent matching better in arch gateway (#391)

This commit is contained in:
Shuguang Chen 2025-03-04 12:49:13 -08:00 committed by GitHub
parent 10cad4d0b7
commit e77fc47225
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 653 additions and 309 deletions

View file

@ -355,7 +355,7 @@ class ArchFunctionHandler(ArchBaseHandler):
try:
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
is_valid, error_message = False, e
break
tool_calls.append(
@ -573,23 +573,26 @@ class ArchFunctionHandler(ArchBaseHandler):
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
if len(extracted["result"]) and extracted["status"]:
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["result"]
)
if verified["status"]:
logger.info(
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
if extracted["status"]:
# Response with tool calls
if len(extracted["result"]):
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["result"]
)
model_response = Message(content="", tool_calls=extracted["result"])
if verified["status"]:
logger.info(
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
)
model_response = Message(content="", tool_calls=extracted["result"])
else:
logger.error(f"Invalid tool call - {verified['message']}")
# Response without tool calls
else:
logger.error(f"Invalid tool call - {verified['message']}")
# raise ValueError(
# f"[Arch-Function]: Invalid tool call - {verified['message']}"
# )
model_response = Message(content=model_response, tool_calls=[])
# Response with tool calls but contain errors
else:
model_response = Message(content=model_response, tool_calls=[])
logger.error(f"Tool call extraction error - {extracted['message']}")
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name

View file

@ -171,7 +171,7 @@ class ArchBaseHandler:
assert processed_messages[-1]["role"] == "user"
if extra_instruction:
processed_messages[-1]["content"] += extra_instruction
processed_messages[-1]["content"] += "\n" + extra_instruction
# keep the first system message and shift conversation if the total token length exceeds the limit
def truncate_messages(messages: List[Dict[str, Any]]):

View file

@ -104,6 +104,7 @@ async def function_calling(req: ChatMessage, res: Response):
res.status_code = 500
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
else:
# no intent matched
intent_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
}