diff --git a/config/config.yaml b/config/config.yaml index bed67083c..72d2c0b19 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -94,4 +94,9 @@ MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k ### browser path for pyppeteer engine, support Chrome, Chromium,MS Edge #PYPPETEER_EXECUTABLE_PATH: "/usr/bin/google-chrome-stable" +### for repair non-openai LLM's output when parse json-text if PROMPT_FORMAT=json +### due to non-openai LLM's output will not always follow the instruction, so here activate a post-process +### repair operation on the content extracted from LLM's raw output. Warning, it improves the result but not fix all cases. +# REPAIR_LLM_OUTPUT: false + PROMPT_FORMAT: json #json or markdown \ No newline at end of file diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index f9e4f926b..7433c3857 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -5,17 +5,16 @@ @Author : alexanderwu @File : action.py """ -import re + from abc import ABC from typing import Optional -from tenacity import retry, stop_after_attempt, wait_fixed +from tenacity import retry, stop_after_attempt, wait_fixed, after_log from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM from metagpt.logs import logger from metagpt.utils.common import OutputParser -from metagpt.utils.custom_decoder import CustomDecoder from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType,\ retry_parse_json_text, extract_content_from_output @@ -51,7 +50,11 @@ class Action(ABC): system_msgs.append(self.prefix) return await self.llm.aask(prompt, system_msgs) - # @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level("ERROR").name), + ) async def _aask_v1( self, prompt: str, @@ -65,7 +68,7 @@ class Action(ABC): system_msgs = [] system_msgs.append(self.prefix) content = await self.llm.aask(prompt, system_msgs) - logger.debug(content) + logger.debug(f"llm raw output:\n{content}") output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) output_class_fields = list(output_class.schema()["properties"].keys()) # Custom ActionOutput's fields @@ -73,8 +76,8 @@ class Action(ABC): content = repair_llm_raw_output(content, req_keys=output_class_fields + ["[/CONTENT]"]) content = extract_content_from_output(content) content = repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON) # req_keys mocked - logger.info(f"extracted CONTENT from content:\n{content}") - parsed_data = retry_parse_json_text(content) + logger.info(f"extracted json CONTENT from output:\n{content}") + parsed_data = retry_parse_json_text(output=content) # should use output=content else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index a65e4be80..c26dc838d 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -4,8 +4,9 @@ import copy from enum import Enum -from typing import Union +from typing import Union, Callable import regex as re +from tenacity import retry, stop_after_attempt, wait_fixed, after_log, RetryCallState from metagpt.logs import logger from metagpt.config import CONFIG @@ -14,8 +15,8 @@ from metagpt.utils.custom_decoder import CustomDecoder class RepairType(Enum): CS = "case sensitivity" - SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]` RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]` + SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]` JSON = "json format" @@ -39,9 +40,11 @@ def repair_case_sensitivity(output: str, req_key: str) -> str: return output -def repair_special_character_missing(output: str, req_key: str) -> str: +def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str: """ - fix target string `[CONTENT]xxx[/CONTENT]` lacks [/CONTENT] + fix + 1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]` + 2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]` """ sc_arr = ["/"] @@ -55,30 +58,48 @@ def repair_special_character_missing(output: str, req_key: str) -> str: # req_key with special_character usually in the tail side ridx = output.rfind(req_key_pure) output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}" - logger.info(f"repair_special_character_missing: {req_key}") + logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}") return output -def repair_required_key_pair_missing(output: str, req_key: str) -> str: +def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str: """ implement the req_key pair in the begin or end of the content req_key format 1. `[req_key]`, and its pair `[/req_key]` 2. `[/req_key]`, and its pair `[req_key]` """ + sc = "/" # special char if req_key.startswith("[") and req_key.endswith("]"): - if "/" in req_key: - left_key = req_key.replace("/", "") # `[/req_key]` -> `[req_key]` + if sc in req_key: + left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]` right_key = req_key else: left_key = req_key - right_key = f"{req_key[0]}/{req_key[1:]}" # `[req_key]` -> `[/req_key]` + right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]` if left_key not in output: - output = left_key + output + output = left_key + "\n" + output if right_key not in output: - output = output + right_key + def judge_potential_json(routput: str, left_key: str) -> Union[str, bool]: + routput = copy.deepcopy(routput) + ridx = routput.rfind(left_key) + if ridx < 0: + return None + sub_output = routput[ridx:] + idx1 = sub_output.rfind("}") + idx2 = sub_output.rindex("]") + idx = idx1 if idx1 >= idx2 else idx2 + sub_output = sub_output[: idx] + return sub_output + + if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)): + # # avoid [req_key]xx[req_key] case to append [/req_key] + output = output + "\n" + right_key + elif judge_potential_json(output, left_key): + sub_content = judge_potential_json(output, left_key) + output = sub_content + "\n" + right_key return output @@ -106,12 +127,12 @@ def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = for repair_type in repair_types: if repair_type == RepairType.CS: output = repair_case_sensitivity(output, req_key) + elif repair_type == RepairType.RKPM: + output = repair_required_key_pair_missing(output, req_key) elif repair_type == RepairType.SCM: output = repair_special_character_missing(output, req_key) elif repair_type == RepairType.JSON: output = repair_json_format(output) - elif repair_type == RepairType.RKPM: - output = repair_required_key_pair_missing(output, req_key) return output @@ -178,25 +199,58 @@ def repair_invalid_json(output: str, error: str) -> str: return output -def retry_parse_json_text(output: str, retry: int = 5) -> Union[list, dict]: +def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]: + def run_and_passon(retry_state: RetryCallState) -> None: + """ + RetryCallState example + { + "start_time":143.098322024, + "retry_object":")>", + "fn":"", + "args":"(\"tag:[/CONTENT]\",)", # function input args + "kwargs":{}, # function input kwargs + "attempt_number":1, # retry number + "outcome":"", # type(outcome.result()) = "str", type(outcome.exception()) = "class" + "outcome_timestamp":143.098416904, + "idle_for":0, + "next_action":"None" + } + """ + if retry_state.outcome.failed: + if len(retry_state.args) > 0: + # # can't used as args=retry_state.args + func_param_output = retry_state.args[0] + elif len(retry_state.kwargs) > 0: + func_param_output = retry_state.kwargs.get("output", "") + # import pdb; pdb.set_trace() + exp_str = str(retry_state.outcome.exception()) + logger.warning(f"parse json from content inside [CONTENT][/CONTENT] failed at retry " + f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}") + + repaired_output = repair_invalid_json(func_param_output, exp_str) + retry_state.kwargs["output"] = repaired_output + + return run_and_passon + + +@retry( + stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0), + wait=wait_fixed(1), + after=run_after_exp_and_passon_next_retry(logger), +) +def retry_parse_json_text(output: str) -> Union[list, dict]: """ repair the json-text situation like there are extra chars like [']', '}'] + + Warning + if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work + if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. + it's a two-layer retry cycle """ - parsed_data = {} - for idx in range(retry): - raw_output = copy.deepcopy(output) + logger.debug(f"output to json decode:\n{output}") - try: - parsed_data = CustomDecoder(strict=False).decode(output) - break - except Exception as exp: - if not CONFIG.repair_llm_output: - # if repair_llm_output is False, break from the retry loop - break - - logger.warning(f"decode content into json failed, try to fix it. exp: {exp}") - error = str(exp) - output = repair_invalid_json(output, error) + # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break + parsed_data = CustomDecoder(strict=False).decode(output) return parsed_data diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 39a7343e7..dfcf60ad5 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -42,20 +42,69 @@ def test_repair_special_character_missing(): req_keys=req_keys) assert output == target_output + raw_output = """[CONTENT] tag +[CONTENT] +{ +"Anything UNCLEAR": "No unclear requirements or information." +} +[CONTENT]""" + target_output = """[CONTENT] tag +[CONTENT] +{ +"Anything UNCLEAR": "No unclear requirements or information." +} +[/CONTENT]""" + output = repair_llm_raw_output(output=raw_output, + req_keys=req_keys) + assert output == target_output + + raw_output = '[CONTENT] {"a": "b"} [CONTENT]' + target_output = '[CONTENT] {"a": "b"} [/CONTENT]' + + output = repair_llm_raw_output(output=raw_output, + req_keys=["[/CONTENT]"]) + print("output\n", output) + assert output == target_output + def test_required_key_pair_missing(): - raw_output = "[CONTENT] xxx" - target_output = "[CONTENT] xxx[/CONTENT]" + raw_output = '[CONTENT] {"a": "b"}' + target_output = '[CONTENT] {"a": "b"}\n[/CONTENT]' output = repair_llm_raw_output(output=raw_output, req_keys=["[/CONTENT]"]) assert output == target_output - raw_output = "xxx[/CONTENT]" - target_output = "[CONTENT]xxx[/CONTENT]" + raw_output = '''[CONTENT] +{ + "a": "b" +]''' + target_output = '''[CONTENT] +{ + "a": "b" +] +[/CONTENT]''' output = repair_llm_raw_output(output=raw_output, - req_keys=["[CONTENT]"]) + req_keys=["[/CONTENT]"]) + assert output == target_output + + raw_output = '''[CONTENT] tag +[CONTENT] +{ + "a": "b" +} +xxx +''' + target_output = '''[CONTENT] tag +[CONTENT] +{ + "a": "b" +} +[/CONTENT] +''' + output = repair_llm_raw_output(output=raw_output, + req_keys=["[/CONTENT]"]) assert output == target_output @@ -100,6 +149,13 @@ def test_retry_parse_json_text(): def test_extract_content_from_output(): + """ + cases + xxx [CONTENT] xxxx [/CONTENT] + xxx [CONTENT] xxx [CONTENT] xxxx [/CONTENT] + xxx [CONTENT] xxxx [/CONTENT] xxx [CONTENT][/CONTENT] xxx [CONTENT][/CONTENT] # target pair is the last one + """ + output = 'Sure! Here is the properly formatted JSON output based on the given context:\n\n[CONTENT]\n{\n"' \ 'Required Python third-party packages": [\n"pygame==2.0.4",\n"pytest"\n],\n"Required Other language ' \ 'third-party packages": [\n"No third-party packages are required."\n],\n"Full API spec": "\nopenapi: ' \