From cbd181a092d76198e543803930bf0d078eeecd41 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Fri, 4 Apr 2025 09:53:54 -0700 Subject: [PATCH] Fix a bug in message formatting --- model_server/src/core/function_calling.py | 50 +++++++++++------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index ffb15351..28ea2e93 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -152,7 +152,12 @@ class ArchFunctionHandler(ArchBaseHandler): unmatched_opening = stack.pop() fixed_str += opening_bracket[unmatched_opening] - return fixed_str + try: + fixed_str = json.loads(fixed_str) + except Exception: + fixed_str = json.loads(fixed_str.replace("'", '"')) + + return json.dumps(fixed_str) def _parse_model_response(self, content: str) -> Dict[str, any]: """ @@ -171,6 +176,7 @@ class ArchFunctionHandler(ArchBaseHandler): """ response_dict = { + "raw_response": [], "response": [], "required_functions": [], "clarification": "", @@ -186,11 +192,9 @@ class ArchFunctionHandler(ArchBaseHandler): content = content[4:].strip() content = self._fix_json_string(content) - try: - model_response = json.loads(content) - except Exception: - model_response = json.loads(content.replace("'", '"')) + response_dict["raw_response"] = f"```json\n{content}```" + model_response = json.loads(content) response_dict["response"] = model_response.get("response", "") response_dict["required_functions"] = model_response.get( "required_functions", [] @@ -212,7 +216,7 @@ class ArchFunctionHandler(ArchBaseHandler): response_dict["is_valid"] = False response_dict["error_message"] = f"Fail to parse model responses: {e}" - return content, response_dict + return response_dict def _convert_data_type(self, value: str, target_type: str): # TODO: Add more conversion rules as needed @@ -272,9 +276,9 @@ class ArchFunctionHandler(ArchBaseHandler): if required_param not in func_args: verification_dict["is_valid"] = False verification_dict["invalid_tool_call"] = tool_call - verification_dict["error_message"] = ( - f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!" - ) + verification_dict[ + "error_message" + ] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!" break # Verify the data type of each parameter in the tool calls @@ -286,9 +290,9 @@ class ArchFunctionHandler(ArchBaseHandler): if param_name not in function_properties: verification_dict["is_valid"] = False verification_dict["invalid_tool_call"] = tool_call - verification_dict["error_message"] = ( - f"Parameter `{param_name}` is not defined in the function `{func_name}`." - ) + verification_dict[ + "error_message" + ] = f"Parameter `{param_name}` is not defined in the function `{func_name}`." break else: param_value = func_args[param_name] @@ -304,16 +308,16 @@ class ArchFunctionHandler(ArchBaseHandler): if not isinstance(param_value, data_type): verification_dict["is_valid"] = False verification_dict["invalid_tool_call"] = tool_call - verification_dict["error_message"] = ( - f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`." - ) + verification_dict[ + "error_message" + ] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`." break else: verification_dict["is_valid"] = False verification_dict["invalid_tool_call"] = tool_call - verification_dict["error_message"] = ( - f"Data type `{target_type}` is not supported." - ) + verification_dict[ + "error_message" + ] = f"Data type `{target_type}` is not supported." return verification_dict @@ -405,12 +409,8 @@ class ArchFunctionHandler(ArchBaseHandler): model_response = "".join(self.hallucination_state.tokens) # Extract tool calls from model response - raw_model_response_json_fixed, response_dict = self._parse_model_response( - model_response - ) - logger.info( - f"[arch-fc]: raw model response (json fixed): {raw_model_response_json_fixed}" - ) + response_dict = self._parse_model_response(model_response) + logger.info(f"[arch-fc]: raw model response: {response_dict['raw_response']}") # General model response if response_dict.get("response", ""): @@ -462,7 +462,7 @@ class ArchFunctionHandler(ArchBaseHandler): chat_completion_response = ChatCompletionResponse( choices=[Choice(message=model_message)], model=self.model_name, - metadata={"x-arch-fc-model-response": raw_model_response_json_fixed}, + metadata={"x-arch-fc-model-response": response_dict["raw_response"]}, role="assistant", )