This commit is contained in:
Shuguang Chen 2024-12-09 15:40:57 -08:00
parent 21fc0b5624
commit 3859e8eb43

View file

@ -43,7 +43,11 @@ class ArchIntentConfig:
EXTRA_INSTRUCTION = "Are there any tools can help?"
GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]}
GENERATION_PARAMS = {
"temperature": 0.01,
"max_tokens": 1,
"stop_token_ids": [151645],
}
class ArchIntentHandler(ArchBaseHandler):
@ -318,6 +322,9 @@ class ArchFunctionHandler(ArchBaseHandler):
flag = False
for line in content.split("\n"):
if not is_valid:
break
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
@ -332,7 +339,7 @@ class ArchFunctionHandler(ArchBaseHandler):
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
return tool_calls, is_valid, error_message
break
tool_calls.append(
{
@ -347,7 +354,7 @@ class ArchFunctionHandler(ArchBaseHandler):
flag = False
return {"result": tool_calls, "status": is_valid, "message": "error_message"}
return {"result": tool_calls, "status": is_valid, "message": error_message}
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
@ -374,16 +381,19 @@ class ArchFunctionHandler(ArchBaseHandler):
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
func_name, func_args = (
tool_call["function"]["name"],
tool_call["function"]["arguments"],
)
if not is_valid:
break
func_name = tool_call["function"]["name"]
func_args = tool_call["function"]["arguments"]
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
invalid_tool_call = tool_call
error_message = f"{func_name} is not defined!"
return is_valid, error_message
break
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
@ -391,7 +401,7 @@ class ArchFunctionHandler(ArchBaseHandler):
is_valid = False
invalid_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
return is_valid, invalid_tool_call, error_message
break
# Verify the data type of each parameter in the tool calls
for param_name in func_args:
@ -405,7 +415,7 @@ class ArchFunctionHandler(ArchBaseHandler):
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
return is_valid, invalid_tool_call, error_message
break
return {
"status": is_valid,