This commit is contained in:
Shuguang Chen 2024-12-09 11:19:09 -08:00
parent e0d4ee7357
commit 1a3d33409b
2 changed files with 26 additions and 18 deletions

View file

@ -172,15 +172,13 @@ class ArchFunctionConfig:
"""
).strip()
GENERATION_PARAMS = (
{
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
},
)
GENERATION_PARAMS = {
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
}
PREFILL_CONFIG = {
"prefill_params": {
@ -486,17 +484,17 @@ class ArchFunctionHandler(ArchBaseHandler):
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
if extracted["tool_calls"]:
if extracted["result"]:
# [TODO] Review: define the behavior in the case that tool call extraction fails
# if not extracted["status"]:
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["tool_calls"]
tools=req.tools, tool_calls=extracted["result"]
)
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if verified["status"]:
model_response = Message(content="", tool_calls=extracted["tool_calls"])
model_response = Message(content="", tool_calls=extracted["result"])
# else:
else: