From 3859e8eb43d73277ec7df51fabc82ed1a6efd49d Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:40:57 -0800 Subject: [PATCH] Fix bugs --- model_server/src/core/function_calling.py | 30 +++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 57428235..d3fee240 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -43,7 +43,11 @@ class ArchIntentConfig: EXTRA_INSTRUCTION = "Are there any tools can help?" - GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]} + GENERATION_PARAMS = { + "temperature": 0.01, + "max_tokens": 1, + "stop_token_ids": [151645], + } class ArchIntentHandler(ArchBaseHandler): @@ -318,6 +322,9 @@ class ArchFunctionHandler(ArchBaseHandler): flag = False for line in content.split("\n"): + if not is_valid: + break + if "" == line: flag = True elif "" == line: @@ -332,7 +339,7 @@ class ArchFunctionHandler(ArchBaseHandler): tool_content = json.loads(fixed_content) except Exception: tool_calls, is_valid, error_message = [], False, e - return tool_calls, is_valid, error_message + break tool_calls.append( { @@ -347,7 +354,7 @@ class ArchFunctionHandler(ArchBaseHandler): flag = False - return {"result": tool_calls, "status": is_valid, "message": "error_message"} + return {"result": tool_calls, "status": is_valid, "message": error_message} def _verify_tool_calls( self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]] @@ -374,16 +381,19 @@ class ArchFunctionHandler(ArchBaseHandler): functions[tool["function"]["name"]] = tool["function"]["parameters"] for tool_call in tool_calls: - func_name, func_args = ( - tool_call["function"]["name"], - tool_call["function"]["arguments"], - ) + if not is_valid: + break + + func_name = tool_call["function"]["name"] + func_args = tool_call["function"]["arguments"] # Check whether the function is available or not if func_name not in functions: is_valid = False + invalid_tool_call = tool_call error_message = f"{func_name} is not defined!" - return is_valid, error_message + break + else: # Check if all the requried parameters can be found in the tool calls for required_param in functions[func_name].get("required", []): @@ -391,7 +401,7 @@ class ArchFunctionHandler(ArchBaseHandler): is_valid = False invalid_tool_call = tool_call error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!" - return is_valid, invalid_tool_call, error_message + break # Verify the data type of each parameter in the tool calls for param_name in func_args: @@ -405,7 +415,7 @@ class ArchFunctionHandler(ArchBaseHandler): is_valid = False invalid_tool_call = tool_call error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`." - return is_valid, invalid_tool_call, error_message + break return { "status": is_valid,