From 91c5c7208936bacfbeac5938b0a01e1f7ee47c12 Mon Sep 17 00:00:00 2001 From: Arnaud Gelas Date: Mon, 22 Jan 2024 20:10:11 +0100 Subject: [PATCH 1/9] Fix prompt logic when defining to who the message should be sent. With the previous logic, it was possible to reach an undefined state where it was not meant to be sent to Engineer, QaEngineer, nor NoOne. --- metagpt/actions/run_code.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 30b06f1a6..885f4e12c 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -42,8 +42,8 @@ Determine the ONE file to rewrite in order to fix the error, for example, xyz.py Determine if all of the code works fine, if so write PASS, else FAIL, WRITE ONLY ONE WORD, PASS OR FAIL, IN THIS SECTION ## Send To: -Please write Engineer if the errors are due to problematic development codes, and QaEngineer to problematic test codes, and NoOne if there are no errors, -WRITE ONLY ONE WORD, Engineer OR QaEngineer OR NoOne, IN THIS SECTION. +Please write NoOne if there are no errors, Engineer if the errors are due to problematic development codes, else QaEngineer, +WRITE ONLY ONE WORD, NoOne OR Engineer OR QaEngineer, IN THIS SECTION. --- You should fill in necessary instruction, status, send to, and finally return all content between the --- segment line. """ From ed54f6b86a7000de1487472a8900138933ac6c42 Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Fri, 26 Jan 2024 22:59:10 +0800 Subject: [PATCH 2/9] To avoid JSONDecodeError: Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py --- metagpt/actions/action_node.py | 2 ++ metagpt/utils/repair_llm_raw_output.py | 31 ++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 6c65b33ef..0f441cfee 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -24,6 +24,8 @@ TAG = "CONTENT" LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." +# Delete comments in json +# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT SIMPLE_TEMPLATE = """ diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index b71def136..4995918c2 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -120,13 +120,21 @@ def repair_json_format(output: str) -> str: elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" - # remove `#` in output json str, usually appeared in `glm-4` + # remove comments in output json str, after json value content, maybe start with #, maybe start with // arr = output.split("\n") new_arr = [] for line in arr: - idx = line.find("#") - if idx >= 0: - line = line[:idx] + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + line = line[:comment_index].rstrip() new_arr.append(line) output = "\n".join(new_arr) return output @@ -198,6 +206,21 @@ def repair_invalid_json(output: str, error: str) -> str: new_line = line.replace("}", "") elif line.endswith("},") and output.endswith("},"): new_line = line[:-1] + # remove comments in output json str, after json value content, maybe start with #, maybe start with // + elif rline[col_no] == "#" or rline[col_no] == "/": + new_line = rline[:col_no] + for i in range(line_no + 1, len(arr)): + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + arr[i] = arr[i][:comment_index].rstrip() elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" From 43b069f453d0ea351ba31b918b4fcb8bae5863e0 Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Fri, 26 Jan 2024 23:20:16 +0800 Subject: [PATCH 3/9] Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py --- metagpt/actions/action_node.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 0f441cfee..ed0e27869 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -23,7 +23,8 @@ from metagpt.utils.common import OutputParser, general_after_log TAG = "CONTENT" LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." -FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." +FORMAT_CONSTRAINT = (f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else. " + f"Delete comments in json") # Delete comments in json # If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT From f16b24758692bd10d89069c13a1260f9d15c968c Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 15:32:12 +0800 Subject: [PATCH 4/9] merge code with similar logic to avoid duplication --- ...move comments in output json str, after js | 12 + PR/action_node.py | 351 ++++++++++++++++++ PR/repair_llm_raw_output.py | 351 ++++++++++++++++++ metagpt/utils/repair_llm_raw_output.py | 43 ++- 4 files changed, 735 insertions(+), 22 deletions(-) create mode 100644 PR/# remove comments in output json str, after js create mode 100644 PR/action_node.py create mode 100644 PR/repair_llm_raw_output.py diff --git a/PR/# remove comments in output json str, after js b/PR/# remove comments in output json str, after js new file mode 100644 index 000000000..f795fefdb --- /dev/null +++ b/PR/# remove comments in output json str, after js @@ -0,0 +1,12 @@ + +git commit -m "To avoid JSONDecodeError: " -m "Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value" -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" + + + +git commit -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" + + + + + + diff --git a/PR/action_node.py b/PR/action_node.py new file mode 100644 index 000000000..0f441cfee --- /dev/null +++ b/PR/action_node.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/11 18:45 +@Author : alexanderwu +@File : action_node.py + +NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, + we can use typing to extract the type of the node, but we cannot use built-in list to extract. +""" +import json +from typing import Any, Dict, List, Optional, Tuple, Type + +from pydantic import BaseModel, create_model, model_validator +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.config import CONFIG +from metagpt.llm import BaseLLM +from metagpt.logs import logger +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess +from metagpt.utils.common import OutputParser, general_after_log + +TAG = "CONTENT" + +LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." +# Delete comments in json +# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT + + +SIMPLE_TEMPLATE = """ +## context +{context} + +----- + +## format example +{example} + +## nodes: ": # " +{instruction} + +## constraint +{constraint} + +## action +Follow instructions of nodes, generate output and make sure it follows the format example. +""" + + +def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): + markdown_str = "" + for key, value in d.items(): + markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" + return markdown_str + + +class ActionNode: + """ActionNode is a tree of nodes.""" + + schema: str # raw/json/markdown, default: "" + + # Action Context + context: str # all the context, including all necessary info + llm: BaseLLM # LLM with aask interface + children: dict[str, "ActionNode"] + + # Action Input + key: str # Product Requirement / File list / Code + expected_type: Type # such as str / int / float etc. + # context: str # everything in the history. + instruction: str # the instructions should be followed. + example: Any # example for In Context-Learning. + + # Action Output + content: str + instruct_content: BaseModel + + def __init__( + self, + key: str, + expected_type: Type, + instruction: str, + example: Any, + content: str = "", + children: dict[str, "ActionNode"] = None, + schema: str = "", + ): + self.key = key + self.expected_type = expected_type + self.instruction = instruction + self.example = example + self.content = content + self.children = children if children is not None else {} + self.schema = schema + + def __str__(self): + return ( + f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}" + f", {self.content}, {self.children}" + ) + + def __repr__(self): + return self.__str__() + + def add_child(self, node: "ActionNode"): + """增加子ActionNode""" + self.children[node.key] = node + + def add_children(self, nodes: List["ActionNode"]): + """批量增加子ActionNode""" + for node in nodes: + self.add_child(node) + + @classmethod + def from_children(cls, key, nodes: List["ActionNode"]): + """直接从一系列的子nodes初始化""" + obj = cls(key, str, "", "") + obj.add_children(nodes) + return obj + + def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: + """获得子ActionNode的字典,以key索引""" + exclude = exclude or [] + return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} + + def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: + """get self key: type mapping""" + return {self.key: (self.expected_type, ...)} + + def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: + """get key: type mapping under mode""" + if mode == "children" or (mode == "auto" and self.children): + return self.get_children_mapping(exclude=exclude) + return {} if exclude and self.key in exclude else self.get_self_mapping() + + @classmethod + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): + """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" + + def check_fields(cls, values): + required_fields = set(mapping.keys()) + missing_fields = required_fields - set(values.keys()) + if missing_fields: + raise ValueError(f"Missing fields: {missing_fields}") + + unrecognized_fields = set(values.keys()) - required_fields + if unrecognized_fields: + logger.warning(f"Unrecognized fields: {unrecognized_fields}") + return values + + validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)} + + new_class = create_model(class_name, __validators__=validators, **mapping) + return new_class + + def create_children_class(self, exclude=None): + """使用object内有的字段直接生成model_class""" + class_name = f"{self.key}_AN" + mapping = self.get_children_mapping(exclude=exclude) + return self.create_model_class(class_name, mapping) + + def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: + """将当前节点与子节点都按照node: format的格式组织成字典""" + + # 如果没有提供格式化函数,使用默认的格式化方式 + if format_func is None: + format_func = lambda node: f"{node.instruction}" + + # 使用提供的格式化函数来格式化当前节点的值 + formatted_value = format_func(self) + + # 创建当前节点的键值对 + if mode == "children" or (mode == "auto" and self.children): + node_dict = {} + else: + node_dict = {self.key: formatted_value} + + if mode == "root": + return node_dict + + # 遍历子节点并递归调用 to_dict 方法 + exclude = exclude or [] + for _, child_node in self.children.items(): + if child_node.key in exclude: + continue + node_dict.update(child_node.to_dict(format_func)) + + return node_dict + + def compile_to(self, i: Dict, schema, kv_sep) -> str: + if schema == "json": + return json.dumps(i, indent=4) + elif schema == "markdown": + return dict_to_markdown(i, kv_sep=kv_sep) + else: + return str(i) + + def tagging(self, text, schema, tag="") -> str: + if not tag: + return text + if schema == "json": + return f"[{tag}]\n" + text + f"\n[/{tag}]" + else: # markdown + return f"[{tag}]\n" + text + f"\n[/{tag}]" + + def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: + nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) + text = self.compile_to(nodes, schema, kv_sep) + return self.tagging(text, schema, tag) + + def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown template with all/root/children nodes""" + format_func = lambda i: f"{i.expected_type} # {i.instruction}" + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) + + def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown examples with all/root/children nodes""" + + # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example + # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str + format_func = lambda i: i.example + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) + + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: + """ + mode: all/root/children + mode="children": 编译所有子节点为一个统一模板,包括instruction与example + mode="all": NotImplemented + mode="root": NotImplemented + schmea: raw/json/markdown + schema="raw": 不编译,context, lang_constaint, instruction + schema="json":编译context, example(json), instruction(markdown), constraint, action + schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action + """ + if schema == "raw": + return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction + + # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", + # compile example暂时不支持markdown + instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) + example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) + # nodes = ", ".join(self.to_dict(mode=mode).keys()) + constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] + constraint = "\n".join(constraints) + + prompt = template.format( + context=context, + example=example, + instruction=instruction, + constraint=constraint, + ) + return prompt + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _aask_v1( + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + system_msgs: Optional[list[str]] = None, + schema="markdown", # compatible to original format + timeout=CONFIG.timeout, + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" + content = await self.llm.aask(prompt, system_msgs, timeout=timeout) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + + if schema == "json": + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + else: # using markdown parser + parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) + + logger.debug(f"parsed_data:\n{parsed_data}") + instruct_content = output_class(**parsed_data) + return content, instruct_content + + def get(self, key): + return self.instruct_content.model_dump()[key] + + def set_recursive(self, name, value): + setattr(self, name, value) + for _, i in self.children.items(): + i.set_recursive(name, value) + + def set_llm(self, llm): + self.set_recursive("llm", llm) + + def set_context(self, context): + self.set_recursive("context", context) + + async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): + prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) + + if schema != "raw": + mapping = self.get_mapping(mode, exclude=exclude) + class_name = f"{self.key}_AN" + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) + self.content = content + self.instruct_content = scontent + else: + self.content = await self.llm.aask(prompt) + self.instruct_content = None + + return self + + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): + """Fill the node(s) with mode. + + :param context: Everything we should know when filling node. + :param llm: Large Language Model with pre-defined system message. + :param schema: json/markdown, determine example and output format. + - raw: free form text + - json: it's easy to open source LLM with json format + - markdown: when generating code, markdown is always better + :param mode: auto/children/root + - auto: automated fill children's nodes and gather outputs, if no children, fill itself + - children: fill children's nodes and gather outputs + - root: fill root's node and gather output + :param strgy: simple/complex + - simple: run only once + - complex: run each node + :param timeout: Timeout for llm invocation. + :param exclude: The keys of ActionNode to exclude. + :return: self + """ + self.set_llm(llm) + self.set_context(context) + if self.schema: + schema = self.schema + + if strgy == "simple": + return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + elif strgy == "complex": + # 这里隐式假设了拥有children + tmp = {} + for _, i in self.children.items(): + if exclude and i.key in exclude: + continue + child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + tmp.update(child.instruct_content.dict()) + cls = self.create_children_class() + self.instruct_content = cls(**tmp) + return self diff --git a/PR/repair_llm_raw_output.py b/PR/repair_llm_raw_output.py new file mode 100644 index 000000000..4995918c2 --- /dev/null +++ b/PR/repair_llm_raw_output.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : repair llm raw output with particular conditions + +import copy +from enum import Enum +from typing import Callable, Union + +import regex as re +from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.utils.custom_decoder import CustomDecoder + + +class RepairType(Enum): + CS = "case sensitivity" + RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]` + SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]` + JSON = "json format" + + +def repair_case_sensitivity(output: str, req_key: str) -> str: + """ + usually, req_key is the key name of expected json or markdown content, it won't appear in the value part. + fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually + """ + if req_key in output: + return output + + output_lower = output.lower() + req_key_lower = req_key.lower() + if req_key_lower in output_lower: + # find the sub-part index, and replace it with raw req_key + lidx = output_lower.find(req_key_lower) + source = output[lidx : lidx + len(req_key_lower)] + output = output.replace(source, req_key) + logger.info(f"repair_case_sensitivity: {req_key}") + + return output + + +def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str: + """ + fix + 1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]` + 2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]` + """ + sc_arr = ["/"] + + if req_key in output: + return output + + for sc in sc_arr: + req_key_pure = req_key.replace(sc, "") + appear_cnt = output.count(req_key_pure) + if req_key_pure in output and appear_cnt > 1: + # req_key with special_character usually in the tail side + ridx = output.rfind(req_key_pure) + output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}" + logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}") + + return output + + +def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str: + """ + implement the req_key pair in the begin or end of the content + req_key format + 1. `[req_key]`, and its pair `[/req_key]` + 2. `[/req_key]`, and its pair `[req_key]` + """ + sc = "/" # special char + if req_key.startswith("[") and req_key.endswith("]"): + if sc in req_key: + left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]` + right_key = req_key + else: + left_key = req_key + right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]` + + if left_key not in output: + output = left_key + "\n" + output + if right_key not in output: + + def judge_potential_json(routput: str, left_key: str) -> Union[str, None]: + ridx = routput.rfind(left_key) + if ridx < 0: + return None + sub_output = routput[ridx:] + idx1 = sub_output.rfind("}") + idx2 = sub_output.rindex("]") + idx = idx1 if idx1 >= idx2 else idx2 + sub_output = sub_output[: idx + 1] + return sub_output + + if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)): + # # avoid [req_key]xx[req_key] case to append [/req_key] + output = output + "\n" + right_key + elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)): + sub_content = judge_potential_json(output, left_key) + output = sub_content + "\n" + right_key + + return output + + +def repair_json_format(output: str) -> str: + """ + fix extra `[` or `}` in the end + """ + output = output.strip() + + if output.startswith("[{"): + output = output[1:] + logger.info(f"repair_json_format: {'[{'}") + elif output.endswith("}]"): + output = output[:-1] + logger.info(f"repair_json_format: {'}]'}") + elif output.startswith("{") and output.endswith("]"): + output = output[:-1] + "}" + + # remove comments in output json str, after json value content, maybe start with #, maybe start with // + arr = output.split("\n") + new_arr = [] + for line in arr: + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + line = line[:comment_index].rstrip() + new_arr.append(line) + output = "\n".join(new_arr) + return output + + +def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str: + repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]] + for repair_type in repair_types: + if repair_type == RepairType.CS: + output = repair_case_sensitivity(output, req_key) + elif repair_type == RepairType.RKPM: + output = repair_required_key_pair_missing(output, req_key) + elif repair_type == RepairType.SCM: + output = repair_special_character_missing(output, req_key) + elif repair_type == RepairType.JSON: + output = repair_json_format(output) + return output + + +def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str: + """ + in open-source llm model, it usually can't follow the instruction well, the output may be incomplete, + so here we try to repair it and use all repair methods by default. + typical case + 1. case sensitivity + target: "Original Requirements" + output: "Original requirements" + 2. special character missing + target: [/CONTENT] + output: [CONTENT] + 3. json format + target: { xxx } + output: { xxx }] + """ + if not CONFIG.repair_llm_output: + return output + + # do the repairation usually for non-openai models + for req_key in req_keys: + output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type) + return output + + +def repair_invalid_json(output: str, error: str) -> str: + """ + repair the situation like there are extra chars like + error examples + example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765) + example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266) + """ + pattern = r"line ([0-9]+) column ([0-9]+)" + + matches = re.findall(pattern, error, re.DOTALL) + if len(matches) > 0: + line_no = int(matches[0][0]) - 1 + col_no = int(matches[0][1]) - 1 + + # due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'` + output = output.replace('"""', '"').replace("'''", '"') + arr = output.split("\n") + rline = arr[line_no] # raw line + line = arr[line_no].strip() + # different general problems + if line.endswith("],"): + # problem, redundant char `]` + new_line = line.replace("]", "") + elif line.endswith("},") and not output.endswith("},"): + # problem, redundant char `}` + new_line = line.replace("}", "") + elif line.endswith("},") and output.endswith("},"): + new_line = line[:-1] + # remove comments in output json str, after json value content, maybe start with #, maybe start with // + elif rline[col_no] == "#" or rline[col_no] == "/": + new_line = rline[:col_no] + for i in range(line_no + 1, len(arr)): + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + arr[i] = arr[i][:comment_index].rstrip() + elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: + # problem, `"""` or `'''` without `,` + new_line = f",{line}" + elif '",' not in line and "," not in line and '"' not in line: + new_line = f'{line}",' + elif not line.endswith(","): + # problem, miss char `,` at the end. + new_line = f"{line}," + elif "," in line and len(line) == 1: + new_line = f'"{line}' + elif '",' in line: + new_line = line[:-2] + "'," + else: + new_line = line + + arr[line_no] = new_line + output = "\n".join(arr) + logger.info(f"repair_invalid_json, raw error: {error}") + + return output + + +def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]: + def run_and_passon(retry_state: RetryCallState) -> None: + """ + RetryCallState example + { + "start_time":143.098322024, + "retry_object":")>", + "fn":"", + "args":"(\"tag:[/CONTENT]\",)", # function input args + "kwargs":{}, # function input kwargs + "attempt_number":1, # retry number + "outcome":"", # type(outcome.result()) = "str", type(outcome.exception()) = "class" + "outcome_timestamp":143.098416904, + "idle_for":0, + "next_action":"None" + } + """ + if retry_state.outcome.failed: + if retry_state.args: + # # can't be used as args=retry_state.args + func_param_output = retry_state.args[0] + elif retry_state.kwargs: + func_param_output = retry_state.kwargs.get("output", "") + exp_str = str(retry_state.outcome.exception()) + + fix_str = "try to fix it, " if CONFIG.repair_llm_output else "" + logger.warning( + f"parse json from content inside [CONTENT][/CONTENT] failed at retry " + f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}" + ) + + repaired_output = repair_invalid_json(func_param_output, exp_str) + retry_state.kwargs["output"] = repaired_output + + return run_and_passon + + +@retry( + stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0), + wait=wait_fixed(1), + after=run_after_exp_and_passon_next_retry(logger), +) +def retry_parse_json_text(output: str) -> Union[list, dict]: + """ + repair the json-text situation like there are extra chars like [']', '}'] + + Warning + if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work + if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. + it's a two-layer retry cycle + """ + # logger.debug(f"output to json decode:\n{output}") + + # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break + parsed_data = CustomDecoder(strict=False).decode(output) + + return parsed_data + + +def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"): + """extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern""" + + def re_extract_content(cont: str, pattern: str) -> str: + matches = re.findall(pattern, cont, re.DOTALL) + for match in matches: + if match: + cont = match + break + return cont.strip() + + # TODO construct the extract pattern with the `right_key` + raw_content = copy.deepcopy(content) + pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]" + new_content = re_extract_content(raw_content, pattern) + + if not new_content.startswith("{"): + # TODO find a more general pattern + # # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation + logger.warning(f"extract_content try another pattern: {pattern}") + if right_key not in new_content: + raw_content = copy.deepcopy(new_content + "\n" + right_key) + # # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" + new_content = re_extract_content(raw_content, pattern) + else: + if right_key in new_content: + idx = new_content.find(right_key) + new_content = new_content[:idx] + new_content = new_content.strip() + + return new_content + + +def extract_state_value_from_output(content: str) -> str: + """ + For openai models, they will always return state number. But for open llm models, the instruction result maybe a + long text contain target number, so here add a extraction to improve success rate. + + Args: + content (str): llm's output from `Role._think` + """ + content = content.strip() # deal the output cases like " 0", "0\n" and so on. + pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern + matches = re.findall(pattern, content, re.DOTALL) + matches = list(set(matches)) + state = matches[0] if len(matches) > 0 else "-1" + return state diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 4995918c2..ef3580750 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -105,6 +105,23 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") - return output +def remove_comments_from_line(line): + """ + Remove comments from a single line of string. + Comments are assumed to start with '#' or '//' and are not inside string values. + """ + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + if comment_index != -1: # if comments, then delete them + return line[:comment_index].rstrip() + return line + + def repair_json_format(output: str) -> str: """ fix extra `[` or `}` in the end @@ -125,17 +142,8 @@ def repair_json_format(output: str) -> str: new_arr = [] for line in arr: # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - line = line[:comment_index].rstrip() - new_arr.append(line) + new_line = remove_comments_from_line(line) + new_arr.append(new_line) output = "\n".join(new_arr) return output @@ -209,18 +217,9 @@ def repair_invalid_json(output: str, error: str) -> str: # remove comments in output json str, after json value content, maybe start with #, maybe start with // elif rline[col_no] == "#" or rline[col_no] == "/": new_line = rline[:col_no] + # check the next line and remove the comments for i in range(line_no + 1, len(arr)): - # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - arr[i] = arr[i][:comment_index].rstrip() + arr[i] = remove_comments_from_line(arr[i]) elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" From 8b5f7848fa6aa8c9e0703dcdbf6760ef1efa87eb Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 17:00:59 +0800 Subject: [PATCH 5/9] delete PR dir --- metagpt/utils/repair_llm_raw_output.py | 43 +++++++++++++------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 4995918c2..ef3580750 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -105,6 +105,23 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") - return output +def remove_comments_from_line(line): + """ + Remove comments from a single line of string. + Comments are assumed to start with '#' or '//' and are not inside string values. + """ + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + if comment_index != -1: # if comments, then delete them + return line[:comment_index].rstrip() + return line + + def repair_json_format(output: str) -> str: """ fix extra `[` or `}` in the end @@ -125,17 +142,8 @@ def repair_json_format(output: str) -> str: new_arr = [] for line in arr: # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - line = line[:comment_index].rstrip() - new_arr.append(line) + new_line = remove_comments_from_line(line) + new_arr.append(new_line) output = "\n".join(new_arr) return output @@ -209,18 +217,9 @@ def repair_invalid_json(output: str, error: str) -> str: # remove comments in output json str, after json value content, maybe start with #, maybe start with // elif rline[col_no] == "#" or rline[col_no] == "/": new_line = rline[:col_no] + # check the next line and remove the comments for i in range(line_no + 1, len(arr)): - # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - arr[i] = arr[i][:comment_index].rstrip() + arr[i] = remove_comments_from_line(arr[i]) elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" From 2361c7e8aa2df55bd3562243e95b4d5f538188ff Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 17:07:21 +0800 Subject: [PATCH 6/9] delete PR dir --- ...move comments in output json str, after js | 12 - PR/action_node.py | 351 ------------------ PR/repair_llm_raw_output.py | 351 ------------------ 3 files changed, 714 deletions(-) delete mode 100644 PR/# remove comments in output json str, after js delete mode 100644 PR/action_node.py delete mode 100644 PR/repair_llm_raw_output.py diff --git a/PR/# remove comments in output json str, after js b/PR/# remove comments in output json str, after js deleted file mode 100644 index f795fefdb..000000000 --- a/PR/# remove comments in output json str, after js +++ /dev/null @@ -1,12 +0,0 @@ - -git commit -m "To avoid JSONDecodeError: " -m "Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value" -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" - - - -git commit -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" - - - - - - diff --git a/PR/action_node.py b/PR/action_node.py deleted file mode 100644 index 0f441cfee..000000000 --- a/PR/action_node.py +++ /dev/null @@ -1,351 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/12/11 18:45 -@Author : alexanderwu -@File : action_node.py - -NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, - we can use typing to extract the type of the node, but we cannot use built-in list to extract. -""" -import json -from typing import Any, Dict, List, Optional, Tuple, Type - -from pydantic import BaseModel, create_model, model_validator -from tenacity import retry, stop_after_attempt, wait_random_exponential - -from metagpt.config import CONFIG -from metagpt.llm import BaseLLM -from metagpt.logs import logger -from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess -from metagpt.utils.common import OutputParser, general_after_log - -TAG = "CONTENT" - -LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." -FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." -# Delete comments in json -# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT - - -SIMPLE_TEMPLATE = """ -## context -{context} - ------ - -## format example -{example} - -## nodes: ": # " -{instruction} - -## constraint -{constraint} - -## action -Follow instructions of nodes, generate output and make sure it follows the format example. -""" - - -def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): - markdown_str = "" - for key, value in d.items(): - markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" - return markdown_str - - -class ActionNode: - """ActionNode is a tree of nodes.""" - - schema: str # raw/json/markdown, default: "" - - # Action Context - context: str # all the context, including all necessary info - llm: BaseLLM # LLM with aask interface - children: dict[str, "ActionNode"] - - # Action Input - key: str # Product Requirement / File list / Code - expected_type: Type # such as str / int / float etc. - # context: str # everything in the history. - instruction: str # the instructions should be followed. - example: Any # example for In Context-Learning. - - # Action Output - content: str - instruct_content: BaseModel - - def __init__( - self, - key: str, - expected_type: Type, - instruction: str, - example: Any, - content: str = "", - children: dict[str, "ActionNode"] = None, - schema: str = "", - ): - self.key = key - self.expected_type = expected_type - self.instruction = instruction - self.example = example - self.content = content - self.children = children if children is not None else {} - self.schema = schema - - def __str__(self): - return ( - f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}" - f", {self.content}, {self.children}" - ) - - def __repr__(self): - return self.__str__() - - def add_child(self, node: "ActionNode"): - """增加子ActionNode""" - self.children[node.key] = node - - def add_children(self, nodes: List["ActionNode"]): - """批量增加子ActionNode""" - for node in nodes: - self.add_child(node) - - @classmethod - def from_children(cls, key, nodes: List["ActionNode"]): - """直接从一系列的子nodes初始化""" - obj = cls(key, str, "", "") - obj.add_children(nodes) - return obj - - def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: - """获得子ActionNode的字典,以key索引""" - exclude = exclude or [] - return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} - - def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: - """get self key: type mapping""" - return {self.key: (self.expected_type, ...)} - - def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: - """get key: type mapping under mode""" - if mode == "children" or (mode == "auto" and self.children): - return self.get_children_mapping(exclude=exclude) - return {} if exclude and self.key in exclude else self.get_self_mapping() - - @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): - """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" - - def check_fields(cls, values): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(values.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - - unrecognized_fields = set(values.keys()) - required_fields - if unrecognized_fields: - logger.warning(f"Unrecognized fields: {unrecognized_fields}") - return values - - validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)} - - new_class = create_model(class_name, __validators__=validators, **mapping) - return new_class - - def create_children_class(self, exclude=None): - """使用object内有的字段直接生成model_class""" - class_name = f"{self.key}_AN" - mapping = self.get_children_mapping(exclude=exclude) - return self.create_model_class(class_name, mapping) - - def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: - """将当前节点与子节点都按照node: format的格式组织成字典""" - - # 如果没有提供格式化函数,使用默认的格式化方式 - if format_func is None: - format_func = lambda node: f"{node.instruction}" - - # 使用提供的格式化函数来格式化当前节点的值 - formatted_value = format_func(self) - - # 创建当前节点的键值对 - if mode == "children" or (mode == "auto" and self.children): - node_dict = {} - else: - node_dict = {self.key: formatted_value} - - if mode == "root": - return node_dict - - # 遍历子节点并递归调用 to_dict 方法 - exclude = exclude or [] - for _, child_node in self.children.items(): - if child_node.key in exclude: - continue - node_dict.update(child_node.to_dict(format_func)) - - return node_dict - - def compile_to(self, i: Dict, schema, kv_sep) -> str: - if schema == "json": - return json.dumps(i, indent=4) - elif schema == "markdown": - return dict_to_markdown(i, kv_sep=kv_sep) - else: - return str(i) - - def tagging(self, text, schema, tag="") -> str: - if not tag: - return text - if schema == "json": - return f"[{tag}]\n" + text + f"\n[/{tag}]" - else: # markdown - return f"[{tag}]\n" + text + f"\n[/{tag}]" - - def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: - nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) - text = self.compile_to(nodes, schema, kv_sep) - return self.tagging(text, schema, tag) - - def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: - """compile to raw/json/markdown template with all/root/children nodes""" - format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) - - def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: - """compile to raw/json/markdown examples with all/root/children nodes""" - - # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example - # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str - format_func = lambda i: i.example - return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) - - def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: - """ - mode: all/root/children - mode="children": 编译所有子节点为一个统一模板,包括instruction与example - mode="all": NotImplemented - mode="root": NotImplemented - schmea: raw/json/markdown - schema="raw": 不编译,context, lang_constaint, instruction - schema="json":编译context, example(json), instruction(markdown), constraint, action - schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action - """ - if schema == "raw": - return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction - - # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", - # compile example暂时不支持markdown - instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) - example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) - # nodes = ", ".join(self.to_dict(mode=mode).keys()) - constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] - constraint = "\n".join(constraints) - - prompt = template.format( - context=context, - example=example, - instruction=instruction, - constraint=constraint, - ) - return prompt - - @retry( - wait=wait_random_exponential(min=1, max=20), - stop=stop_after_attempt(6), - after=general_after_log(logger), - ) - async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - schema="markdown", # compatible to original format - timeout=CONFIG.timeout, - ) -> (str, BaseModel): - """Use ActionOutput to wrap the output of aask""" - content = await self.llm.aask(prompt, system_msgs, timeout=timeout) - logger.debug(f"llm raw output:\n{content}") - output_class = self.create_model_class(output_class_name, output_data_mapping) - - if schema == "json": - parsed_data = llm_output_postprocess( - output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" - ) - else: # using markdown parser - parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - - logger.debug(f"parsed_data:\n{parsed_data}") - instruct_content = output_class(**parsed_data) - return content, instruct_content - - def get(self, key): - return self.instruct_content.model_dump()[key] - - def set_recursive(self, name, value): - setattr(self, name, value) - for _, i in self.children.items(): - i.set_recursive(name, value) - - def set_llm(self, llm): - self.set_recursive("llm", llm) - - def set_context(self, context): - self.set_recursive("context", context) - - async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): - prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) - - if schema != "raw": - mapping = self.get_mapping(mode, exclude=exclude) - class_name = f"{self.key}_AN" - content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) - self.content = content - self.instruct_content = scontent - else: - self.content = await self.llm.aask(prompt) - self.instruct_content = None - - return self - - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): - """Fill the node(s) with mode. - - :param context: Everything we should know when filling node. - :param llm: Large Language Model with pre-defined system message. - :param schema: json/markdown, determine example and output format. - - raw: free form text - - json: it's easy to open source LLM with json format - - markdown: when generating code, markdown is always better - :param mode: auto/children/root - - auto: automated fill children's nodes and gather outputs, if no children, fill itself - - children: fill children's nodes and gather outputs - - root: fill root's node and gather output - :param strgy: simple/complex - - simple: run only once - - complex: run each node - :param timeout: Timeout for llm invocation. - :param exclude: The keys of ActionNode to exclude. - :return: self - """ - self.set_llm(llm) - self.set_context(context) - if self.schema: - schema = self.schema - - if strgy == "simple": - return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) - elif strgy == "complex": - # 这里隐式假设了拥有children - tmp = {} - for _, i in self.children.items(): - if exclude and i.key in exclude: - continue - child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) - tmp.update(child.instruct_content.dict()) - cls = self.create_children_class() - self.instruct_content = cls(**tmp) - return self diff --git a/PR/repair_llm_raw_output.py b/PR/repair_llm_raw_output.py deleted file mode 100644 index 4995918c2..000000000 --- a/PR/repair_llm_raw_output.py +++ /dev/null @@ -1,351 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : repair llm raw output with particular conditions - -import copy -from enum import Enum -from typing import Callable, Union - -import regex as re -from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed - -from metagpt.config import CONFIG -from metagpt.logs import logger -from metagpt.utils.custom_decoder import CustomDecoder - - -class RepairType(Enum): - CS = "case sensitivity" - RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]` - SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]` - JSON = "json format" - - -def repair_case_sensitivity(output: str, req_key: str) -> str: - """ - usually, req_key is the key name of expected json or markdown content, it won't appear in the value part. - fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually - """ - if req_key in output: - return output - - output_lower = output.lower() - req_key_lower = req_key.lower() - if req_key_lower in output_lower: - # find the sub-part index, and replace it with raw req_key - lidx = output_lower.find(req_key_lower) - source = output[lidx : lidx + len(req_key_lower)] - output = output.replace(source, req_key) - logger.info(f"repair_case_sensitivity: {req_key}") - - return output - - -def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str: - """ - fix - 1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]` - 2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]` - """ - sc_arr = ["/"] - - if req_key in output: - return output - - for sc in sc_arr: - req_key_pure = req_key.replace(sc, "") - appear_cnt = output.count(req_key_pure) - if req_key_pure in output and appear_cnt > 1: - # req_key with special_character usually in the tail side - ridx = output.rfind(req_key_pure) - output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}" - logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}") - - return output - - -def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str: - """ - implement the req_key pair in the begin or end of the content - req_key format - 1. `[req_key]`, and its pair `[/req_key]` - 2. `[/req_key]`, and its pair `[req_key]` - """ - sc = "/" # special char - if req_key.startswith("[") and req_key.endswith("]"): - if sc in req_key: - left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]` - right_key = req_key - else: - left_key = req_key - right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]` - - if left_key not in output: - output = left_key + "\n" + output - if right_key not in output: - - def judge_potential_json(routput: str, left_key: str) -> Union[str, None]: - ridx = routput.rfind(left_key) - if ridx < 0: - return None - sub_output = routput[ridx:] - idx1 = sub_output.rfind("}") - idx2 = sub_output.rindex("]") - idx = idx1 if idx1 >= idx2 else idx2 - sub_output = sub_output[: idx + 1] - return sub_output - - if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)): - # # avoid [req_key]xx[req_key] case to append [/req_key] - output = output + "\n" + right_key - elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)): - sub_content = judge_potential_json(output, left_key) - output = sub_content + "\n" + right_key - - return output - - -def repair_json_format(output: str) -> str: - """ - fix extra `[` or `}` in the end - """ - output = output.strip() - - if output.startswith("[{"): - output = output[1:] - logger.info(f"repair_json_format: {'[{'}") - elif output.endswith("}]"): - output = output[:-1] - logger.info(f"repair_json_format: {'}]'}") - elif output.startswith("{") and output.endswith("]"): - output = output[:-1] + "}" - - # remove comments in output json str, after json value content, maybe start with #, maybe start with // - arr = output.split("\n") - new_arr = [] - for line in arr: - # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - line = line[:comment_index].rstrip() - new_arr.append(line) - output = "\n".join(new_arr) - return output - - -def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str: - repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]] - for repair_type in repair_types: - if repair_type == RepairType.CS: - output = repair_case_sensitivity(output, req_key) - elif repair_type == RepairType.RKPM: - output = repair_required_key_pair_missing(output, req_key) - elif repair_type == RepairType.SCM: - output = repair_special_character_missing(output, req_key) - elif repair_type == RepairType.JSON: - output = repair_json_format(output) - return output - - -def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str: - """ - in open-source llm model, it usually can't follow the instruction well, the output may be incomplete, - so here we try to repair it and use all repair methods by default. - typical case - 1. case sensitivity - target: "Original Requirements" - output: "Original requirements" - 2. special character missing - target: [/CONTENT] - output: [CONTENT] - 3. json format - target: { xxx } - output: { xxx }] - """ - if not CONFIG.repair_llm_output: - return output - - # do the repairation usually for non-openai models - for req_key in req_keys: - output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type) - return output - - -def repair_invalid_json(output: str, error: str) -> str: - """ - repair the situation like there are extra chars like - error examples - example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765) - example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266) - """ - pattern = r"line ([0-9]+) column ([0-9]+)" - - matches = re.findall(pattern, error, re.DOTALL) - if len(matches) > 0: - line_no = int(matches[0][0]) - 1 - col_no = int(matches[0][1]) - 1 - - # due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'` - output = output.replace('"""', '"').replace("'''", '"') - arr = output.split("\n") - rline = arr[line_no] # raw line - line = arr[line_no].strip() - # different general problems - if line.endswith("],"): - # problem, redundant char `]` - new_line = line.replace("]", "") - elif line.endswith("},") and not output.endswith("},"): - # problem, redundant char `}` - new_line = line.replace("}", "") - elif line.endswith("},") and output.endswith("},"): - new_line = line[:-1] - # remove comments in output json str, after json value content, maybe start with #, maybe start with // - elif rline[col_no] == "#" or rline[col_no] == "/": - new_line = rline[:col_no] - for i in range(line_no + 1, len(arr)): - # look for # or // comments and make sure they are not inside the string value - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - # if comments, then delete them - if comment_index != -1: - arr[i] = arr[i][:comment_index].rstrip() - elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: - # problem, `"""` or `'''` without `,` - new_line = f",{line}" - elif '",' not in line and "," not in line and '"' not in line: - new_line = f'{line}",' - elif not line.endswith(","): - # problem, miss char `,` at the end. - new_line = f"{line}," - elif "," in line and len(line) == 1: - new_line = f'"{line}' - elif '",' in line: - new_line = line[:-2] + "'," - else: - new_line = line - - arr[line_no] = new_line - output = "\n".join(arr) - logger.info(f"repair_invalid_json, raw error: {error}") - - return output - - -def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]: - def run_and_passon(retry_state: RetryCallState) -> None: - """ - RetryCallState example - { - "start_time":143.098322024, - "retry_object":")>", - "fn":"", - "args":"(\"tag:[/CONTENT]\",)", # function input args - "kwargs":{}, # function input kwargs - "attempt_number":1, # retry number - "outcome":"", # type(outcome.result()) = "str", type(outcome.exception()) = "class" - "outcome_timestamp":143.098416904, - "idle_for":0, - "next_action":"None" - } - """ - if retry_state.outcome.failed: - if retry_state.args: - # # can't be used as args=retry_state.args - func_param_output = retry_state.args[0] - elif retry_state.kwargs: - func_param_output = retry_state.kwargs.get("output", "") - exp_str = str(retry_state.outcome.exception()) - - fix_str = "try to fix it, " if CONFIG.repair_llm_output else "" - logger.warning( - f"parse json from content inside [CONTENT][/CONTENT] failed at retry " - f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}" - ) - - repaired_output = repair_invalid_json(func_param_output, exp_str) - retry_state.kwargs["output"] = repaired_output - - return run_and_passon - - -@retry( - stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0), - wait=wait_fixed(1), - after=run_after_exp_and_passon_next_retry(logger), -) -def retry_parse_json_text(output: str) -> Union[list, dict]: - """ - repair the json-text situation like there are extra chars like [']', '}'] - - Warning - if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work - if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. - it's a two-layer retry cycle - """ - # logger.debug(f"output to json decode:\n{output}") - - # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break - parsed_data = CustomDecoder(strict=False).decode(output) - - return parsed_data - - -def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"): - """extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern""" - - def re_extract_content(cont: str, pattern: str) -> str: - matches = re.findall(pattern, cont, re.DOTALL) - for match in matches: - if match: - cont = match - break - return cont.strip() - - # TODO construct the extract pattern with the `right_key` - raw_content = copy.deepcopy(content) - pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]" - new_content = re_extract_content(raw_content, pattern) - - if not new_content.startswith("{"): - # TODO find a more general pattern - # # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation - logger.warning(f"extract_content try another pattern: {pattern}") - if right_key not in new_content: - raw_content = copy.deepcopy(new_content + "\n" + right_key) - # # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" - new_content = re_extract_content(raw_content, pattern) - else: - if right_key in new_content: - idx = new_content.find(right_key) - new_content = new_content[:idx] - new_content = new_content.strip() - - return new_content - - -def extract_state_value_from_output(content: str) -> str: - """ - For openai models, they will always return state number. But for open llm models, the instruction result maybe a - long text contain target number, so here add a extraction to improve success rate. - - Args: - content (str): llm's output from `Role._think` - """ - content = content.strip() # deal the output cases like " 0", "0\n" and so on. - pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern - matches = re.findall(pattern, content, re.DOTALL) - matches = list(set(matches)) - state = matches[0] if len(matches) > 0 else "-1" - return state From 11f70ca9b1714dc312a847505c9940f0c60a24b1 Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 18:06:52 +0800 Subject: [PATCH 7/9] modify code based on feedback of action_node.py and repair_llm_raw_output.py, add code in test_repair_llm_raw_output.py --- metagpt/actions/action_node.py | 5 +--- metagpt/utils/repair_llm_raw_output.py | 9 +------ .../utils/test_repair_llm_raw_output.py | 26 +++++++++++++++++++ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index ed0e27869..6c65b33ef 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -23,10 +23,7 @@ from metagpt.utils.common import OutputParser, general_after_log TAG = "CONTENT" LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." -FORMAT_CONSTRAINT = (f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else. " - f"Delete comments in json") -# Delete comments in json -# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." SIMPLE_TEMPLATE = """ diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index ef3580750..973cffb8a 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -137,11 +137,10 @@ def repair_json_format(output: str) -> str: elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" - # remove comments in output json str, after json value content, maybe start with #, maybe start with // + # remove comments in output json string arr = output.split("\n") new_arr = [] for line in arr: - # look for # or // comments and make sure they are not inside the string value new_line = remove_comments_from_line(line) new_arr.append(new_line) output = "\n".join(new_arr) @@ -214,12 +213,6 @@ def repair_invalid_json(output: str, error: str) -> str: new_line = line.replace("}", "") elif line.endswith("},") and output.endswith("},"): new_line = line[:-1] - # remove comments in output json str, after json value content, maybe start with #, maybe start with // - elif rline[col_no] == "#" or rline[col_no] == "/": - new_line = rline[:col_no] - # check the next line and remove the comments - for i in range(line_no + 1, len(arr)): - arr[i] = remove_comments_from_line(arr[i]) elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 1f809a081..9d53b8243 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -141,6 +141,32 @@ def test_repair_json_format(): output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) assert output == target_output + raw_output = """ +{ + "Language": "en_us", // define language + "Programming Language": "Python" # define code language +} +""" + target_output = """{ + "Language": "en_us", + "Programming Language": "Python" +}""" + output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) + assert output == target_output + + raw_output = """ + { + "Language": "#en_us#", // define language + "Programming Language": "//Python # Code // Language//" # define code language + } + """ + target_output = """{ + "Language": "#en_us#", + "Programming Language": "//Python # Code // Language//" + }""" + output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) + assert output == target_output + def test_repair_invalid_json(): from metagpt.utils.repair_llm_raw_output import repair_invalid_json From c3b4c698d80cba70e446cd6a97f375459c8c5595 Mon Sep 17 00:00:00 2001 From: huzixia <528543747@qq.com> Date: Sat, 27 Jan 2024 18:23:57 +0800 Subject: [PATCH 8/9] update repair_llm_raw_output.py --- metagpt/utils/repair_llm_raw_output.py | 36 ++++++++++---------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 973cffb8a..6da974d96 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -105,23 +105,6 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") - return output -def remove_comments_from_line(line): - """ - Remove comments from a single line of string. - Comments are assumed to start with '#' or '//' and are not inside string values. - """ - comment_index = -1 - for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): - if match.group(1): # if the string value - continue - if match.group(2): # if comments - comment_index = match.start(2) - break - if comment_index != -1: # if comments, then delete them - return line[:comment_index].rstrip() - return line - - def repair_json_format(output: str) -> str: """ fix extra `[` or `}` in the end @@ -136,13 +119,22 @@ def repair_json_format(output: str) -> str: logger.info(f"repair_json_format: {'}]'}") elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" - - # remove comments in output json string + # remove comments in output json string, after json value content, maybe start with #, maybe start with // arr = output.split("\n") new_arr = [] - for line in arr: - new_line = remove_comments_from_line(line) - new_arr.append(new_line) + for json_line in arr: + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", json_line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + json_line = json_line[:comment_index].rstrip() + new_arr.append(json_line) output = "\n".join(new_arr) return output From a68c3442bcff09864409ed47a02bbcc476657d23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E5=85=81=E6=9D=83?= Date: Sun, 28 Jan 2024 10:39:00 +0800 Subject: [PATCH 9/9] Refactor get_choice_delta_text for safer dict access --- metagpt/provider/base_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index a50cdacd9..47c527b97 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -91,7 +91,7 @@ class BaseLLM(ABC): def get_choice_delta_text(self, rsp: dict) -> str: """Required to provide the first text of stream choice""" - return rsp.get("choices")[0]["delta"]["content"] + return rsp.get("choices", [{}])[0].get("delta", {}).get("content", "") def get_choice_function(self, rsp: dict) -> dict: """Required to provide the first function of choice