add repair_escape_error function to parse commands

This commit is contained in:
黄伟韬 2024-08-12 13:52:32 +08:00
parent fa06a67a64
commit 43ffa3558b
2 changed files with 36 additions and 3 deletions

View file

@ -100,6 +100,9 @@ JSON_REPAIR_PROMPT = """
## json data
{json_data}
## json decode error
{json_decode_error}
## Output Format
```json

View file

@ -311,10 +311,20 @@ class RoleZero(Role):
if commands.endswith("]") and not commands.startswith("["):
commands = "[" + commands
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError:
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON for: {self.command_rsp}. Trying to repair...")
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
commands = await self.llm.aask(
msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp, json_decode_error=str(e))
)
try:
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except json.JSONDecodeError:
# repair escape error of code and math
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
new_command = self.repair_escape_error(commands)
commands = json.loads(
repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON)
)
except Exception as e:
tb = traceback.format_exc()
print(tb)
@ -327,6 +337,26 @@ class RoleZero(Role):
commands = commands["commands"] if "commands" in commands else [commands]
return commands, True
def repair_escape_error(self, commands):
"""Repaires escape errors in command responses"""
escape_repair_map = {
"\a": "\\\\a",
"\b": "\\\\b",
"\f": "\\\\f",
"\r": "\\\\r",
"\t": "\\\\t",
"\v": "\\\\v",
}
new_command = ""
for index, ch in enumerate(commands):
if ch == "\\" and index + 1 < len(commands):
if commands[index + 1] not in ["n", '"', " "]:
new_command += "\\"
elif ch in escape_repair_map:
ch = escape_repair_map[ch]
new_command += ch
return commands
async def _run_commands(self, commands) -> str:
outputs = []
for cmd in commands: