diff --git a/PR/# remove comments in output json str, after js b/PR/# remove comments in output json str, after js new file mode 100644 index 000000000..f795fefdb --- /dev/null +++ b/PR/# remove comments in output json str, after js @@ -0,0 +1,12 @@ + +git commit -m "To avoid JSONDecodeError: " -m "Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value" -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" + + + +git commit -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py" + + + + + + diff --git a/PR/action_node.py b/PR/action_node.py new file mode 100644 index 000000000..0f441cfee --- /dev/null +++ b/PR/action_node.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/11 18:45 +@Author : alexanderwu +@File : action_node.py + +NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, + we can use typing to extract the type of the node, but we cannot use built-in list to extract. +""" +import json +from typing import Any, Dict, List, Optional, Tuple, Type + +from pydantic import BaseModel, create_model, model_validator +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.config import CONFIG +from metagpt.llm import BaseLLM +from metagpt.logs import logger +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess +from metagpt.utils.common import OutputParser, general_after_log + +TAG = "CONTENT" + +LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." +# Delete comments in json +# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT + + +SIMPLE_TEMPLATE = """ +## context +{context} + +----- + +## format example +{example} + +## nodes: ": # " +{instruction} + +## constraint +{constraint} + +## action +Follow instructions of nodes, generate output and make sure it follows the format example. +""" + + +def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): + markdown_str = "" + for key, value in d.items(): + markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" + return markdown_str + + +class ActionNode: + """ActionNode is a tree of nodes.""" + + schema: str # raw/json/markdown, default: "" + + # Action Context + context: str # all the context, including all necessary info + llm: BaseLLM # LLM with aask interface + children: dict[str, "ActionNode"] + + # Action Input + key: str # Product Requirement / File list / Code + expected_type: Type # such as str / int / float etc. + # context: str # everything in the history. + instruction: str # the instructions should be followed. + example: Any # example for In Context-Learning. + + # Action Output + content: str + instruct_content: BaseModel + + def __init__( + self, + key: str, + expected_type: Type, + instruction: str, + example: Any, + content: str = "", + children: dict[str, "ActionNode"] = None, + schema: str = "", + ): + self.key = key + self.expected_type = expected_type + self.instruction = instruction + self.example = example + self.content = content + self.children = children if children is not None else {} + self.schema = schema + + def __str__(self): + return ( + f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}" + f", {self.content}, {self.children}" + ) + + def __repr__(self): + return self.__str__() + + def add_child(self, node: "ActionNode"): + """增加子ActionNode""" + self.children[node.key] = node + + def add_children(self, nodes: List["ActionNode"]): + """批量增加子ActionNode""" + for node in nodes: + self.add_child(node) + + @classmethod + def from_children(cls, key, nodes: List["ActionNode"]): + """直接从一系列的子nodes初始化""" + obj = cls(key, str, "", "") + obj.add_children(nodes) + return obj + + def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: + """获得子ActionNode的字典,以key索引""" + exclude = exclude or [] + return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} + + def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: + """get self key: type mapping""" + return {self.key: (self.expected_type, ...)} + + def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: + """get key: type mapping under mode""" + if mode == "children" or (mode == "auto" and self.children): + return self.get_children_mapping(exclude=exclude) + return {} if exclude and self.key in exclude else self.get_self_mapping() + + @classmethod + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): + """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" + + def check_fields(cls, values): + required_fields = set(mapping.keys()) + missing_fields = required_fields - set(values.keys()) + if missing_fields: + raise ValueError(f"Missing fields: {missing_fields}") + + unrecognized_fields = set(values.keys()) - required_fields + if unrecognized_fields: + logger.warning(f"Unrecognized fields: {unrecognized_fields}") + return values + + validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)} + + new_class = create_model(class_name, __validators__=validators, **mapping) + return new_class + + def create_children_class(self, exclude=None): + """使用object内有的字段直接生成model_class""" + class_name = f"{self.key}_AN" + mapping = self.get_children_mapping(exclude=exclude) + return self.create_model_class(class_name, mapping) + + def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: + """将当前节点与子节点都按照node: format的格式组织成字典""" + + # 如果没有提供格式化函数,使用默认的格式化方式 + if format_func is None: + format_func = lambda node: f"{node.instruction}" + + # 使用提供的格式化函数来格式化当前节点的值 + formatted_value = format_func(self) + + # 创建当前节点的键值对 + if mode == "children" or (mode == "auto" and self.children): + node_dict = {} + else: + node_dict = {self.key: formatted_value} + + if mode == "root": + return node_dict + + # 遍历子节点并递归调用 to_dict 方法 + exclude = exclude or [] + for _, child_node in self.children.items(): + if child_node.key in exclude: + continue + node_dict.update(child_node.to_dict(format_func)) + + return node_dict + + def compile_to(self, i: Dict, schema, kv_sep) -> str: + if schema == "json": + return json.dumps(i, indent=4) + elif schema == "markdown": + return dict_to_markdown(i, kv_sep=kv_sep) + else: + return str(i) + + def tagging(self, text, schema, tag="") -> str: + if not tag: + return text + if schema == "json": + return f"[{tag}]\n" + text + f"\n[/{tag}]" + else: # markdown + return f"[{tag}]\n" + text + f"\n[/{tag}]" + + def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: + nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) + text = self.compile_to(nodes, schema, kv_sep) + return self.tagging(text, schema, tag) + + def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown template with all/root/children nodes""" + format_func = lambda i: f"{i.expected_type} # {i.instruction}" + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) + + def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown examples with all/root/children nodes""" + + # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example + # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str + format_func = lambda i: i.example + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) + + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: + """ + mode: all/root/children + mode="children": 编译所有子节点为一个统一模板,包括instruction与example + mode="all": NotImplemented + mode="root": NotImplemented + schmea: raw/json/markdown + schema="raw": 不编译,context, lang_constaint, instruction + schema="json":编译context, example(json), instruction(markdown), constraint, action + schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action + """ + if schema == "raw": + return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction + + # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", + # compile example暂时不支持markdown + instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) + example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) + # nodes = ", ".join(self.to_dict(mode=mode).keys()) + constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] + constraint = "\n".join(constraints) + + prompt = template.format( + context=context, + example=example, + instruction=instruction, + constraint=constraint, + ) + return prompt + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _aask_v1( + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + system_msgs: Optional[list[str]] = None, + schema="markdown", # compatible to original format + timeout=CONFIG.timeout, + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" + content = await self.llm.aask(prompt, system_msgs, timeout=timeout) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + + if schema == "json": + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + else: # using markdown parser + parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) + + logger.debug(f"parsed_data:\n{parsed_data}") + instruct_content = output_class(**parsed_data) + return content, instruct_content + + def get(self, key): + return self.instruct_content.model_dump()[key] + + def set_recursive(self, name, value): + setattr(self, name, value) + for _, i in self.children.items(): + i.set_recursive(name, value) + + def set_llm(self, llm): + self.set_recursive("llm", llm) + + def set_context(self, context): + self.set_recursive("context", context) + + async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): + prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) + + if schema != "raw": + mapping = self.get_mapping(mode, exclude=exclude) + class_name = f"{self.key}_AN" + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) + self.content = content + self.instruct_content = scontent + else: + self.content = await self.llm.aask(prompt) + self.instruct_content = None + + return self + + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): + """Fill the node(s) with mode. + + :param context: Everything we should know when filling node. + :param llm: Large Language Model with pre-defined system message. + :param schema: json/markdown, determine example and output format. + - raw: free form text + - json: it's easy to open source LLM with json format + - markdown: when generating code, markdown is always better + :param mode: auto/children/root + - auto: automated fill children's nodes and gather outputs, if no children, fill itself + - children: fill children's nodes and gather outputs + - root: fill root's node and gather output + :param strgy: simple/complex + - simple: run only once + - complex: run each node + :param timeout: Timeout for llm invocation. + :param exclude: The keys of ActionNode to exclude. + :return: self + """ + self.set_llm(llm) + self.set_context(context) + if self.schema: + schema = self.schema + + if strgy == "simple": + return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + elif strgy == "complex": + # 这里隐式假设了拥有children + tmp = {} + for _, i in self.children.items(): + if exclude and i.key in exclude: + continue + child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + tmp.update(child.instruct_content.dict()) + cls = self.create_children_class() + self.instruct_content = cls(**tmp) + return self diff --git a/PR/repair_llm_raw_output.py b/PR/repair_llm_raw_output.py new file mode 100644 index 000000000..4995918c2 --- /dev/null +++ b/PR/repair_llm_raw_output.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : repair llm raw output with particular conditions + +import copy +from enum import Enum +from typing import Callable, Union + +import regex as re +from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.utils.custom_decoder import CustomDecoder + + +class RepairType(Enum): + CS = "case sensitivity" + RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]` + SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]` + JSON = "json format" + + +def repair_case_sensitivity(output: str, req_key: str) -> str: + """ + usually, req_key is the key name of expected json or markdown content, it won't appear in the value part. + fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually + """ + if req_key in output: + return output + + output_lower = output.lower() + req_key_lower = req_key.lower() + if req_key_lower in output_lower: + # find the sub-part index, and replace it with raw req_key + lidx = output_lower.find(req_key_lower) + source = output[lidx : lidx + len(req_key_lower)] + output = output.replace(source, req_key) + logger.info(f"repair_case_sensitivity: {req_key}") + + return output + + +def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str: + """ + fix + 1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]` + 2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]` + """ + sc_arr = ["/"] + + if req_key in output: + return output + + for sc in sc_arr: + req_key_pure = req_key.replace(sc, "") + appear_cnt = output.count(req_key_pure) + if req_key_pure in output and appear_cnt > 1: + # req_key with special_character usually in the tail side + ridx = output.rfind(req_key_pure) + output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}" + logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}") + + return output + + +def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str: + """ + implement the req_key pair in the begin or end of the content + req_key format + 1. `[req_key]`, and its pair `[/req_key]` + 2. `[/req_key]`, and its pair `[req_key]` + """ + sc = "/" # special char + if req_key.startswith("[") and req_key.endswith("]"): + if sc in req_key: + left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]` + right_key = req_key + else: + left_key = req_key + right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]` + + if left_key not in output: + output = left_key + "\n" + output + if right_key not in output: + + def judge_potential_json(routput: str, left_key: str) -> Union[str, None]: + ridx = routput.rfind(left_key) + if ridx < 0: + return None + sub_output = routput[ridx:] + idx1 = sub_output.rfind("}") + idx2 = sub_output.rindex("]") + idx = idx1 if idx1 >= idx2 else idx2 + sub_output = sub_output[: idx + 1] + return sub_output + + if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)): + # # avoid [req_key]xx[req_key] case to append [/req_key] + output = output + "\n" + right_key + elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)): + sub_content = judge_potential_json(output, left_key) + output = sub_content + "\n" + right_key + + return output + + +def repair_json_format(output: str) -> str: + """ + fix extra `[` or `}` in the end + """ + output = output.strip() + + if output.startswith("[{"): + output = output[1:] + logger.info(f"repair_json_format: {'[{'}") + elif output.endswith("}]"): + output = output[:-1] + logger.info(f"repair_json_format: {'}]'}") + elif output.startswith("{") and output.endswith("]"): + output = output[:-1] + "}" + + # remove comments in output json str, after json value content, maybe start with #, maybe start with // + arr = output.split("\n") + new_arr = [] + for line in arr: + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + line = line[:comment_index].rstrip() + new_arr.append(line) + output = "\n".join(new_arr) + return output + + +def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str: + repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]] + for repair_type in repair_types: + if repair_type == RepairType.CS: + output = repair_case_sensitivity(output, req_key) + elif repair_type == RepairType.RKPM: + output = repair_required_key_pair_missing(output, req_key) + elif repair_type == RepairType.SCM: + output = repair_special_character_missing(output, req_key) + elif repair_type == RepairType.JSON: + output = repair_json_format(output) + return output + + +def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str: + """ + in open-source llm model, it usually can't follow the instruction well, the output may be incomplete, + so here we try to repair it and use all repair methods by default. + typical case + 1. case sensitivity + target: "Original Requirements" + output: "Original requirements" + 2. special character missing + target: [/CONTENT] + output: [CONTENT] + 3. json format + target: { xxx } + output: { xxx }] + """ + if not CONFIG.repair_llm_output: + return output + + # do the repairation usually for non-openai models + for req_key in req_keys: + output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type) + return output + + +def repair_invalid_json(output: str, error: str) -> str: + """ + repair the situation like there are extra chars like + error examples + example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765) + example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266) + """ + pattern = r"line ([0-9]+) column ([0-9]+)" + + matches = re.findall(pattern, error, re.DOTALL) + if len(matches) > 0: + line_no = int(matches[0][0]) - 1 + col_no = int(matches[0][1]) - 1 + + # due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'` + output = output.replace('"""', '"').replace("'''", '"') + arr = output.split("\n") + rline = arr[line_no] # raw line + line = arr[line_no].strip() + # different general problems + if line.endswith("],"): + # problem, redundant char `]` + new_line = line.replace("]", "") + elif line.endswith("},") and not output.endswith("},"): + # problem, redundant char `}` + new_line = line.replace("}", "") + elif line.endswith("},") and output.endswith("},"): + new_line = line[:-1] + # remove comments in output json str, after json value content, maybe start with #, maybe start with // + elif rline[col_no] == "#" or rline[col_no] == "/": + new_line = rline[:col_no] + for i in range(line_no + 1, len(arr)): + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + arr[i] = arr[i][:comment_index].rstrip() + elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: + # problem, `"""` or `'''` without `,` + new_line = f",{line}" + elif '",' not in line and "," not in line and '"' not in line: + new_line = f'{line}",' + elif not line.endswith(","): + # problem, miss char `,` at the end. + new_line = f"{line}," + elif "," in line and len(line) == 1: + new_line = f'"{line}' + elif '",' in line: + new_line = line[:-2] + "'," + else: + new_line = line + + arr[line_no] = new_line + output = "\n".join(arr) + logger.info(f"repair_invalid_json, raw error: {error}") + + return output + + +def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]: + def run_and_passon(retry_state: RetryCallState) -> None: + """ + RetryCallState example + { + "start_time":143.098322024, + "retry_object":")>", + "fn":"", + "args":"(\"tag:[/CONTENT]\",)", # function input args + "kwargs":{}, # function input kwargs + "attempt_number":1, # retry number + "outcome":"", # type(outcome.result()) = "str", type(outcome.exception()) = "class" + "outcome_timestamp":143.098416904, + "idle_for":0, + "next_action":"None" + } + """ + if retry_state.outcome.failed: + if retry_state.args: + # # can't be used as args=retry_state.args + func_param_output = retry_state.args[0] + elif retry_state.kwargs: + func_param_output = retry_state.kwargs.get("output", "") + exp_str = str(retry_state.outcome.exception()) + + fix_str = "try to fix it, " if CONFIG.repair_llm_output else "" + logger.warning( + f"parse json from content inside [CONTENT][/CONTENT] failed at retry " + f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}" + ) + + repaired_output = repair_invalid_json(func_param_output, exp_str) + retry_state.kwargs["output"] = repaired_output + + return run_and_passon + + +@retry( + stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0), + wait=wait_fixed(1), + after=run_after_exp_and_passon_next_retry(logger), +) +def retry_parse_json_text(output: str) -> Union[list, dict]: + """ + repair the json-text situation like there are extra chars like [']', '}'] + + Warning + if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work + if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. + it's a two-layer retry cycle + """ + # logger.debug(f"output to json decode:\n{output}") + + # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break + parsed_data = CustomDecoder(strict=False).decode(output) + + return parsed_data + + +def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"): + """extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern""" + + def re_extract_content(cont: str, pattern: str) -> str: + matches = re.findall(pattern, cont, re.DOTALL) + for match in matches: + if match: + cont = match + break + return cont.strip() + + # TODO construct the extract pattern with the `right_key` + raw_content = copy.deepcopy(content) + pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]" + new_content = re_extract_content(raw_content, pattern) + + if not new_content.startswith("{"): + # TODO find a more general pattern + # # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation + logger.warning(f"extract_content try another pattern: {pattern}") + if right_key not in new_content: + raw_content = copy.deepcopy(new_content + "\n" + right_key) + # # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" + new_content = re_extract_content(raw_content, pattern) + else: + if right_key in new_content: + idx = new_content.find(right_key) + new_content = new_content[:idx] + new_content = new_content.strip() + + return new_content + + +def extract_state_value_from_output(content: str) -> str: + """ + For openai models, they will always return state number. But for open llm models, the instruction result maybe a + long text contain target number, so here add a extraction to improve success rate. + + Args: + content (str): llm's output from `Role._think` + """ + content = content.strip() # deal the output cases like " 0", "0\n" and so on. + pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern + matches = re.findall(pattern, content, re.DOTALL) + matches = list(set(matches)) + state = matches[0] if len(matches) > 0 else "-1" + return state