Fix a bug in message formatting

This commit is contained in:
Shuguang Chen 2025-04-04 09:53:54 -07:00
parent 0c3d52bfe4
commit cbd181a092

View file

@ -152,7 +152,12 @@ class ArchFunctionHandler(ArchBaseHandler):
unmatched_opening = stack.pop() unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening] fixed_str += opening_bracket[unmatched_opening]
return fixed_str try:
fixed_str = json.loads(fixed_str)
except Exception:
fixed_str = json.loads(fixed_str.replace("'", '"'))
return json.dumps(fixed_str)
def _parse_model_response(self, content: str) -> Dict[str, any]: def _parse_model_response(self, content: str) -> Dict[str, any]:
""" """
@ -171,6 +176,7 @@ class ArchFunctionHandler(ArchBaseHandler):
""" """
response_dict = { response_dict = {
"raw_response": [],
"response": [], "response": [],
"required_functions": [], "required_functions": [],
"clarification": "", "clarification": "",
@ -186,11 +192,9 @@ class ArchFunctionHandler(ArchBaseHandler):
content = content[4:].strip() content = content[4:].strip()
content = self._fix_json_string(content) content = self._fix_json_string(content)
try: response_dict["raw_response"] = f"```json\n{content}```"
model_response = json.loads(content)
except Exception:
model_response = json.loads(content.replace("'", '"'))
model_response = json.loads(content)
response_dict["response"] = model_response.get("response", "") response_dict["response"] = model_response.get("response", "")
response_dict["required_functions"] = model_response.get( response_dict["required_functions"] = model_response.get(
"required_functions", [] "required_functions", []
@ -212,7 +216,7 @@ class ArchFunctionHandler(ArchBaseHandler):
response_dict["is_valid"] = False response_dict["is_valid"] = False
response_dict["error_message"] = f"Fail to parse model responses: {e}" response_dict["error_message"] = f"Fail to parse model responses: {e}"
return content, response_dict return response_dict
def _convert_data_type(self, value: str, target_type: str): def _convert_data_type(self, value: str, target_type: str):
# TODO: Add more conversion rules as needed # TODO: Add more conversion rules as needed
@ -272,9 +276,9 @@ class ArchFunctionHandler(ArchBaseHandler):
if required_param not in func_args: if required_param not in func_args:
verification_dict["is_valid"] = False verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call verification_dict["invalid_tool_call"] = tool_call
verification_dict["error_message"] = ( verification_dict[
f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!" "error_message"
) ] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
break break
# Verify the data type of each parameter in the tool calls # Verify the data type of each parameter in the tool calls
@ -286,9 +290,9 @@ class ArchFunctionHandler(ArchBaseHandler):
if param_name not in function_properties: if param_name not in function_properties:
verification_dict["is_valid"] = False verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call verification_dict["invalid_tool_call"] = tool_call
verification_dict["error_message"] = ( verification_dict[
f"Parameter `{param_name}` is not defined in the function `{func_name}`." "error_message"
) ] = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
break break
else: else:
param_value = func_args[param_name] param_value = func_args[param_name]
@ -304,16 +308,16 @@ class ArchFunctionHandler(ArchBaseHandler):
if not isinstance(param_value, data_type): if not isinstance(param_value, data_type):
verification_dict["is_valid"] = False verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call verification_dict["invalid_tool_call"] = tool_call
verification_dict["error_message"] = ( verification_dict[
f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`." "error_message"
) ] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
break break
else: else:
verification_dict["is_valid"] = False verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call verification_dict["invalid_tool_call"] = tool_call
verification_dict["error_message"] = ( verification_dict[
f"Data type `{target_type}` is not supported." "error_message"
) ] = f"Data type `{target_type}` is not supported."
return verification_dict return verification_dict
@ -405,12 +409,8 @@ class ArchFunctionHandler(ArchBaseHandler):
model_response = "".join(self.hallucination_state.tokens) model_response = "".join(self.hallucination_state.tokens)
# Extract tool calls from model response # Extract tool calls from model response
raw_model_response_json_fixed, response_dict = self._parse_model_response( response_dict = self._parse_model_response(model_response)
model_response logger.info(f"[arch-fc]: raw model response: {response_dict['raw_response']}")
)
logger.info(
f"[arch-fc]: raw model response (json fixed): {raw_model_response_json_fixed}"
)
# General model response # General model response
if response_dict.get("response", ""): if response_dict.get("response", ""):
@ -462,7 +462,7 @@ class ArchFunctionHandler(ArchBaseHandler):
chat_completion_response = ChatCompletionResponse( chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_message)], choices=[Choice(message=model_message)],
model=self.model_name, model=self.model_name,
metadata={"x-arch-fc-model-response": raw_model_response_json_fixed}, metadata={"x-arch-fc-model-response": response_dict["raw_response"]},
role="assistant", role="assistant",
) )