update retry_parse_json_text

This commit is contained in:
better629 2023-11-22 13:01:16 +08:00
parent c49b832dee
commit fc4ec5a944
4 changed files with 158 additions and 40 deletions

View file

@ -4,8 +4,9 @@
import copy
from enum import Enum
from typing import Union
from typing import Union, Callable
import regex as re
from tenacity import retry, stop_after_attempt, wait_fixed, after_log, RetryCallState
from metagpt.logs import logger
from metagpt.config import CONFIG
@ -14,8 +15,8 @@ from metagpt.utils.custom_decoder import CustomDecoder
class RepairType(Enum):
CS = "case sensitivity"
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
JSON = "json format"
@ -39,9 +40,11 @@ def repair_case_sensitivity(output: str, req_key: str) -> str:
return output
def repair_special_character_missing(output: str, req_key: str) -> str:
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
fix target string `[CONTENT]xxx[/CONTENT]` lacks [/CONTENT]
fix
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
"""
sc_arr = ["/"]
@ -55,30 +58,48 @@ def repair_special_character_missing(output: str, req_key: str) -> str:
# req_key with special_character usually in the tail side
ridx = output.rfind(req_key_pure)
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
logger.info(f"repair_special_character_missing: {req_key}")
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
return output
def repair_required_key_pair_missing(output: str, req_key: str) -> str:
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
implement the req_key pair in the begin or end of the content
req_key format
1. `[req_key]`, and its pair `[/req_key]`
2. `[/req_key]`, and its pair `[req_key]`
"""
sc = "/" # special char
if req_key.startswith("[") and req_key.endswith("]"):
if "/" in req_key:
left_key = req_key.replace("/", "") # `[/req_key]` -> `[req_key]`
if sc in req_key:
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
right_key = req_key
else:
left_key = req_key
right_key = f"{req_key[0]}/{req_key[1:]}" # `[req_key]` -> `[/req_key]`
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
if left_key not in output:
output = left_key + output
output = left_key + "\n" + output
if right_key not in output:
output = output + right_key
def judge_potential_json(routput: str, left_key: str) -> Union[str, bool]:
routput = copy.deepcopy(routput)
ridx = routput.rfind(left_key)
if ridx < 0:
return None
sub_output = routput[ridx:]
idx1 = sub_output.rfind("}")
idx2 = sub_output.rindex("]")
idx = idx1 if idx1 >= idx2 else idx2
sub_output = sub_output[: idx]
return sub_output
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
# # avoid [req_key]xx[req_key] case to append [/req_key]
output = output + "\n" + right_key
elif judge_potential_json(output, left_key):
sub_content = judge_potential_json(output, left_key)
output = sub_content + "\n" + right_key
return output
@ -106,12 +127,12 @@ def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType =
for repair_type in repair_types:
if repair_type == RepairType.CS:
output = repair_case_sensitivity(output, req_key)
elif repair_type == RepairType.RKPM:
output = repair_required_key_pair_missing(output, req_key)
elif repair_type == RepairType.SCM:
output = repair_special_character_missing(output, req_key)
elif repair_type == RepairType.JSON:
output = repair_json_format(output)
elif repair_type == RepairType.RKPM:
output = repair_required_key_pair_missing(output, req_key)
return output
@ -178,25 +199,58 @@ def repair_invalid_json(output: str, error: str) -> str:
return output
def retry_parse_json_text(output: str, retry: int = 5) -> Union[list, dict]:
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
def run_and_passon(retry_state: RetryCallState) -> None:
"""
RetryCallState example
{
"start_time":143.098322024,
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
"args":"(\"tag:[/CONTENT]\",)", # function input args
"kwargs":{}, # function input kwargs
"attempt_number":1, # retry number
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
"outcome_timestamp":143.098416904,
"idle_for":0,
"next_action":"None"
}
"""
if retry_state.outcome.failed:
if len(retry_state.args) > 0:
# # can't used as args=retry_state.args
func_param_output = retry_state.args[0]
elif len(retry_state.kwargs) > 0:
func_param_output = retry_state.kwargs.get("output", "")
# import pdb; pdb.set_trace()
exp_str = str(retry_state.outcome.exception())
logger.warning(f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}")
repaired_output = repair_invalid_json(func_param_output, exp_str)
retry_state.kwargs["output"] = repaired_output
return run_and_passon
@retry(
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
wait=wait_fixed(1),
after=run_after_exp_and_passon_next_retry(logger),
)
def retry_parse_json_text(output: str) -> Union[list, dict]:
"""
repair the json-text situation like there are extra chars like [']', '}']
Warning
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
it's a two-layer retry cycle
"""
parsed_data = {}
for idx in range(retry):
raw_output = copy.deepcopy(output)
logger.debug(f"output to json decode:\n{output}")
try:
parsed_data = CustomDecoder(strict=False).decode(output)
break
except Exception as exp:
if not CONFIG.repair_llm_output:
# if repair_llm_output is False, break from the retry loop
break
logger.warning(f"decode content into json failed, try to fix it. exp: {exp}")
error = str(exp)
output = repair_invalid_json(output, error)
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
parsed_data = CustomDecoder(strict=False).decode(output)
return parsed_data