mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
Fix a bug in message formatting
This commit is contained in:
parent
0c3d52bfe4
commit
cbd181a092
1 changed files with 25 additions and 25 deletions
|
|
@ -152,7 +152,12 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
unmatched_opening = stack.pop()
|
unmatched_opening = stack.pop()
|
||||||
fixed_str += opening_bracket[unmatched_opening]
|
fixed_str += opening_bracket[unmatched_opening]
|
||||||
|
|
||||||
return fixed_str
|
try:
|
||||||
|
fixed_str = json.loads(fixed_str)
|
||||||
|
except Exception:
|
||||||
|
fixed_str = json.loads(fixed_str.replace("'", '"'))
|
||||||
|
|
||||||
|
return json.dumps(fixed_str)
|
||||||
|
|
||||||
def _parse_model_response(self, content: str) -> Dict[str, any]:
|
def _parse_model_response(self, content: str) -> Dict[str, any]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -171,6 +176,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response_dict = {
|
response_dict = {
|
||||||
|
"raw_response": [],
|
||||||
"response": [],
|
"response": [],
|
||||||
"required_functions": [],
|
"required_functions": [],
|
||||||
"clarification": "",
|
"clarification": "",
|
||||||
|
|
@ -186,11 +192,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
content = content[4:].strip()
|
content = content[4:].strip()
|
||||||
|
|
||||||
content = self._fix_json_string(content)
|
content = self._fix_json_string(content)
|
||||||
try:
|
response_dict["raw_response"] = f"```json\n{content}```"
|
||||||
model_response = json.loads(content)
|
|
||||||
except Exception:
|
|
||||||
model_response = json.loads(content.replace("'", '"'))
|
|
||||||
|
|
||||||
|
model_response = json.loads(content)
|
||||||
response_dict["response"] = model_response.get("response", "")
|
response_dict["response"] = model_response.get("response", "")
|
||||||
response_dict["required_functions"] = model_response.get(
|
response_dict["required_functions"] = model_response.get(
|
||||||
"required_functions", []
|
"required_functions", []
|
||||||
|
|
@ -212,7 +216,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
response_dict["is_valid"] = False
|
response_dict["is_valid"] = False
|
||||||
response_dict["error_message"] = f"Fail to parse model responses: {e}"
|
response_dict["error_message"] = f"Fail to parse model responses: {e}"
|
||||||
|
|
||||||
return content, response_dict
|
return response_dict
|
||||||
|
|
||||||
def _convert_data_type(self, value: str, target_type: str):
|
def _convert_data_type(self, value: str, target_type: str):
|
||||||
# TODO: Add more conversion rules as needed
|
# TODO: Add more conversion rules as needed
|
||||||
|
|
@ -272,9 +276,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
if required_param not in func_args:
|
if required_param not in func_args:
|
||||||
verification_dict["is_valid"] = False
|
verification_dict["is_valid"] = False
|
||||||
verification_dict["invalid_tool_call"] = tool_call
|
verification_dict["invalid_tool_call"] = tool_call
|
||||||
verification_dict["error_message"] = (
|
verification_dict[
|
||||||
f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
"error_message"
|
||||||
)
|
] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||||
break
|
break
|
||||||
|
|
||||||
# Verify the data type of each parameter in the tool calls
|
# Verify the data type of each parameter in the tool calls
|
||||||
|
|
@ -286,9 +290,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
if param_name not in function_properties:
|
if param_name not in function_properties:
|
||||||
verification_dict["is_valid"] = False
|
verification_dict["is_valid"] = False
|
||||||
verification_dict["invalid_tool_call"] = tool_call
|
verification_dict["invalid_tool_call"] = tool_call
|
||||||
verification_dict["error_message"] = (
|
verification_dict[
|
||||||
f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
"error_message"
|
||||||
)
|
] = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
param_value = func_args[param_name]
|
param_value = func_args[param_name]
|
||||||
|
|
@ -304,16 +308,16 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
if not isinstance(param_value, data_type):
|
if not isinstance(param_value, data_type):
|
||||||
verification_dict["is_valid"] = False
|
verification_dict["is_valid"] = False
|
||||||
verification_dict["invalid_tool_call"] = tool_call
|
verification_dict["invalid_tool_call"] = tool_call
|
||||||
verification_dict["error_message"] = (
|
verification_dict[
|
||||||
f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
"error_message"
|
||||||
)
|
] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
verification_dict["is_valid"] = False
|
verification_dict["is_valid"] = False
|
||||||
verification_dict["invalid_tool_call"] = tool_call
|
verification_dict["invalid_tool_call"] = tool_call
|
||||||
verification_dict["error_message"] = (
|
verification_dict[
|
||||||
f"Data type `{target_type}` is not supported."
|
"error_message"
|
||||||
)
|
] = f"Data type `{target_type}` is not supported."
|
||||||
|
|
||||||
return verification_dict
|
return verification_dict
|
||||||
|
|
||||||
|
|
@ -405,12 +409,8 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
model_response = "".join(self.hallucination_state.tokens)
|
model_response = "".join(self.hallucination_state.tokens)
|
||||||
|
|
||||||
# Extract tool calls from model response
|
# Extract tool calls from model response
|
||||||
raw_model_response_json_fixed, response_dict = self._parse_model_response(
|
response_dict = self._parse_model_response(model_response)
|
||||||
model_response
|
logger.info(f"[arch-fc]: raw model response: {response_dict['raw_response']}")
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[arch-fc]: raw model response (json fixed): {raw_model_response_json_fixed}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# General model response
|
# General model response
|
||||||
if response_dict.get("response", ""):
|
if response_dict.get("response", ""):
|
||||||
|
|
@ -462,7 +462,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
||||||
chat_completion_response = ChatCompletionResponse(
|
chat_completion_response = ChatCompletionResponse(
|
||||||
choices=[Choice(message=model_message)],
|
choices=[Choice(message=model_message)],
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
metadata={"x-arch-fc-model-response": raw_model_response_json_fixed},
|
metadata={"x-arch-fc-model-response": response_dict["raw_response"]},
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue