From ed54f6b86a7000de1487472a8900138933ac6c42 Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Fri, 26 Jan 2024 22:59:10 +0800 Subject: [PATCH 1/7] To avoid JSONDecodeError: Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py --- metagpt/actions/action_node.py | 2 ++ metagpt/utils/repair_llm_raw_output.py | 31 ++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 6c65b33ef..0f441cfee 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -24,6 +24,8 @@ TAG = "CONTENT" LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." +# Delete comments in json +# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT SIMPLE_TEMPLATE = """ diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index b71def136..4995918c2 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -120,13 +120,21 @@ def repair_json_format(output: str) -> str: elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" - # remove `#` in output json str, usually appeared in `glm-4` + # remove comments in output json str, after json value content, maybe start with #, maybe start with // arr = output.split("\n") new_arr = [] for line in arr: - idx = line.find("#") - if idx >= 0: - line = line[:idx] + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + line = line[:comment_index].rstrip() new_arr.append(line) output = "\n".join(new_arr) return output @@ -198,6 +206,21 @@ def repair_invalid_json(output: str, error: str) -> str: new_line = line.replace("}", "") elif line.endswith("},") and output.endswith("},"): new_line = line[:-1] + # remove comments in output json str, after json value content, maybe start with #, maybe start with // + elif rline[col_no] == "#" or rline[col_no] == "/": + new_line = rline[:col_no] + for i in range(line_no + 1, len(arr)): + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + arr[i] = arr[i][:comment_index].rstrip() elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" From 43b069f453d0ea351ba31b918b4fcb8bae5863e0 Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Fri, 26 Jan 2024 23:20:16 +0800 Subject: [PATCH 2/7] Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py --- metagpt/actions/action_node.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 0f441cfee..ed0e27869 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -23,7 +23,8 @@ from metagpt.utils.common import OutputParser, general_after_log TAG = "CONTENT" LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." -FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." +FORMAT_CONSTRAINT = (f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else. " + f"Delete comments in json") # Delete comments in json # If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT From f16b24758692bd10d89069c13a1260f9d15c968c Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 15:32:12 +0800 Subject: [PATCH 3/7] merge code with similar logic to avoid duplication --- ...move comments in output json str, after js | 12 + PR/action_node.py | 351 ++++++++++++++++++ PR/repair_llm_raw_output.py | 351 ++++++++++++++++++ metagpt/utils/repair_llm_raw_output.py | 43 ++- 4 files changed, 735 insertions(+), 22 deletions(-) create mode 100644 PR/# remove comments in output json str, after js create mode 100644 PR/action_node.py create mode 100644 PR/repair_llm_raw_output.py diff --git a/PR/# remove comments in output json str, after js b/PR/# remove comments in output json str, after js new file mode 100644 index 000000000..f795fefdb --- /dev/null +++ b/PR/# remove comments in output json str, after js @@ -0,0 +1,12 @@ + +git commit -m "To avoid JSONDecodeError: " -m "Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value" -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" + + + +git commit -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" + + + + + + diff --git a/PR/action_node.py b/PR/action_node.py new file mode 100644 index 000000000..0f441cfee --- /dev/null +++ b/PR/action_node.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/11 18:45 +@Author : alexanderwu +@File : action_node.py + +NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, + we can use typing to extract the type of the node, but we cannot use built-in list to extract. +""" +import json +from typing import Any, Dict, List, Optional, Tuple, Type + +from pydantic import BaseModel, create_model, model_validator +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.config import CONFIG +from metagpt.llm import BaseLLM +from metagpt.logs import logger +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess +from metagpt.utils.common import OutputParser, general_after_log + +TAG = "CONTENT" + +LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." +# Delete comments in json +# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT + + +SIMPLE_TEMPLATE = """ +## context +{context} + +----- + +## format example +{example} + +## nodes: ": # " +{instruction} + +## constraint +{constraint} + +## action +Follow instructions of nodes, generate output and make sure it follows the format example. +""" + + +def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): + markdown_str = "" + for key, value in d.items(): + markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" + return markdown_str + + +class ActionNode: + """ActionNode is a tree of nodes.""" + + schema: str # raw/json/markdown, default: "" + + # Action Context + context: str # all the context, including all necessary info + llm: BaseLLM # LLM with aask interface + children: dict[str, "ActionNode"] + + # Action Input + key: str # Product Requirement / File list / Code + expected_type: Type # such as str / int / float etc. + # context: str # everything in the history. + instruction: str # the instructions should be followed. + example: Any # example for In Context-Learning. + + # Action Output + content: str + instruct_content: BaseModel + + def __init__( + self, + key: str, + expected_type: Type, + instruction: str, + example: Any, + content: str = "", + children: dict[str, "ActionNode"] = None, + schema: str = "", + ): + self.key = key + self.expected_type = expected_type + self.instruction = instruction + self.example = example + self.content = content + self.children = children if children is not None else {} + self.schema = schema + + def __str__(self): + return ( + f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}" + f", {self.content}, {self.children}" + ) + + def __repr__(self): + return self.__str__() + + def add_child(self, node: "ActionNode"): + """增加子ActionNode""" + self.children[node.key] = node + + def add_children(self, nodes: List["ActionNode"]): + """批量增加子ActionNode""" + for node in nodes: + self.add_child(node) + + @classmethod + def from_children(cls, key, nodes: List["ActionNode"]): + """直接从一系列的子nodes初始化""" + obj = cls(key, str, "", "") + obj.add_children(nodes) + return obj + + def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: + """获得子ActionNode的字典,以key索引""" + exclude = exclude or [] + return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} + + def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: + """get self key: type mapping""" + return {self.key: (self.expected_type, ...)} + + def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: + """get key: type mapping under mode""" + if mode == "children" or (mode == "auto" and self.children): + return self.get_children_mapping(exclude=exclude) + return {} if exclude and self.key in exclude else self.get_self_mapping() + + @classmethod + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): + """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" + + def check_fields(cls, values): + required_fields = set(mapping.keys()) + missing_fields = required_fields - set(values.keys()) + if missing_fields: + raise ValueError(f"Missing fields: {missing_fields}") + + unrecognized_fields = set(values.keys()) - required_fields + if unrecognized_fields: + logger.warning(f"Unrecognized fields: {unrecognized_fields}") + return values + + validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)} + + new_class = create_model(class_name, __validators__=validators, **mapping) + return new_class + + def create_children_class(self, exclude=None): + """使用object内有的字段直接生成model_class""" + class_name = f"{self.key}_AN" + mapping = self.get_children_mapping(exclude=exclude) + return self.create_model_class(class_name, mapping) + + def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: + """将当前节点与子节点都按照node: format的格式组织成字典""" + + # 如果没有提供格式化函数,使用默认的格式化方式 + if format_func is None: + format_func = lambda node: f"{node.instruction}" + + # 使用提供的格式化函数来格式化当前节点的值 + formatted_value = format_func(self) + + # 创建当前节点的键值对 + if mode == "children" or (mode == "auto" and self.children): + node_dict = {} + else: + node_dict = {self.key: formatted_value} + + if mode == "root": + return node_dict + + # 遍历子节点并递归调用 to_dict 方法 + exclude = exclude or [] + for _, child_node in self.children.items(): + if child_node.key in exclude: + continue + node_dict.update(child_node.to_dict(format_func)) + + return node_dict + + def compile_to(self, i: Dict, schema, kv_sep) -> str: + if schema == "json": + return json.dumps(i, indent=4) + elif schema == "markdown": + return dict_to_markdown(i, kv_sep=kv_sep) + else: + return str(i) + + def tagging(self, text, schema, tag="") -> str: + if not tag: + return text + if schema == "json": + return f"[{tag}]\n" + text + f"\n[/{tag}]" + else: # markdown + return f"[{tag}]\n" + text + f"\n[/{tag}]" + + def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: + nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) + text = self.compile_to(nodes, schema, kv_sep) + return self.tagging(text, schema, tag) + + def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown template with all/root/children nodes""" + format_func = lambda i: f"{i.expected_type} # {i.instruction}" + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) + + def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown examples with all/root/children nodes""" + + # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example + # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str + format_func = lambda i: i.example + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) + + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: + """ + mode: all/root/children + mode="children": 编译所有子节点为一个统一模板,包括instruction与example + mode="all": NotImplemented + mode="root": NotImplemented + schmea: raw/json/markdown + schema="raw": 不编译,context, lang_constaint, instruction + schema="json":编译context, example(json), instruction(markdown), constraint, action + schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action + """ + if schema == "raw": + return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction + + # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", + # compile example暂时不支持markdown + instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) + example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) + # nodes = ", ".join(self.to_dict(mode=mode).keys()) + constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] + constraint = "\n".join(constraints) + + prompt = template.format( + context=context, + example=example, + instruction=instruction, + constraint=constraint, + ) + return prompt + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _aask_v1( + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + system_msgs: Optional[list[str]] = None, + schema="markdown", # compatible to original format + timeout=CONFIG.timeout, + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" + content = await self.llm.aask(prompt, system_msgs, timeout=timeout) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + + if schema == "json": + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + else: # using markdown parser + parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) + + logger.debug(f"parsed_data:\n{parsed_data}") + instruct_content = output_class(**parsed_data) + return content, instruct_content + + def get(self, key): + return self.instruct_content.model_dump()[key] + + def set_recursive(self, name, value): + setattr(self, name, value) + for _, i in self.children.items(): + i.set_recursive(name, value) + + def set_llm(self, llm): + self.set_recursive("llm", llm) + + def set_context(self, context): + self.set_recursive("context", context) + + async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): + prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) + + if schema != "raw": + mapping = self.get_mapping(mode, exclude=exclude) + class_name = f"{self.key}_AN" + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) + self.content = content + self.instruct_content = scontent + else: + self.content = await self.llm.aask(prompt) + self.instruct_content = None + + return self + + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): + """Fill the node(s) with mode. + + :param context: Everything we should know when filling node. + :param llm: Large Language Model with pre-defined system message. + :param schema: json/markdown, determine example and output format. + - raw: free form text + - json: it's easy to open source LLM with json format + - markdown: when generating code, markdown is always better + :param mode: auto/children/root + - auto: automated fill children's nodes and gather outputs, if no children, fill itself + - children: fill children's nodes and gather outputs + - root: fill root's node and gather output + :param strgy: simple/complex + - simple: run only once + - complex: run each node + :param timeout: Timeout for llm invocation. + :param exclude: The keys of ActionNode to exclude. + :return: self + """ + self.set_llm(llm) + self.set_context(context) + if self.schema: + schema = self.schema + + if strgy == "simple": + return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + elif strgy == "complex": + # 这里隐式假设了拥有children + tmp = {} + for _, i in self.children.items(): + if exclude and i.key in exclude: + continue + child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + tmp.update(child.instruct_content.dict()) + cls = self.create_children_class() + self.instruct_content = cls(**tmp) + return self diff --git a/PR/repair_llm_raw_output.py b/PR/repair_llm_raw_output.py new file mode 100644 index 000000000..4995918c2 --- /dev/null +++ b/PR/repair_llm_raw_output.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : repair llm raw output with particular conditions + +import copy +from enum import Enum +from typing import Callable, Union + +import regex as re +from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.utils.custom_decoder import CustomDecoder + + +class RepairType(Enum): + CS = "case sensitivity" + RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]` + SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]` + JSON = "json format" + + +def repair_case_sensitivity(output: str, req_key: str) -> str: + """ + usually, req_key is the key name of expected json or markdown content, it won't appear in the value part. + fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually + """ + if req_key in output: + return output + + output_lower = output.lower() + req_key_lower = req_key.lower() + if req_key_lower in output_lower: + # find the sub-part index, and replace it with raw req_key + lidx = output_lower.find(req_key_lower) + source = output[lidx : lidx + len(req_key_lower)] + output = output.replace(source, req_key) + logger.info(f"repair_case_sensitivity: {req_key}") + + return output + + +def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str: + """ + fix + 1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]` + 2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]` + """ + sc_arr = ["/"] + + if req_key in output: + return output + + for sc in sc_arr: + req_key_pure = req_key.replace(sc, "") + appear_cnt = output.count(req_key_pure) + if req_key_pure in output and appear_cnt > 1: + # req_key with special_character usually in the tail side + ridx = output.rfind(req_key_pure) + output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}" + logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}") + + return output + + +def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str: + """ + implement the req_key pair in the begin or end of the content + req_key format + 1. `[req_key]`, and its pair `[/req_key]` + 2. `[/req_key]`, and its pair `[req_key]` + """ + sc = "/" # special char + if req_key.startswith("[") and req_key.endswith("]"): + if sc in req_key: + left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]` + right_key = req_key + else: + left_key = req_key + right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]` + + if left_key not in output: + output = left_key + "\n" + output + if right_key not in output: + + def judge_potential_json(routput: str, left_key: str) -> Union[str, None]: + ridx = routput.rfind(left_key) + if ridx < 0: + return None + sub_output = routput[ridx:] + idx1 = sub_output.rfind("}") + idx2 = sub_output.rindex("]") + idx = idx1 if idx1 >= idx2 else idx2 + sub_output = sub_output[: idx + 1] + return sub_output + + if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)): + # # avoid [req_key]xx[req_key] case to append [/req_key] + output = output + "\n" + right_key + elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)): + sub_content = judge_potential_json(output, left_key) + output = sub_content + "\n" + right_key + + return output + + +def repair_json_format(output: str) -> str: + """ + fix extra `[` or `}` in the end + """ + output = output.strip() + + if output.startswith("[{"): + output = output[1:] + logger.info(f"repair_json_format: {'[{'}") + elif output.endswith("}]"): + output = output[:-1] + logger.info(f"repair_json_format: {'}]'}") + elif output.startswith("{") and output.endswith("]"): + output = output[:-1] + "}" + + # remove comments in output json str, after json value content, maybe start with #, maybe start with // + arr = output.split("\n") + new_arr = [] + for line in arr: + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + line = line[:comment_index].rstrip() + new_arr.append(line) + output = "\n".join(new_arr) + return output + + +def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str: + repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]] + for repair_type in repair_types: + if repair_type == RepairType.CS: + output = repair_case_sensitivity(output, req_key) + elif repair_type == RepairType.RKPM: + output = repair_required_key_pair_missing(output, req_key) + elif repair_type == RepairType.SCM: + output = repair_special_character_missing(output, req_key) + elif repair_type == RepairType.JSON: + output = repair_json_format(output) + return output + + +def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str: + """ + in open-source llm model, it usually can't follow the instruction well, the output may be incomplete, + so here we try to repair it and use all repair methods by default. + typical case + 1. case sensitivity + target: "Original Requirements" + output: "Original requirements" + 2. special character missing + target: [/CONTENT] + output: [CONTENT] + 3. json format + target: { xxx } + output: { xxx }] + """ + if not CONFIG.repair_llm_output: + return output + + # do the repairation usually for non-openai models + for req_key in req_keys: + output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type) + return output + + +def repair_invalid_json(output: str, error: str) -> str: + """ + repair the situation like there are extra chars like + error examples + example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765) + example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266) + """ + pattern = r"line ([0-9]+) column ([0-9]+)" + + matches = re.findall(pattern, error, re.DOTALL) + if len(matches) > 0: + line_no = int(matches[0][0]) - 1 + col_no = int(matches[0][1]) - 1 + + # due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'` + output = output.replace('"""', '"').replace("'''", '"') + arr = output.split("\n") + rline = arr[line_no] # raw line + line = arr[line_no].strip() + # different general problems + if line.endswith("],"): + # problem, redundant char `]` + new_line = line.replace("]", "") + elif line.endswith("},") and not output.endswith("},"): + # problem, redundant char `}` + new_line = line.replace("}", "") + elif line.endswith("},") and output.endswith("},"): + new_line = line[:-1] + # remove comments in output json str, after json value content, maybe start with #, maybe start with // + elif rline[col_no] == "#" or rline[col_no] == "/": + new_line = rline[:col_no] + for i in range(line_no + 1, len(arr)): + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + arr[i] = arr[i][:comment_index].rstrip() + elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: + # problem, `"""` or `'''` without `,` + new_line = f",{line}" + elif '",' not in line and "," not in line and '"' not in line: + new_line = f'{line}",' + elif not line.endswith(","): + # problem, miss char `,` at the end. + new_line = f"{line}," + elif "," in line and len(line) == 1: + new_line = f'"{line}' + elif '",' in line: + new_line = line[:-2] + "'," + else: + new_line = line + + arr[line_no] = new_line + output = "\n".join(arr) + logger.info(f"repair_invalid_json, raw error: {error}") + + return output + + +def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]: + def run_and_passon(retry_state: RetryCallState) -> None: + """ + RetryCallState example + { + "start_time":143.098322024, + "retry_object":")>", + "fn":"", + "args":"(\"tag:[/CONTENT]\",)", # function input args + "kwargs":{}, # function input kwargs + "attempt_number":1, # retry number + "outcome":"", # type(outcome.result()) = "str", type(outcome.exception()) = "class" + "outcome_timestamp":143.098416904, + "idle_for":0, + "next_action":"None" + } + """ + if retry_state.outcome.failed: + if retry_state.args: + # # can't be used as args=retry_state.args + func_param_output = retry_state.args[0] + elif retry_state.kwargs: + func_param_output = retry_state.kwargs.get("output", "") + exp_str = str(retry_state.outcome.exception()) + + fix_str = "try to fix it, " if CONFIG.repair_llm_output else "" + logger.warning( + f"parse json from content inside [CONTENT][/CONTENT] failed at retry " + f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}" + ) + + repaired_output = repair_invalid_json(func_param_output, exp_str) + retry_state.kwargs["output"] = repaired_output + + return run_and_passon + + +@retry( + stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0), + wait=wait_fixed(1), + after=run_after_exp_and_passon_next_retry(logger), +) +def retry_parse_json_text(output: str) -> Union[list, dict]: + """ + repair the json-text situation like there are extra chars like [']', '}'] + + Warning + if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work + if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. + it's a two-layer retry cycle + """ + # logger.debug(f"output to json decode:\n{output}") + + # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break + parsed_data = CustomDecoder(strict=False).decode(output) + + return parsed_data + + +def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"): + """extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern""" + + def re_extract_content(cont: str, pattern: str) -> str: + matches = re.findall(pattern, cont, re.DOTALL) + for match in matches: + if match: + cont = match + break + return cont.strip() + + # TODO construct the extract pattern with the `right_key` + raw_content = copy.deepcopy(content) + pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]" + new_content = re_extract_content(raw_content, pattern) + + if not new_content.startswith("{"): + # TODO find a more general pattern + # # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation + logger.warning(f"extract_content try another pattern: {pattern}") + if right_key not in new_content: + raw_content = copy.deepcopy(new_content + "\n" + right_key) + # # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" + new_content = re_extract_content(raw_content, pattern) + else: + if right_key in new_content: + idx = new_content.find(right_key) + new_content = new_content[:idx] + new_content = new_content.strip() + + return new_content + + +def extract_state_value_from_output(content: str) -> str: + """ + For openai models, they will always return state number. But for open llm models, the instruction result maybe a + long text contain target number, so here add a extraction to improve success rate. + + Args: + content (str): llm's output from `Role._think` + """ + content = content.strip() # deal the output cases like " 0", "0\n" and so on. + pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern + matches = re.findall(pattern, content, re.DOTALL) + matches = list(set(matches)) + state = matches[0] if len(matches) > 0 else "-1" + return state diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 4995918c2..ef3580750 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -105,6 +105,23 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") - return output +def remove_comments_from_line(line): + """ + Remove comments from a single line of string. + Comments are assumed to start with '#' or '//' and are not inside string values. + """ + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + if comment_index != -1: # if comments, then delete them + return line[:comment_index].rstrip() + return line + + def repair_json_format(output: str) -> str: """ fix extra `[` or `}` in the end @@ -125,17 +142,8 @@ def repair_json_format(output: str) -> str: new_arr = [] for line in arr: # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - line = line[:comment_index].rstrip() - new_arr.append(line) + new_line = remove_comments_from_line(line) + new_arr.append(new_line) output = "\n".join(new_arr) return output @@ -209,18 +217,9 @@ def repair_invalid_json(output: str, error: str) -> str: # remove comments in output json str, after json value content, maybe start with #, maybe start with // elif rline[col_no] == "#" or rline[col_no] == "/": new_line = rline[:col_no] + # check the next line and remove the comments for i in range(line_no + 1, len(arr)): - # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - arr[i] = arr[i][:comment_index].rstrip() + arr[i] = remove_comments_from_line(arr[i]) elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" From 8b5f7848fa6aa8c9e0703dcdbf6760ef1efa87eb Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 17:00:59 +0800 Subject: [PATCH 4/7] delete PR dir --- metagpt/utils/repair_llm_raw_output.py | 43 +++++++++++++------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 4995918c2..ef3580750 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -105,6 +105,23 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") - return output +def remove_comments_from_line(line): + """ + Remove comments from a single line of string. + Comments are assumed to start with '#' or '//' and are not inside string values. + """ + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + if comment_index != -1: # if comments, then delete them + return line[:comment_index].rstrip() + return line + + def repair_json_format(output: str) -> str: """ fix extra `[` or `}` in the end @@ -125,17 +142,8 @@ def repair_json_format(output: str) -> str: new_arr = [] for line in arr: # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - line = line[:comment_index].rstrip() - new_arr.append(line) + new_line = remove_comments_from_line(line) + new_arr.append(new_line) output = "\n".join(new_arr) return output @@ -209,18 +217,9 @@ def repair_invalid_json(output: str, error: str) -> str: # remove comments in output json str, after json value content, maybe start with #, maybe start with // elif rline[col_no] == "#" or rline[col_no] == "/": new_line = rline[:col_no] + # check the next line and remove the comments for i in range(line_no + 1, len(arr)): - # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - arr[i] = arr[i][:comment_index].rstrip() + arr[i] = remove_comments_from_line(arr[i]) elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" From 2361c7e8aa2df55bd3562243e95b4d5f538188ff Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 17:07:21 +0800 Subject: [PATCH 5/7] delete PR dir --- ...move comments in output json str, after js | 12 - PR/action_node.py | 351 ------------------ PR/repair_llm_raw_output.py | 351 ------------------ 3 files changed, 714 deletions(-) delete mode 100644 PR/# remove comments in output json str, after js delete mode 100644 PR/action_node.py delete mode 100644 PR/repair_llm_raw_output.py diff --git a/PR/# remove comments in output json str, after js b/PR/# remove comments in output json str, after js deleted file mode 100644 index f795fefdb..000000000 --- a/PR/# remove comments in output json str, after js +++ /dev/null @@ -1,12 +0,0 @@ - -git commit -m "To avoid JSONDecodeError: " -m "Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value" -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" - - - -git commit -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" - - - - - - diff --git a/PR/action_node.py b/PR/action_node.py deleted file mode 100644 index 0f441cfee..000000000 --- a/PR/action_node.py +++ /dev/null @@ -1,351 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/12/11 18:45 -@Author : alexanderwu -@File : action_node.py - -NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, - we can use typing to extract the type of the node, but we cannot use built-in list to extract. -""" -import json -from typing import Any, Dict, List, Optional, Tuple, Type - -from pydantic import BaseModel, create_model, model_validator -from tenacity import retry, stop_after_attempt, wait_random_exponential - -from metagpt.config import CONFIG -from metagpt.llm import BaseLLM -from metagpt.logs import logger -from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess -from metagpt.utils.common import OutputParser, general_after_log - -TAG = "CONTENT" - -LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." -FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." -# Delete comments in json -# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT - - -SIMPLE_TEMPLATE = """ -## context -{context} - ------ - -## format example -{example} - -## nodes: ": # " -{instruction} - -## constraint -{constraint} - -## action -Follow instructions of nodes, generate output and make sure it follows the format example. -""" - - -def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): - markdown_str = "" - for key, value in d.items(): - markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" - return markdown_str - - -class ActionNode: - """ActionNode is a tree of nodes.""" - - schema: str # raw/json/markdown, default: "" - - # Action Context - context: str # all the context, including all necessary info - llm: BaseLLM # LLM with aask interface - children: dict[str, "ActionNode"] - - # Action Input - key: str # Product Requirement / File list / Code - expected_type: Type # such as str / int / float etc. - # context: str # everything in the history. - instruction: str # the instructions should be followed. - example: Any # example for In Context-Learning. - - # Action Output - content: str - instruct_content: BaseModel - - def __init__( - self, - key: str, - expected_type: Type, - instruction: str, - example: Any, - content: str = "", - children: dict[str, "ActionNode"] = None, - schema: str = "", - ): - self.key = key - self.expected_type = expected_type - self.instruction = instruction - self.example = example - self.content = content - self.children = children if children is not None else {} - self.schema = schema - - def __str__(self): - return ( - f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}" - f", {self.content}, {self.children}" - ) - - def __repr__(self): - return self.__str__() - - def add_child(self, node: "ActionNode"): - """增加子ActionNode""" - self.children[node.key] = node - - def add_children(self, nodes: List["ActionNode"]): - """批量增加子ActionNode""" - for node in nodes: - self.add_child(node) - - @classmethod - def from_children(cls, key, nodes: List["ActionNode"]): - """直接从一系列的子nodes初始化""" - obj = cls(key, str, "", "") - obj.add_children(nodes) - return obj - - def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: - """获得子ActionNode的字典,以key索引""" - exclude = exclude or [] - return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} - - def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: - """get self key: type mapping""" - return {self.key: (self.expected_type, ...)} - - def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: - """get key: type mapping under mode""" - if mode == "children" or (mode == "auto" and self.children): - return self.get_children_mapping(exclude=exclude) - return {} if exclude and self.key in exclude else self.get_self_mapping() - - @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): - """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" - - def check_fields(cls, values): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(values.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - - unrecognized_fields = set(values.keys()) - required_fields - if unrecognized_fields: - logger.warning(f"Unrecognized fields: {unrecognized_fields}") - return values - - validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)} - - new_class = create_model(class_name, __validators__=validators, **mapping) - return new_class - - def create_children_class(self, exclude=None): - """使用object内有的字段直接生成model_class""" - class_name = f"{self.key}_AN" - mapping = self.get_children_mapping(exclude=exclude) - return self.create_model_class(class_name, mapping) - - def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: - """将当前节点与子节点都按照node: format的格式组织成字典""" - - # 如果没有提供格式化函数,使用默认的格式化方式 - if format_func is None: - format_func = lambda node: f"{node.instruction}" - - # 使用提供的格式化函数来格式化当前节点的值 - formatted_value = format_func(self) - - # 创建当前节点的键值对 - if mode == "children" or (mode == "auto" and self.children): - node_dict = {} - else: - node_dict = {self.key: formatted_value} - - if mode == "root": - return node_dict - - # 遍历子节点并递归调用 to_dict 方法 - exclude = exclude or [] - for _, child_node in self.children.items(): - if child_node.key in exclude: - continue - node_dict.update(child_node.to_dict(format_func)) - - return node_dict - - def compile_to(self, i: Dict, schema, kv_sep) -> str: - if schema == "json": - return json.dumps(i, indent=4) - elif schema == "markdown": - return dict_to_markdown(i, kv_sep=kv_sep) - else: - return str(i) - - def tagging(self, text, schema, tag="") -> str: - if not tag: - return text - if schema == "json": - return f"[{tag}]\n" + text + f"\n[/{tag}]" - else: # markdown - return f"[{tag}]\n" + text + f"\n[/{tag}]" - - def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: - nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) - text = self.compile_to(nodes, schema, kv_sep) - return self.tagging(text, schema, tag) - - def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: - """compile to raw/json/markdown template with all/root/children nodes""" - format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) - - def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: - """compile to raw/json/markdown examples with all/root/children nodes""" - - # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example - # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str - format_func = lambda i: i.example - return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) - - def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: - """ - mode: all/root/children - mode="children": 编译所有子节点为一个统一模板,包括instruction与example - mode="all": NotImplemented - mode="root": NotImplemented - schmea: raw/json/markdown - schema="raw": 不编译,context, lang_constaint, instruction - schema="json":编译context, example(json), instruction(markdown), constraint, action - schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action - """ - if schema == "raw": - return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction - - # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", - # compile example暂时不支持markdown - instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) - example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) - # nodes = ", ".join(self.to_dict(mode=mode).keys()) - constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] - constraint = "\n".join(constraints) - - prompt = template.format( - context=context, - example=example, - instruction=instruction, - constraint=constraint, - ) - return prompt - - @retry( - wait=wait_random_exponential(min=1, max=20), - stop=stop_after_attempt(6), - after=general_after_log(logger), - ) - async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - schema="markdown", # compatible to original format - timeout=CONFIG.timeout, - ) -> (str, BaseModel): - """Use ActionOutput to wrap the output of aask""" - content = await self.llm.aask(prompt, system_msgs, timeout=timeout) - logger.debug(f"llm raw output:\n{content}") - output_class = self.create_model_class(output_class_name, output_data_mapping) - - if schema == "json": - parsed_data = llm_output_postprocess( - output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" - ) - else: # using markdown parser - parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - - logger.debug(f"parsed_data:\n{parsed_data}") - instruct_content = output_class(**parsed_data) - return content, instruct_content - - def get(self, key): - return self.instruct_content.model_dump()[key] - - def set_recursive(self, name, value): - setattr(self, name, value) - for _, i in self.children.items(): - i.set_recursive(name, value) - - def set_llm(self, llm): - self.set_recursive("llm", llm) - - def set_context(self, context): - self.set_recursive("context", context) - - async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): - prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) - - if schema != "raw": - mapping = self.get_mapping(mode, exclude=exclude) - class_name = f"{self.key}_AN" - content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) - self.content = content - self.instruct_content = scontent - else: - self.content = await self.llm.aask(prompt) - self.instruct_content = None - - return self - - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): - """Fill the node(s) with mode. - - :param context: Everything we should know when filling node. - :param llm: Large Language Model with pre-defined system message. - :param schema: json/markdown, determine example and output format. - - raw: free form text - - json: it's easy to open source LLM with json format - - markdown: when generating code, markdown is always better - :param mode: auto/children/root - - auto: automated fill children's nodes and gather outputs, if no children, fill itself - - children: fill children's nodes and gather outputs - - root: fill root's node and gather output - :param strgy: simple/complex - - simple: run only once - - complex: run each node - :param timeout: Timeout for llm invocation. - :param exclude: The keys of ActionNode to exclude. - :return: self - """ - self.set_llm(llm) - self.set_context(context) - if self.schema: - schema = self.schema - - if strgy == "simple": - return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) - elif strgy == "complex": - # 这里隐式假设了拥有children - tmp = {} - for _, i in self.children.items(): - if exclude and i.key in exclude: - continue - child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) - tmp.update(child.instruct_content.dict()) - cls = self.create_children_class() - self.instruct_content = cls(**tmp) - return self diff --git a/PR/repair_llm_raw_output.py b/PR/repair_llm_raw_output.py deleted file mode 100644 index 4995918c2..000000000 --- a/PR/repair_llm_raw_output.py +++ /dev/null @@ -1,351 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : repair llm raw output with particular conditions - -import copy -from enum import Enum -from typing import Callable, Union - -import regex as re -from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed - -from metagpt.config import CONFIG -from metagpt.logs import logger -from metagpt.utils.custom_decoder import CustomDecoder - - -class RepairType(Enum): - CS = "case sensitivity" - RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]` - SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]` - JSON = "json format" - - -def repair_case_sensitivity(output: str, req_key: str) -> str: - """ - usually, req_key is the key name of expected json or markdown content, it won't appear in the value part. - fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually - """ - if req_key in output: - return output - - output_lower = output.lower() - req_key_lower = req_key.lower() - if req_key_lower in output_lower: - # find the sub-part index, and replace it with raw req_key - lidx = output_lower.find(req_key_lower) - source = output[lidx : lidx + len(req_key_lower)] - output = output.replace(source, req_key) - logger.info(f"repair_case_sensitivity: {req_key}") - - return output - - -def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str: - """ - fix - 1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]` - 2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]` - """ - sc_arr = ["/"] - - if req_key in output: - return output - - for sc in sc_arr: - req_key_pure = req_key.replace(sc, "") - appear_cnt = output.count(req_key_pure) - if req_key_pure in output and appear_cnt > 1: - # req_key with special_character usually in the tail side - ridx = output.rfind(req_key_pure) - output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}" - logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}") - - return output - - -def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str: - """ - implement the req_key pair in the begin or end of the content - req_key format - 1. `[req_key]`, and its pair `[/req_key]` - 2. `[/req_key]`, and its pair `[req_key]` - """ - sc = "/" # special char - if req_key.startswith("[") and req_key.endswith("]"): - if sc in req_key: - left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]` - right_key = req_key - else: - left_key = req_key - right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]` - - if left_key not in output: - output = left_key + "\n" + output - if right_key not in output: - - def judge_potential_json(routput: str, left_key: str) -> Union[str, None]: - ridx = routput.rfind(left_key) - if ridx < 0: - return None - sub_output = routput[ridx:] - idx1 = sub_output.rfind("}") - idx2 = sub_output.rindex("]") - idx = idx1 if idx1 >= idx2 else idx2 - sub_output = sub_output[: idx + 1] - return sub_output - - if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)): - # # avoid [req_key]xx[req_key] case to append [/req_key] - output = output + "\n" + right_key - elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)): - sub_content = judge_potential_json(output, left_key) - output = sub_content + "\n" + right_key - - return output - - -def repair_json_format(output: str) -> str: - """ - fix extra `[` or `}` in the end - """ - output = output.strip() - - if output.startswith("[{"): - output = output[1:] - logger.info(f"repair_json_format: {'[{'}") - elif output.endswith("}]"): - output = output[:-1] - logger.info(f"repair_json_format: {'}]'}") - elif output.startswith("{") and output.endswith("]"): - output = output[:-1] + "}" - - # remove comments in output json str, after json value content, maybe start with #, maybe start with // - arr = output.split("\n") - new_arr = [] - for line in arr: - # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - line = line[:comment_index].rstrip() - new_arr.append(line) - output = "\n".join(new_arr) - return output - - -def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str: - repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]] - for repair_type in repair_types: - if repair_type == RepairType.CS: - output = repair_case_sensitivity(output, req_key) - elif repair_type == RepairType.RKPM: - output = repair_required_key_pair_missing(output, req_key) - elif repair_type == RepairType.SCM: - output = repair_special_character_missing(output, req_key) - elif repair_type == RepairType.JSON: - output = repair_json_format(output) - return output - - -def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str: - """ - in open-source llm model, it usually can't follow the instruction well, the output may be incomplete, - so here we try to repair it and use all repair methods by default. - typical case - 1. case sensitivity - target: "Original Requirements" - output: "Original requirements" - 2. special character missing - target: [/CONTENT] - output: [CONTENT] - 3. json format - target: { xxx } - output: { xxx }] - """ - if not CONFIG.repair_llm_output: - return output - - # do the repairation usually for non-openai models - for req_key in req_keys: - output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type) - return output - - -def repair_invalid_json(output: str, error: str) -> str: - """ - repair the situation like there are extra chars like - error examples - example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765) - example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266) - """ - pattern = r"line ([0-9]+) column ([0-9]+)" - - matches = re.findall(pattern, error, re.DOTALL) - if len(matches) > 0: - line_no = int(matches[0][0]) - 1 - col_no = int(matches[0][1]) - 1 - - # due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'` - output = output.replace('"""', '"').replace("'''", '"') - arr = output.split("\n") - rline = arr[line_no] # raw line - line = arr[line_no].strip() - # different general problems - if line.endswith("],"): - # problem, redundant char `]` - new_line = line.replace("]", "") - elif line.endswith("},") and not output.endswith("},"): - # problem, redundant char `}` - new_line = line.replace("}", "") - elif line.endswith("},") and output.endswith("},"): - new_line = line[:-1] - # remove comments in output json str, after json value content, maybe start with #, maybe start with // - elif rline[col_no] == "#" or rline[col_no] == "/": - new_line = rline[:col_no] - for i in range(line_no + 1, len(arr)): - # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - arr[i] = arr[i][:comment_index].rstrip() - elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: - # problem, `"""` or `'''` without `,` - new_line = f",{line}" - elif '",' not in line and "," not in line and '"' not in line: - new_line = f'{line}",' - elif not line.endswith(","): - # problem, miss char `,` at the end. - new_line = f"{line}," - elif "," in line and len(line) == 1: - new_line = f'"{line}' - elif '",' in line: - new_line = line[:-2] + "'," - else: - new_line = line - - arr[line_no] = new_line - output = "\n".join(arr) - logger.info(f"repair_invalid_json, raw error: {error}") - - return output - - -def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]: - def run_and_passon(retry_state: RetryCallState) -> None: - """ - RetryCallState example - { - "start_time":143.098322024, - "retry_object":")>", - "fn":"", - "args":"(\"tag:[/CONTENT]\",)", # function input args - "kwargs":{}, # function input kwargs - "attempt_number":1, # retry number - "outcome":"", # type(outcome.result()) = "str", type(outcome.exception()) = "class" - "outcome_timestamp":143.098416904, - "idle_for":0, - "next_action":"None" - } - """ - if retry_state.outcome.failed: - if retry_state.args: - # # can't be used as args=retry_state.args - func_param_output = retry_state.args[0] - elif retry_state.kwargs: - func_param_output = retry_state.kwargs.get("output", "") - exp_str = str(retry_state.outcome.exception()) - - fix_str = "try to fix it, " if CONFIG.repair_llm_output else "" - logger.warning( - f"parse json from content inside [CONTENT][/CONTENT] failed at retry " - f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}" - ) - - repaired_output = repair_invalid_json(func_param_output, exp_str) - retry_state.kwargs["output"] = repaired_output - - return run_and_passon - - -@retry( - stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0), - wait=wait_fixed(1), - after=run_after_exp_and_passon_next_retry(logger), -) -def retry_parse_json_text(output: str) -> Union[list, dict]: - """ - repair the json-text situation like there are extra chars like [']', '}'] - - Warning - if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work - if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. - it's a two-layer retry cycle - """ - # logger.debug(f"output to json decode:\n{output}") - - # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break - parsed_data = CustomDecoder(strict=False).decode(output) - - return parsed_data - - -def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"): - """extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern""" - - def re_extract_content(cont: str, pattern: str) -> str: - matches = re.findall(pattern, cont, re.DOTALL) - for match in matches: - if match: - cont = match - break - return cont.strip() - - # TODO construct the extract pattern with the `right_key` - raw_content = copy.deepcopy(content) - pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]" - new_content = re_extract_content(raw_content, pattern) - - if not new_content.startswith("{"): - # TODO find a more general pattern - # # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation - logger.warning(f"extract_content try another pattern: {pattern}") - if right_key not in new_content: - raw_content = copy.deepcopy(new_content + "\n" + right_key) - # # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" - new_content = re_extract_content(raw_content, pattern) - else: - if right_key in new_content: - idx = new_content.find(right_key) - new_content = new_content[:idx] - new_content = new_content.strip() - - return new_content - - -def extract_state_value_from_output(content: str) -> str: - """ - For openai models, they will always return state number. But for open llm models, the instruction result maybe a - long text contain target number, so here add a extraction to improve success rate. - - Args: - content (str): llm's output from `Role._think` - """ - content = content.strip() # deal the output cases like " 0", "0\n" and so on. - pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern - matches = re.findall(pattern, content, re.DOTALL) - matches = list(set(matches)) - state = matches[0] if len(matches) > 0 else "-1" - return state From 11f70ca9b1714dc312a847505c9940f0c60a24b1 Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 18:06:52 +0800 Subject: [PATCH 6/7] modify code based on feedback of action_node.py and repair_llm_raw_output.py, add code in test_repair_llm_raw_output.py --- metagpt/actions/action_node.py | 5 +--- metagpt/utils/repair_llm_raw_output.py | 9 +------ .../utils/test_repair_llm_raw_output.py | 26 +++++++++++++++++++ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index ed0e27869..6c65b33ef 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -23,10 +23,7 @@ from metagpt.utils.common import OutputParser, general_after_log TAG = "CONTENT" LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." -FORMAT_CONSTRAINT = (f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else. " - f"Delete comments in json") -# Delete comments in json -# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." SIMPLE_TEMPLATE = """ diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index ef3580750..973cffb8a 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -137,11 +137,10 @@ def repair_json_format(output: str) -> str: elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" - # remove comments in output json str, after json value content, maybe start with #, maybe start with // + # remove comments in output json string arr = output.split("\n") new_arr = [] for line in arr: - # look for # or // comments and make sure they are not inside the string value new_line = remove_comments_from_line(line) new_arr.append(new_line) output = "\n".join(new_arr) @@ -214,12 +213,6 @@ def repair_invalid_json(output: str, error: str) -> str: new_line = line.replace("}", "") elif line.endswith("},") and output.endswith("},"): new_line = line[:-1] - # remove comments in output json str, after json value content, maybe start with #, maybe start with // - elif rline[col_no] == "#" or rline[col_no] == "/": - new_line = rline[:col_no] - # check the next line and remove the comments - for i in range(line_no + 1, len(arr)): - arr[i] = remove_comments_from_line(arr[i]) elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 1f809a081..9d53b8243 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -141,6 +141,32 @@ def test_repair_json_format(): output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) assert output == target_output + raw_output = """ +{ + "Language": "en_us", // define language + "Programming Language": "Python" # define code language +} +""" + target_output = """{ + "Language": "en_us", + "Programming Language": "Python" +}""" + output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) + assert output == target_output + + raw_output = """ + { + "Language": "#en_us#", // define language + "Programming Language": "//Python # Code // Language//" # define code language + } + """ + target_output = """{ + "Language": "#en_us#", + "Programming Language": "//Python # Code // Language//" + }""" + output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) + assert output == target_output + def test_repair_invalid_json(): from metagpt.utils.repair_llm_raw_output import repair_invalid_json From c3b4c698d80cba70e446cd6a97f375459c8c5595 Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 18:23:57 +0800 Subject: [PATCH 7/7] update repair_llm_raw_output.py --- metagpt/utils/repair_llm_raw_output.py | 36 ++++++++++---------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 973cffb8a..6da974d96 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -105,23 +105,6 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") - return output -def remove_comments_from_line(line): - """ - Remove comments from a single line of string. - Comments are assumed to start with '#' or '//' and are not inside string values. - """ - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - if comment_index != -1: # if comments, then delete them - return line[:comment_index].rstrip() - return line - - def repair_json_format(output: str) -> str: """ fix extra `[` or `}` in the end @@ -136,13 +119,22 @@ def repair_json_format(output: str) -> str: logger.info(f"repair_json_format: {'}]'}") elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" - - # remove comments in output json string + # remove comments in output json string, after json value content, maybe start with #, maybe start with // arr = output.split("\n") new_arr = [] - for line in arr: - new_line = remove_comments_from_line(line) - new_arr.append(new_line) + for json_line in arr: + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", json_line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + json_line = json_line[:comment_index].rstrip() + new_arr.append(json_line) output = "\n".join(new_arr) return output