mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
update retry_parse_json_text
This commit is contained in:
parent
c49b832dee
commit
fc4ec5a944
4 changed files with 158 additions and 40 deletions
|
|
@ -94,4 +94,9 @@ MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k
|
|||
### browser path for pyppeteer engine, support Chrome, Chromium,MS Edge
|
||||
#PYPPETEER_EXECUTABLE_PATH: "/usr/bin/google-chrome-stable"
|
||||
|
||||
### for repair non-openai LLM's output when parse json-text if PROMPT_FORMAT=json
|
||||
### due to non-openai LLM's output will not always follow the instruction, so here activate a post-process
|
||||
### repair operation on the content extracted from LLM's raw output. Warning, it improves the result but not fix all cases.
|
||||
# REPAIR_LLM_OUTPUT: false
|
||||
|
||||
PROMPT_FORMAT: json #json or markdown
|
||||
|
|
@ -5,17 +5,16 @@
|
|||
@Author : alexanderwu
|
||||
@File : action.py
|
||||
"""
|
||||
import re
|
||||
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed, after_log
|
||||
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.custom_decoder import CustomDecoder
|
||||
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType,\
|
||||
retry_parse_json_text, extract_content_from_output
|
||||
|
||||
|
|
@ -51,7 +50,11 @@ class Action(ABC):
|
|||
system_msgs.append(self.prefix)
|
||||
return await self.llm.aask(prompt, system_msgs)
|
||||
|
||||
# @retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(1),
|
||||
after=after_log(logger, logger.level("ERROR").name),
|
||||
)
|
||||
async def _aask_v1(
|
||||
self,
|
||||
prompt: str,
|
||||
|
|
@ -65,7 +68,7 @@ class Action(ABC):
|
|||
system_msgs = []
|
||||
system_msgs.append(self.prefix)
|
||||
content = await self.llm.aask(prompt, system_msgs)
|
||||
logger.debug(content)
|
||||
logger.debug(f"llm raw output:\n{content}")
|
||||
output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping)
|
||||
output_class_fields = list(output_class.schema()["properties"].keys()) # Custom ActionOutput's fields
|
||||
|
||||
|
|
@ -73,8 +76,8 @@ class Action(ABC):
|
|||
content = repair_llm_raw_output(content, req_keys=output_class_fields + ["[/CONTENT]"])
|
||||
content = extract_content_from_output(content)
|
||||
content = repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON) # req_keys mocked
|
||||
logger.info(f"extracted CONTENT from content:\n{content}")
|
||||
parsed_data = retry_parse_json_text(content)
|
||||
logger.info(f"extracted json CONTENT from output:\n{content}")
|
||||
parsed_data = retry_parse_json_text(output=content) # should use output=content
|
||||
|
||||
else: # using markdown parser
|
||||
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@
|
|||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
from typing import Union, Callable
|
||||
import regex as re
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed, after_log, RetryCallState
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -14,8 +15,8 @@ from metagpt.utils.custom_decoder import CustomDecoder
|
|||
|
||||
class RepairType(Enum):
|
||||
CS = "case sensitivity"
|
||||
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
|
||||
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
|
||||
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
|
||||
JSON = "json format"
|
||||
|
||||
|
||||
|
|
@ -39,9 +40,11 @@ def repair_case_sensitivity(output: str, req_key: str) -> str:
|
|||
return output
|
||||
|
||||
|
||||
def repair_special_character_missing(output: str, req_key: str) -> str:
|
||||
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
fix target string `[CONTENT]xxx[/CONTENT]` lacks [/CONTENT]
|
||||
fix
|
||||
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
|
||||
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
|
||||
"""
|
||||
sc_arr = ["/"]
|
||||
|
||||
|
|
@ -55,30 +58,48 @@ def repair_special_character_missing(output: str, req_key: str) -> str:
|
|||
# req_key with special_character usually in the tail side
|
||||
ridx = output.rfind(req_key_pure)
|
||||
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
|
||||
logger.info(f"repair_special_character_missing: {req_key}")
|
||||
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_required_key_pair_missing(output: str, req_key: str) -> str:
|
||||
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
implement the req_key pair in the begin or end of the content
|
||||
req_key format
|
||||
1. `[req_key]`, and its pair `[/req_key]`
|
||||
2. `[/req_key]`, and its pair `[req_key]`
|
||||
"""
|
||||
sc = "/" # special char
|
||||
if req_key.startswith("[") and req_key.endswith("]"):
|
||||
if "/" in req_key:
|
||||
left_key = req_key.replace("/", "") # `[/req_key]` -> `[req_key]`
|
||||
if sc in req_key:
|
||||
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
|
||||
right_key = req_key
|
||||
else:
|
||||
left_key = req_key
|
||||
right_key = f"{req_key[0]}/{req_key[1:]}" # `[req_key]` -> `[/req_key]`
|
||||
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
|
||||
|
||||
if left_key not in output:
|
||||
output = left_key + output
|
||||
output = left_key + "\n" + output
|
||||
if right_key not in output:
|
||||
output = output + right_key
|
||||
def judge_potential_json(routput: str, left_key: str) -> Union[str, bool]:
|
||||
routput = copy.deepcopy(routput)
|
||||
ridx = routput.rfind(left_key)
|
||||
if ridx < 0:
|
||||
return None
|
||||
sub_output = routput[ridx:]
|
||||
idx1 = sub_output.rfind("}")
|
||||
idx2 = sub_output.rindex("]")
|
||||
idx = idx1 if idx1 >= idx2 else idx2
|
||||
sub_output = sub_output[: idx]
|
||||
return sub_output
|
||||
|
||||
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
|
||||
# # avoid [req_key]xx[req_key] case to append [/req_key]
|
||||
output = output + "\n" + right_key
|
||||
elif judge_potential_json(output, left_key):
|
||||
sub_content = judge_potential_json(output, left_key)
|
||||
output = sub_content + "\n" + right_key
|
||||
|
||||
return output
|
||||
|
||||
|
|
@ -106,12 +127,12 @@ def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType =
|
|||
for repair_type in repair_types:
|
||||
if repair_type == RepairType.CS:
|
||||
output = repair_case_sensitivity(output, req_key)
|
||||
elif repair_type == RepairType.RKPM:
|
||||
output = repair_required_key_pair_missing(output, req_key)
|
||||
elif repair_type == RepairType.SCM:
|
||||
output = repair_special_character_missing(output, req_key)
|
||||
elif repair_type == RepairType.JSON:
|
||||
output = repair_json_format(output)
|
||||
elif repair_type == RepairType.RKPM:
|
||||
output = repair_required_key_pair_missing(output, req_key)
|
||||
return output
|
||||
|
||||
|
||||
|
|
@ -178,25 +199,58 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
return output
|
||||
|
||||
|
||||
def retry_parse_json_text(output: str, retry: int = 5) -> Union[list, dict]:
|
||||
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
|
||||
def run_and_passon(retry_state: RetryCallState) -> None:
|
||||
"""
|
||||
RetryCallState example
|
||||
{
|
||||
"start_time":143.098322024,
|
||||
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
|
||||
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
|
||||
"args":"(\"tag:[/CONTENT]\",)", # function input args
|
||||
"kwargs":{}, # function input kwargs
|
||||
"attempt_number":1, # retry number
|
||||
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
|
||||
"outcome_timestamp":143.098416904,
|
||||
"idle_for":0,
|
||||
"next_action":"None"
|
||||
}
|
||||
"""
|
||||
if retry_state.outcome.failed:
|
||||
if len(retry_state.args) > 0:
|
||||
# # can't used as args=retry_state.args
|
||||
func_param_output = retry_state.args[0]
|
||||
elif len(retry_state.kwargs) > 0:
|
||||
func_param_output = retry_state.kwargs.get("output", "")
|
||||
# import pdb; pdb.set_trace()
|
||||
exp_str = str(retry_state.outcome.exception())
|
||||
logger.warning(f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
|
||||
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}")
|
||||
|
||||
repaired_output = repair_invalid_json(func_param_output, exp_str)
|
||||
retry_state.kwargs["output"] = repaired_output
|
||||
|
||||
return run_and_passon
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
|
||||
wait=wait_fixed(1),
|
||||
after=run_after_exp_and_passon_next_retry(logger),
|
||||
)
|
||||
def retry_parse_json_text(output: str) -> Union[list, dict]:
|
||||
"""
|
||||
repair the json-text situation like there are extra chars like [']', '}']
|
||||
|
||||
Warning
|
||||
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
|
||||
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
|
||||
it's a two-layer retry cycle
|
||||
"""
|
||||
parsed_data = {}
|
||||
for idx in range(retry):
|
||||
raw_output = copy.deepcopy(output)
|
||||
logger.debug(f"output to json decode:\n{output}")
|
||||
|
||||
try:
|
||||
parsed_data = CustomDecoder(strict=False).decode(output)
|
||||
break
|
||||
except Exception as exp:
|
||||
if not CONFIG.repair_llm_output:
|
||||
# if repair_llm_output is False, break from the retry loop
|
||||
break
|
||||
|
||||
logger.warning(f"decode content into json failed, try to fix it. exp: {exp}")
|
||||
error = str(exp)
|
||||
output = repair_invalid_json(output, error)
|
||||
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
|
||||
parsed_data = CustomDecoder(strict=False).decode(output)
|
||||
|
||||
return parsed_data
|
||||
|
||||
|
|
|
|||
|
|
@ -42,20 +42,69 @@ def test_repair_special_character_missing():
|
|||
req_keys=req_keys)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = """[CONTENT] tag
|
||||
[CONTENT]
|
||||
{
|
||||
"Anything UNCLEAR": "No unclear requirements or information."
|
||||
}
|
||||
[CONTENT]"""
|
||||
target_output = """[CONTENT] tag
|
||||
[CONTENT]
|
||||
{
|
||||
"Anything UNCLEAR": "No unclear requirements or information."
|
||||
}
|
||||
[/CONTENT]"""
|
||||
output = repair_llm_raw_output(output=raw_output,
|
||||
req_keys=req_keys)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = '[CONTENT] {"a": "b"} [CONTENT]'
|
||||
target_output = '[CONTENT] {"a": "b"} [/CONTENT]'
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output,
|
||||
req_keys=["[/CONTENT]"])
|
||||
print("output\n", output)
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_required_key_pair_missing():
|
||||
raw_output = "[CONTENT] xxx"
|
||||
target_output = "[CONTENT] xxx[/CONTENT]"
|
||||
raw_output = '[CONTENT] {"a": "b"}'
|
||||
target_output = '[CONTENT] {"a": "b"}\n[/CONTENT]'
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output,
|
||||
req_keys=["[/CONTENT]"])
|
||||
assert output == target_output
|
||||
|
||||
raw_output = "xxx[/CONTENT]"
|
||||
target_output = "[CONTENT]xxx[/CONTENT]"
|
||||
raw_output = '''[CONTENT]
|
||||
{
|
||||
"a": "b"
|
||||
]'''
|
||||
target_output = '''[CONTENT]
|
||||
{
|
||||
"a": "b"
|
||||
]
|
||||
[/CONTENT]'''
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output,
|
||||
req_keys=["[CONTENT]"])
|
||||
req_keys=["[/CONTENT]"])
|
||||
assert output == target_output
|
||||
|
||||
raw_output = '''[CONTENT] tag
|
||||
[CONTENT]
|
||||
{
|
||||
"a": "b"
|
||||
}
|
||||
xxx
|
||||
'''
|
||||
target_output = '''[CONTENT] tag
|
||||
[CONTENT]
|
||||
{
|
||||
"a": "b"
|
||||
}
|
||||
[/CONTENT]
|
||||
'''
|
||||
output = repair_llm_raw_output(output=raw_output,
|
||||
req_keys=["[/CONTENT]"])
|
||||
assert output == target_output
|
||||
|
||||
|
||||
|
|
@ -100,6 +149,13 @@ def test_retry_parse_json_text():
|
|||
|
||||
|
||||
def test_extract_content_from_output():
|
||||
"""
|
||||
cases
|
||||
xxx [CONTENT] xxxx [/CONTENT]
|
||||
xxx [CONTENT] xxx [CONTENT] xxxx [/CONTENT]
|
||||
xxx [CONTENT] xxxx [/CONTENT] xxx [CONTENT][/CONTENT] xxx [CONTENT][/CONTENT] # target pair is the last one
|
||||
"""
|
||||
|
||||
output = 'Sure! Here is the properly formatted JSON output based on the given context:\n\n[CONTENT]\n{\n"' \
|
||||
'Required Python third-party packages": [\n"pygame==2.0.4",\n"pytest"\n],\n"Required Other language ' \
|
||||
'third-party packages": [\n"No third-party packages are required."\n],\n"Full API spec": "\nopenapi: ' \
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue