Merge branch 'huzixia' of github.com:HuZixia/MetaGPT into huzixia

merge
This commit is contained in:
huzixia 2024-01-27 17:05:29 +08:00
commit b9a03c380a
3 changed files with 714 additions and 0 deletions

View file

@ -0,0 +1,12 @@
git commit -m "To avoid JSONDecodeError: " -m "Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value" -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py"
git commit -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py"

351
PR/action_node.py Normal file
View file

@ -0,0 +1,351 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/11 18:45
@Author : alexanderwu
@File : action_node.py
NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process,
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
"""
import json
from typing import Any, Dict, List, Optional, Tuple, Type
from pydantic import BaseModel, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from metagpt.utils.common import OutputParser, general_after_log
TAG = "CONTENT"
LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT."
FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else."
# Delete comments in json
# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT
SIMPLE_TEMPLATE = """
## context
{context}
-----
## format example
{example}
## nodes: "<node>: <type> # <instruction>"
{instruction}
## constraint
{constraint}
## action
Follow instructions of nodes, generate output and make sure it follows the format example.
"""
def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"):
markdown_str = ""
for key, value in d.items():
markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}"
return markdown_str
class ActionNode:
"""ActionNode is a tree of nodes."""
schema: str # raw/json/markdown, default: ""
# Action Context
context: str # all the context, including all necessary info
llm: BaseLLM # LLM with aask interface
children: dict[str, "ActionNode"]
# Action Input
key: str # Product Requirement / File list / Code
expected_type: Type # such as str / int / float etc.
# context: str # everything in the history.
instruction: str # the instructions should be followed.
example: Any # example for In Context-Learning.
# Action Output
content: str
instruct_content: BaseModel
def __init__(
self,
key: str,
expected_type: Type,
instruction: str,
example: Any,
content: str = "",
children: dict[str, "ActionNode"] = None,
schema: str = "",
):
self.key = key
self.expected_type = expected_type
self.instruction = instruction
self.example = example
self.content = content
self.children = children if children is not None else {}
self.schema = schema
def __str__(self):
return (
f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}"
f", {self.content}, {self.children}"
)
def __repr__(self):
return self.__str__()
def add_child(self, node: "ActionNode"):
"""增加子ActionNode"""
self.children[node.key] = node
def add_children(self, nodes: List["ActionNode"]):
"""批量增加子ActionNode"""
for node in nodes:
self.add_child(node)
@classmethod
def from_children(cls, key, nodes: List["ActionNode"]):
"""直接从一系列的子nodes初始化"""
obj = cls(key, str, "", "")
obj.add_children(nodes)
return obj
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
"""获得子ActionNode的字典以key索引"""
exclude = exclude or []
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
"""get self key: type mapping"""
return {self.key: (self.expected_type, ...)}
def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]:
"""get key: type mapping under mode"""
if mode == "children" or (mode == "auto" and self.children):
return self.get_children_mapping(exclude=exclude)
return {} if exclude and self.key in exclude else self.get_self_mapping()
@classmethod
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
"""基于pydantic v1的模型动态生成用来检验结果类型正确性"""
def check_fields(cls, values):
required_fields = set(mapping.keys())
missing_fields = required_fields - set(values.keys())
if missing_fields:
raise ValueError(f"Missing fields: {missing_fields}")
unrecognized_fields = set(values.keys()) - required_fields
if unrecognized_fields:
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
return values
validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)}
new_class = create_model(class_name, __validators__=validators, **mapping)
return new_class
def create_children_class(self, exclude=None):
"""使用object内有的字段直接生成model_class"""
class_name = f"{self.key}_AN"
mapping = self.get_children_mapping(exclude=exclude)
return self.create_model_class(class_name, mapping)
def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
"""将当前节点与子节点都按照node: format的格式组织成字典"""
# 如果没有提供格式化函数,使用默认的格式化方式
if format_func is None:
format_func = lambda node: f"{node.instruction}"
# 使用提供的格式化函数来格式化当前节点的值
formatted_value = format_func(self)
# 创建当前节点的键值对
if mode == "children" or (mode == "auto" and self.children):
node_dict = {}
else:
node_dict = {self.key: formatted_value}
if mode == "root":
return node_dict
# 遍历子节点并递归调用 to_dict 方法
exclude = exclude or []
for _, child_node in self.children.items():
if child_node.key in exclude:
continue
node_dict.update(child_node.to_dict(format_func))
return node_dict
def compile_to(self, i: Dict, schema, kv_sep) -> str:
if schema == "json":
return json.dumps(i, indent=4)
elif schema == "markdown":
return dict_to_markdown(i, kv_sep=kv_sep)
else:
return str(i)
def tagging(self, text, schema, tag="") -> str:
if not tag:
return text
if schema == "json":
return f"[{tag}]\n" + text + f"\n[/{tag}]"
else: # markdown
return f"[{tag}]\n" + text + f"\n[/{tag}]"
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
text = self.compile_to(nodes, schema, kv_sep)
return self.tagging(text, schema, tag)
def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str:
"""compile to raw/json/markdown template with all/root/children nodes"""
format_func = lambda i: f"{i.expected_type} # {i.instruction}"
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude)
def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str:
"""compile to raw/json/markdown examples with all/root/children nodes"""
# 这里不能使用f-string因为转译为str后再json.dumps会额外加上引号无法作为有效的example
# 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list而是str
format_func = lambda i: i.example
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude)
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str:
"""
mode: all/root/children
mode="children": 编译所有子节点为一个统一模板包括instruction与example
mode="all": NotImplemented
mode="root": NotImplemented
schmea: raw/json/markdown
schema="raw": 不编译context, lang_constaint, instruction
schema="json"编译context, example(json), instruction(markdown), constraint, action
schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action
"""
if schema == "raw":
return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction
# FIXME: json instruction会带来格式问题"Project name": "web_2048 # 项目名称使用下划线",
# compile example暂时不支持markdown
instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude)
example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude)
# nodes = ", ".join(self.to_dict(mode=mode).keys())
constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT]
constraint = "\n".join(constraints)
prompt = template.format(
context=context,
example=example,
instruction=instruction,
constraint=constraint,
)
return prompt
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _aask_v1(
self,
prompt: str,
output_class_name: str,
output_data_mapping: dict,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=CONFIG.timeout,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
logger.debug(f"llm raw output:\n{content}")
output_class = self.create_model_class(output_class_name, output_data_mapping)
if schema == "json":
parsed_data = llm_output_postprocess(
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
)
else: # using markdown parser
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
logger.debug(f"parsed_data:\n{parsed_data}")
instruct_content = output_class(**parsed_data)
return content, instruct_content
def get(self, key):
return self.instruct_content.model_dump()[key]
def set_recursive(self, name, value):
setattr(self, name, value)
for _, i in self.children.items():
i.set_recursive(name, value)
def set_llm(self, llm):
self.set_recursive("llm", llm)
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
if schema != "raw":
mapping = self.get_mapping(mode, exclude=exclude)
class_name = f"{self.key}_AN"
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
self.content = content
self.instruct_content = scontent
else:
self.content = await self.llm.aask(prompt)
self.instruct_content = None
return self
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
:param llm: Large Language Model with pre-defined system message.
:param schema: json/markdown, determine example and output format.
- raw: free form text
- json: it's easy to open source LLM with json format
- markdown: when generating code, markdown is always better
:param mode: auto/children/root
- auto: automated fill children's nodes and gather outputs, if no children, fill itself
- children: fill children's nodes and gather outputs
- root: fill root's node and gather output
:param strgy: simple/complex
- simple: run only once
- complex: run each node
:param timeout: Timeout for llm invocation.
:param exclude: The keys of ActionNode to exclude.
:return: self
"""
self.set_llm(llm)
self.set_context(context)
if self.schema:
schema = self.schema
if strgy == "simple":
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
elif strgy == "complex":
# 这里隐式假设了拥有children
tmp = {}
for _, i in self.children.items():
if exclude and i.key in exclude:
continue
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
tmp.update(child.instruct_content.dict())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)
return self

351
PR/repair_llm_raw_output.py Normal file
View file

@ -0,0 +1,351 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : repair llm raw output with particular conditions
import copy
from enum import Enum
from typing import Callable, Union
import regex as re
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.custom_decoder import CustomDecoder
class RepairType(Enum):
CS = "case sensitivity"
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
JSON = "json format"
def repair_case_sensitivity(output: str, req_key: str) -> str:
"""
usually, req_key is the key name of expected json or markdown content, it won't appear in the value part.
fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually
"""
if req_key in output:
return output
output_lower = output.lower()
req_key_lower = req_key.lower()
if req_key_lower in output_lower:
# find the sub-part index, and replace it with raw req_key
lidx = output_lower.find(req_key_lower)
source = output[lidx : lidx + len(req_key_lower)]
output = output.replace(source, req_key)
logger.info(f"repair_case_sensitivity: {req_key}")
return output
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
fix
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
"""
sc_arr = ["/"]
if req_key in output:
return output
for sc in sc_arr:
req_key_pure = req_key.replace(sc, "")
appear_cnt = output.count(req_key_pure)
if req_key_pure in output and appear_cnt > 1:
# req_key with special_character usually in the tail side
ridx = output.rfind(req_key_pure)
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
return output
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
implement the req_key pair in the begin or end of the content
req_key format
1. `[req_key]`, and its pair `[/req_key]`
2. `[/req_key]`, and its pair `[req_key]`
"""
sc = "/" # special char
if req_key.startswith("[") and req_key.endswith("]"):
if sc in req_key:
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
right_key = req_key
else:
left_key = req_key
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
if left_key not in output:
output = left_key + "\n" + output
if right_key not in output:
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
ridx = routput.rfind(left_key)
if ridx < 0:
return None
sub_output = routput[ridx:]
idx1 = sub_output.rfind("}")
idx2 = sub_output.rindex("]")
idx = idx1 if idx1 >= idx2 else idx2
sub_output = sub_output[: idx + 1]
return sub_output
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
# # avoid [req_key]xx[req_key] case to append [/req_key]
output = output + "\n" + right_key
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
sub_content = judge_potential_json(output, left_key)
output = sub_content + "\n" + right_key
return output
def repair_json_format(output: str) -> str:
"""
fix extra `[` or `}` in the end
"""
output = output.strip()
if output.startswith("[{"):
output = output[1:]
logger.info(f"repair_json_format: {'[{'}")
elif output.endswith("}]"):
output = output[:-1]
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
# remove comments in output json str, after json value content, maybe start with #, maybe start with //
arr = output.split("\n")
new_arr = []
for line in arr:
# look for # or // comments and make sure they are not inside the string value
comment_index = -1
for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line):
if match.group(1): # if the string value
continue
if match.group(2): # if comments
comment_index = match.start(2)
break
# if comments, then delete them
if comment_index != -1:
line = line[:comment_index].rstrip()
new_arr.append(line)
output = "\n".join(new_arr)
return output
def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str:
repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]]
for repair_type in repair_types:
if repair_type == RepairType.CS:
output = repair_case_sensitivity(output, req_key)
elif repair_type == RepairType.RKPM:
output = repair_required_key_pair_missing(output, req_key)
elif repair_type == RepairType.SCM:
output = repair_special_character_missing(output, req_key)
elif repair_type == RepairType.JSON:
output = repair_json_format(output)
return output
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
"""
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
so here we try to repair it and use all repair methods by default.
typical case
1. case sensitivity
target: "Original Requirements"
output: "Original requirements"
2. special character missing
target: [/CONTENT]
output: [CONTENT]
3. json format
target: { xxx }
output: { xxx }]
"""
if not CONFIG.repair_llm_output:
return output
# do the repairation usually for non-openai models
for req_key in req_keys:
output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type)
return output
def repair_invalid_json(output: str, error: str) -> str:
"""
repair the situation like there are extra chars like
error examples
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
"""
pattern = r"line ([0-9]+) column ([0-9]+)"
matches = re.findall(pattern, error, re.DOTALL)
if len(matches) > 0:
line_no = int(matches[0][0]) - 1
col_no = int(matches[0][1]) - 1
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
output = output.replace('"""', '"').replace("'''", '"')
arr = output.split("\n")
rline = arr[line_no] # raw line
line = arr[line_no].strip()
# different general problems
if line.endswith("],"):
# problem, redundant char `]`
new_line = line.replace("]", "")
elif line.endswith("},") and not output.endswith("},"):
# problem, redundant char `}`
new_line = line.replace("}", "")
elif line.endswith("},") and output.endswith("},"):
new_line = line[:-1]
# remove comments in output json str, after json value content, maybe start with #, maybe start with //
elif rline[col_no] == "#" or rline[col_no] == "/":
new_line = rline[:col_no]
for i in range(line_no + 1, len(arr)):
# look for # or // comments and make sure they are not inside the string value
comment_index = -1
for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line):
if match.group(1): # if the string value
continue
if match.group(2): # if comments
comment_index = match.start(2)
break
# if comments, then delete them
if comment_index != -1:
arr[i] = arr[i][:comment_index].rstrip()
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
# problem, `"""` or `'''` without `,`
new_line = f",{line}"
elif '",' not in line and "," not in line and '"' not in line:
new_line = f'{line}",'
elif not line.endswith(","):
# problem, miss char `,` at the end.
new_line = f"{line},"
elif "," in line and len(line) == 1:
new_line = f'"{line}'
elif '",' in line:
new_line = line[:-2] + "',"
else:
new_line = line
arr[line_no] = new_line
output = "\n".join(arr)
logger.info(f"repair_invalid_json, raw error: {error}")
return output
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
def run_and_passon(retry_state: RetryCallState) -> None:
"""
RetryCallState example
{
"start_time":143.098322024,
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
"args":"(\"tag:[/CONTENT]\",)", # function input args
"kwargs":{}, # function input kwargs
"attempt_number":1, # retry number
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
"outcome_timestamp":143.098416904,
"idle_for":0,
"next_action":"None"
}
"""
if retry_state.outcome.failed:
if retry_state.args:
# # can't be used as args=retry_state.args
func_param_output = retry_state.args[0]
elif retry_state.kwargs:
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
logger.warning(
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
)
repaired_output = repair_invalid_json(func_param_output, exp_str)
retry_state.kwargs["output"] = repaired_output
return run_and_passon
@retry(
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
wait=wait_fixed(1),
after=run_after_exp_and_passon_next_retry(logger),
)
def retry_parse_json_text(output: str) -> Union[list, dict]:
"""
repair the json-text situation like there are extra chars like [']', '}']
Warning
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
it's a two-layer retry cycle
"""
# logger.debug(f"output to json decode:\n{output}")
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
parsed_data = CustomDecoder(strict=False).decode(output)
return parsed_data
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
"""extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern"""
def re_extract_content(cont: str, pattern: str) -> str:
matches = re.findall(pattern, cont, re.DOTALL)
for match in matches:
if match:
cont = match
break
return cont.strip()
# TODO construct the extract pattern with the `right_key`
raw_content = copy.deepcopy(content)
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)
if not new_content.startswith("{"):
# TODO find a more general pattern
# # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
logger.warning(f"extract_content try another pattern: {pattern}")
if right_key not in new_content:
raw_content = copy.deepcopy(new_content + "\n" + right_key)
# # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)
else:
if right_key in new_content:
idx = new_content.find(right_key)
new_content = new_content[:idx]
new_content = new_content.strip()
return new_content
def extract_state_value_from_output(content: str) -> str:
"""
For openai models, they will always return state number. But for open llm models, the instruction result maybe a
long text contain target number, so here add a extraction to improve success rate.
Args:
content (str): llm's output from `Role._think`
"""
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
matches = re.findall(pattern, content, re.DOTALL)
matches = list(set(matches))
state = matches[0] if len(matches) > 0 else "-1"
return state