mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Fix a bug in message formatting
This commit is contained in:
parent
0c3d52bfe4
commit
cbd181a092
1 changed files with 25 additions and 25 deletions
|
|
@ -152,7 +152,12 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
return fixed_str
|
||||
try:
|
||||
fixed_str = json.loads(fixed_str)
|
||||
except Exception:
|
||||
fixed_str = json.loads(fixed_str.replace("'", '"'))
|
||||
|
||||
return json.dumps(fixed_str)
|
||||
|
||||
def _parse_model_response(self, content: str) -> Dict[str, any]:
|
||||
"""
|
||||
|
|
@ -171,6 +176,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
"""
|
||||
|
||||
response_dict = {
|
||||
"raw_response": [],
|
||||
"response": [],
|
||||
"required_functions": [],
|
||||
"clarification": "",
|
||||
|
|
@ -186,11 +192,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
content = content[4:].strip()
|
||||
|
||||
content = self._fix_json_string(content)
|
||||
try:
|
||||
model_response = json.loads(content)
|
||||
except Exception:
|
||||
model_response = json.loads(content.replace("'", '"'))
|
||||
response_dict["raw_response"] = f"```json\n{content}```"
|
||||
|
||||
model_response = json.loads(content)
|
||||
response_dict["response"] = model_response.get("response", "")
|
||||
response_dict["required_functions"] = model_response.get(
|
||||
"required_functions", []
|
||||
|
|
@ -212,7 +216,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
response_dict["is_valid"] = False
|
||||
response_dict["error_message"] = f"Fail to parse model responses: {e}"
|
||||
|
||||
return content, response_dict
|
||||
return response_dict
|
||||
|
||||
def _convert_data_type(self, value: str, target_type: str):
|
||||
# TODO: Add more conversion rules as needed
|
||||
|
|
@ -272,9 +276,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if required_param not in func_args:
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict["error_message"] = (
|
||||
f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
)
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
|
|
@ -286,9 +290,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if param_name not in function_properties:
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict["error_message"] = (
|
||||
f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
)
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
break
|
||||
else:
|
||||
param_value = func_args[param_name]
|
||||
|
|
@ -304,16 +308,16 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if not isinstance(param_value, data_type):
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict["error_message"] = (
|
||||
f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
)
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
break
|
||||
else:
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict["error_message"] = (
|
||||
f"Data type `{target_type}` is not supported."
|
||||
)
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Data type `{target_type}` is not supported."
|
||||
|
||||
return verification_dict
|
||||
|
||||
|
|
@ -405,12 +409,8 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
model_response = "".join(self.hallucination_state.tokens)
|
||||
|
||||
# Extract tool calls from model response
|
||||
raw_model_response_json_fixed, response_dict = self._parse_model_response(
|
||||
model_response
|
||||
)
|
||||
logger.info(
|
||||
f"[arch-fc]: raw model response (json fixed): {raw_model_response_json_fixed}"
|
||||
)
|
||||
response_dict = self._parse_model_response(model_response)
|
||||
logger.info(f"[arch-fc]: raw model response: {response_dict['raw_response']}")
|
||||
|
||||
# General model response
|
||||
if response_dict.get("response", ""):
|
||||
|
|
@ -462,7 +462,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_message)],
|
||||
model=self.model_name,
|
||||
metadata={"x-arch-fc-model-response": raw_model_response_json_fixed},
|
||||
metadata={"x-arch-fc-model-response": response_dict["raw_response"]},
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue