mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
Merge branch 'huzixia' of github.com:HuZixia/MetaGPT into huzixia
merge
This commit is contained in:
commit
b9a03c380a
3 changed files with 714 additions and 0 deletions
12
PR/# remove comments in output json str, after js
Normal file
12
PR/# remove comments in output json str, after js
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
|
||||
git commit -m "To avoid JSONDecodeError: " -m "Remove comments in output json str, after json value content, maybe start with #, maybe start with //, particularly, it is not inside the string value" -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py"
|
||||
|
||||
|
||||
|
||||
git commit -m "Addtionly, if you do not want JSONDecodeError to occur, you can add 'Delete comments in json' after FORMAT_CONSTRAINT in action_node.py"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
351
PR/action_node.py
Normal file
351
PR/action_node.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/11 18:45
|
||||
@Author : alexanderwu
|
||||
@File : action_node.py
|
||||
|
||||
NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process,
|
||||
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, create_model, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
from metagpt.utils.common import OutputParser, general_after_log
|
||||
|
||||
TAG = "CONTENT"
|
||||
|
||||
LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT."
|
||||
FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else."
|
||||
# Delete comments in json
|
||||
# If you don't want JSONDecodeError to occur, you can add Delete comments in json after FORMAT_CONSTRAINT
|
||||
|
||||
|
||||
SIMPLE_TEMPLATE = """
|
||||
## context
|
||||
{context}
|
||||
|
||||
-----
|
||||
|
||||
## format example
|
||||
{example}
|
||||
|
||||
## nodes: "<node>: <type> # <instruction>"
|
||||
{instruction}
|
||||
|
||||
## constraint
|
||||
{constraint}
|
||||
|
||||
## action
|
||||
Follow instructions of nodes, generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
|
||||
def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"):
|
||||
markdown_str = ""
|
||||
for key, value in d.items():
|
||||
markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}"
|
||||
return markdown_str
|
||||
|
||||
|
||||
class ActionNode:
|
||||
"""ActionNode is a tree of nodes."""
|
||||
|
||||
schema: str # raw/json/markdown, default: ""
|
||||
|
||||
# Action Context
|
||||
context: str # all the context, including all necessary info
|
||||
llm: BaseLLM # LLM with aask interface
|
||||
children: dict[str, "ActionNode"]
|
||||
|
||||
# Action Input
|
||||
key: str # Product Requirement / File list / Code
|
||||
expected_type: Type # such as str / int / float etc.
|
||||
# context: str # everything in the history.
|
||||
instruction: str # the instructions should be followed.
|
||||
example: Any # example for In Context-Learning.
|
||||
|
||||
# Action Output
|
||||
content: str
|
||||
instruct_content: BaseModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
expected_type: Type,
|
||||
instruction: str,
|
||||
example: Any,
|
||||
content: str = "",
|
||||
children: dict[str, "ActionNode"] = None,
|
||||
schema: str = "",
|
||||
):
|
||||
self.key = key
|
||||
self.expected_type = expected_type
|
||||
self.instruction = instruction
|
||||
self.example = example
|
||||
self.content = content
|
||||
self.children = children if children is not None else {}
|
||||
self.schema = schema
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}"
|
||||
f", {self.content}, {self.children}"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def add_child(self, node: "ActionNode"):
|
||||
"""增加子ActionNode"""
|
||||
self.children[node.key] = node
|
||||
|
||||
def add_children(self, nodes: List["ActionNode"]):
|
||||
"""批量增加子ActionNode"""
|
||||
for node in nodes:
|
||||
self.add_child(node)
|
||||
|
||||
@classmethod
|
||||
def from_children(cls, key, nodes: List["ActionNode"]):
|
||||
"""直接从一系列的子nodes初始化"""
|
||||
obj = cls(key, str, "", "")
|
||||
obj.add_children(nodes)
|
||||
return obj
|
||||
|
||||
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引"""
|
||||
exclude = exclude or []
|
||||
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
|
||||
|
||||
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get self key: type mapping"""
|
||||
return {self.key: (self.expected_type, ...)}
|
||||
|
||||
def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get key: type mapping under mode"""
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
return self.get_children_mapping(exclude=exclude)
|
||||
return {} if exclude and self.key in exclude else self.get_self_mapping()
|
||||
|
||||
@classmethod
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
|
||||
def check_fields(cls, values):
|
||||
required_fields = set(mapping.keys())
|
||||
missing_fields = required_fields - set(values.keys())
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing fields: {missing_fields}")
|
||||
|
||||
unrecognized_fields = set(values.keys()) - required_fields
|
||||
if unrecognized_fields:
|
||||
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
|
||||
return values
|
||||
|
||||
validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)}
|
||||
|
||||
new_class = create_model(class_name, __validators__=validators, **mapping)
|
||||
return new_class
|
||||
|
||||
def create_children_class(self, exclude=None):
|
||||
"""使用object内有的字段直接生成model_class"""
|
||||
class_name = f"{self.key}_AN"
|
||||
mapping = self.get_children_mapping(exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
|
||||
"""将当前节点与子节点都按照node: format的格式组织成字典"""
|
||||
|
||||
# 如果没有提供格式化函数,使用默认的格式化方式
|
||||
if format_func is None:
|
||||
format_func = lambda node: f"{node.instruction}"
|
||||
|
||||
# 使用提供的格式化函数来格式化当前节点的值
|
||||
formatted_value = format_func(self)
|
||||
|
||||
# 创建当前节点的键值对
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
node_dict = {}
|
||||
else:
|
||||
node_dict = {self.key: formatted_value}
|
||||
|
||||
if mode == "root":
|
||||
return node_dict
|
||||
|
||||
# 遍历子节点并递归调用 to_dict 方法
|
||||
exclude = exclude or []
|
||||
for _, child_node in self.children.items():
|
||||
if child_node.key in exclude:
|
||||
continue
|
||||
node_dict.update(child_node.to_dict(format_func))
|
||||
|
||||
return node_dict
|
||||
|
||||
def compile_to(self, i: Dict, schema, kv_sep) -> str:
|
||||
if schema == "json":
|
||||
return json.dumps(i, indent=4)
|
||||
elif schema == "markdown":
|
||||
return dict_to_markdown(i, kv_sep=kv_sep)
|
||||
else:
|
||||
return str(i)
|
||||
|
||||
def tagging(self, text, schema, tag="") -> str:
|
||||
if not tag:
|
||||
return text
|
||||
if schema == "json":
|
||||
return f"[{tag}]\n" + text + f"\n[/{tag}]"
|
||||
else: # markdown
|
||||
return f"[{tag}]\n" + text + f"\n[/{tag}]"
|
||||
|
||||
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
|
||||
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
|
||||
text = self.compile_to(nodes, schema, kv_sep)
|
||||
return self.tagging(text, schema, tag)
|
||||
|
||||
def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str:
|
||||
"""compile to raw/json/markdown template with all/root/children nodes"""
|
||||
format_func = lambda i: f"{i.expected_type} # {i.instruction}"
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude)
|
||||
|
||||
def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str:
|
||||
"""compile to raw/json/markdown examples with all/root/children nodes"""
|
||||
|
||||
# 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example
|
||||
# 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str
|
||||
format_func = lambda i: i.example
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude)
|
||||
|
||||
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str:
|
||||
"""
|
||||
mode: all/root/children
|
||||
mode="children": 编译所有子节点为一个统一模板,包括instruction与example
|
||||
mode="all": NotImplemented
|
||||
mode="root": NotImplemented
|
||||
schmea: raw/json/markdown
|
||||
schema="raw": 不编译,context, lang_constaint, instruction
|
||||
schema="json":编译context, example(json), instruction(markdown), constraint, action
|
||||
schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action
|
||||
"""
|
||||
if schema == "raw":
|
||||
return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction
|
||||
|
||||
# FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线",
|
||||
# compile example暂时不支持markdown
|
||||
instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude)
|
||||
example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude)
|
||||
# nodes = ", ".join(self.to_dict(mode=mode).keys())
|
||||
constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT]
|
||||
constraint = "\n".join(constraints)
|
||||
|
||||
prompt = template.format(
|
||||
context=context,
|
||||
example=example,
|
||||
instruction=instruction,
|
||||
constraint=constraint,
|
||||
)
|
||||
return prompt
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _aask_v1(
|
||||
self,
|
||||
prompt: str,
|
||||
output_class_name: str,
|
||||
output_data_mapping: dict,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
schema="markdown", # compatible to original format
|
||||
timeout=CONFIG.timeout,
|
||||
) -> (str, BaseModel):
|
||||
"""Use ActionOutput to wrap the output of aask"""
|
||||
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
|
||||
logger.debug(f"llm raw output:\n{content}")
|
||||
output_class = self.create_model_class(output_class_name, output_data_mapping)
|
||||
|
||||
if schema == "json":
|
||||
parsed_data = llm_output_postprocess(
|
||||
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
|
||||
)
|
||||
else: # using markdown parser
|
||||
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
|
||||
|
||||
logger.debug(f"parsed_data:\n{parsed_data}")
|
||||
instruct_content = output_class(**parsed_data)
|
||||
return content, instruct_content
|
||||
|
||||
def get(self, key):
|
||||
return self.instruct_content.model_dump()[key]
|
||||
|
||||
def set_recursive(self, name, value):
|
||||
setattr(self, name, value)
|
||||
for _, i in self.children.items():
|
||||
i.set_recursive(name, value)
|
||||
|
||||
def set_llm(self, llm):
|
||||
self.set_recursive("llm", llm)
|
||||
|
||||
def set_context(self, context):
|
||||
self.set_recursive("context", context)
|
||||
|
||||
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
|
||||
|
||||
if schema != "raw":
|
||||
mapping = self.get_mapping(mode, exclude=exclude)
|
||||
class_name = f"{self.key}_AN"
|
||||
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
|
||||
self.content = content
|
||||
self.instruct_content = scontent
|
||||
else:
|
||||
self.content = await self.llm.aask(prompt)
|
||||
self.instruct_content = None
|
||||
|
||||
return self
|
||||
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
:param llm: Large Language Model with pre-defined system message.
|
||||
:param schema: json/markdown, determine example and output format.
|
||||
- raw: free form text
|
||||
- json: it's easy to open source LLM with json format
|
||||
- markdown: when generating code, markdown is always better
|
||||
:param mode: auto/children/root
|
||||
- auto: automated fill children's nodes and gather outputs, if no children, fill itself
|
||||
- children: fill children's nodes and gather outputs
|
||||
- root: fill root's node and gather output
|
||||
:param strgy: simple/complex
|
||||
- simple: run only once
|
||||
- complex: run each node
|
||||
:param timeout: Timeout for llm invocation.
|
||||
:param exclude: The keys of ActionNode to exclude.
|
||||
:return: self
|
||||
"""
|
||||
self.set_llm(llm)
|
||||
self.set_context(context)
|
||||
if self.schema:
|
||||
schema = self.schema
|
||||
|
||||
if strgy == "simple":
|
||||
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
elif strgy == "complex":
|
||||
# 这里隐式假设了拥有children
|
||||
tmp = {}
|
||||
for _, i in self.children.items():
|
||||
if exclude and i.key in exclude:
|
||||
continue
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
tmp.update(child.instruct_content.dict())
|
||||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
return self
|
||||
351
PR/repair_llm_raw_output.py
Normal file
351
PR/repair_llm_raw_output.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : repair llm raw output with particular conditions
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Callable, Union
|
||||
|
||||
import regex as re
|
||||
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.custom_decoder import CustomDecoder
|
||||
|
||||
|
||||
class RepairType(Enum):
|
||||
CS = "case sensitivity"
|
||||
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
|
||||
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
|
||||
JSON = "json format"
|
||||
|
||||
|
||||
def repair_case_sensitivity(output: str, req_key: str) -> str:
|
||||
"""
|
||||
usually, req_key is the key name of expected json or markdown content, it won't appear in the value part.
|
||||
fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually
|
||||
"""
|
||||
if req_key in output:
|
||||
return output
|
||||
|
||||
output_lower = output.lower()
|
||||
req_key_lower = req_key.lower()
|
||||
if req_key_lower in output_lower:
|
||||
# find the sub-part index, and replace it with raw req_key
|
||||
lidx = output_lower.find(req_key_lower)
|
||||
source = output[lidx : lidx + len(req_key_lower)]
|
||||
output = output.replace(source, req_key)
|
||||
logger.info(f"repair_case_sensitivity: {req_key}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
fix
|
||||
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
|
||||
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
|
||||
"""
|
||||
sc_arr = ["/"]
|
||||
|
||||
if req_key in output:
|
||||
return output
|
||||
|
||||
for sc in sc_arr:
|
||||
req_key_pure = req_key.replace(sc, "")
|
||||
appear_cnt = output.count(req_key_pure)
|
||||
if req_key_pure in output and appear_cnt > 1:
|
||||
# req_key with special_character usually in the tail side
|
||||
ridx = output.rfind(req_key_pure)
|
||||
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
|
||||
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
implement the req_key pair in the begin or end of the content
|
||||
req_key format
|
||||
1. `[req_key]`, and its pair `[/req_key]`
|
||||
2. `[/req_key]`, and its pair `[req_key]`
|
||||
"""
|
||||
sc = "/" # special char
|
||||
if req_key.startswith("[") and req_key.endswith("]"):
|
||||
if sc in req_key:
|
||||
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
|
||||
right_key = req_key
|
||||
else:
|
||||
left_key = req_key
|
||||
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
|
||||
|
||||
if left_key not in output:
|
||||
output = left_key + "\n" + output
|
||||
if right_key not in output:
|
||||
|
||||
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
|
||||
ridx = routput.rfind(left_key)
|
||||
if ridx < 0:
|
||||
return None
|
||||
sub_output = routput[ridx:]
|
||||
idx1 = sub_output.rfind("}")
|
||||
idx2 = sub_output.rindex("]")
|
||||
idx = idx1 if idx1 >= idx2 else idx2
|
||||
sub_output = sub_output[: idx + 1]
|
||||
return sub_output
|
||||
|
||||
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
|
||||
# # avoid [req_key]xx[req_key] case to append [/req_key]
|
||||
output = output + "\n" + right_key
|
||||
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
|
||||
sub_content = judge_potential_json(output, left_key)
|
||||
output = sub_content + "\n" + right_key
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_json_format(output: str) -> str:
|
||||
"""
|
||||
fix extra `[` or `}` in the end
|
||||
"""
|
||||
output = output.strip()
|
||||
|
||||
if output.startswith("[{"):
|
||||
output = output[1:]
|
||||
logger.info(f"repair_json_format: {'[{'}")
|
||||
elif output.endswith("}]"):
|
||||
output = output[:-1]
|
||||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
# remove comments in output json str, after json value content, maybe start with #, maybe start with //
|
||||
arr = output.split("\n")
|
||||
new_arr = []
|
||||
for line in arr:
|
||||
# look for # or // comments and make sure they are not inside the string value
|
||||
comment_index = -1
|
||||
for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line):
|
||||
if match.group(1): # if the string value
|
||||
continue
|
||||
if match.group(2): # if comments
|
||||
comment_index = match.start(2)
|
||||
break
|
||||
# if comments, then delete them
|
||||
if comment_index != -1:
|
||||
line = line[:comment_index].rstrip()
|
||||
new_arr.append(line)
|
||||
output = "\n".join(new_arr)
|
||||
return output
|
||||
|
||||
|
||||
def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str:
|
||||
repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]]
|
||||
for repair_type in repair_types:
|
||||
if repair_type == RepairType.CS:
|
||||
output = repair_case_sensitivity(output, req_key)
|
||||
elif repair_type == RepairType.RKPM:
|
||||
output = repair_required_key_pair_missing(output, req_key)
|
||||
elif repair_type == RepairType.SCM:
|
||||
output = repair_special_character_missing(output, req_key)
|
||||
elif repair_type == RepairType.JSON:
|
||||
output = repair_json_format(output)
|
||||
return output
|
||||
|
||||
|
||||
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
|
||||
"""
|
||||
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
|
||||
so here we try to repair it and use all repair methods by default.
|
||||
typical case
|
||||
1. case sensitivity
|
||||
target: "Original Requirements"
|
||||
output: "Original requirements"
|
||||
2. special character missing
|
||||
target: [/CONTENT]
|
||||
output: [CONTENT]
|
||||
3. json format
|
||||
target: { xxx }
|
||||
output: { xxx }]
|
||||
"""
|
||||
if not CONFIG.repair_llm_output:
|
||||
return output
|
||||
|
||||
# do the repairation usually for non-openai models
|
||||
for req_key in req_keys:
|
||||
output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type)
|
||||
return output
|
||||
|
||||
|
||||
def repair_invalid_json(output: str, error: str) -> str:
|
||||
"""
|
||||
repair the situation like there are extra chars like
|
||||
error examples
|
||||
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
|
||||
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
|
||||
"""
|
||||
pattern = r"line ([0-9]+) column ([0-9]+)"
|
||||
|
||||
matches = re.findall(pattern, error, re.DOTALL)
|
||||
if len(matches) > 0:
|
||||
line_no = int(matches[0][0]) - 1
|
||||
col_no = int(matches[0][1]) - 1
|
||||
|
||||
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
|
||||
output = output.replace('"""', '"').replace("'''", '"')
|
||||
arr = output.split("\n")
|
||||
rline = arr[line_no] # raw line
|
||||
line = arr[line_no].strip()
|
||||
# different general problems
|
||||
if line.endswith("],"):
|
||||
# problem, redundant char `]`
|
||||
new_line = line.replace("]", "")
|
||||
elif line.endswith("},") and not output.endswith("},"):
|
||||
# problem, redundant char `}`
|
||||
new_line = line.replace("}", "")
|
||||
elif line.endswith("},") and output.endswith("},"):
|
||||
new_line = line[:-1]
|
||||
# remove comments in output json str, after json value content, maybe start with #, maybe start with //
|
||||
elif rline[col_no] == "#" or rline[col_no] == "/":
|
||||
new_line = rline[:col_no]
|
||||
for i in range(line_no + 1, len(arr)):
|
||||
# look for # or // comments and make sure they are not inside the string value
|
||||
comment_index = -1
|
||||
for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", line):
|
||||
if match.group(1): # if the string value
|
||||
continue
|
||||
if match.group(2): # if comments
|
||||
comment_index = match.start(2)
|
||||
break
|
||||
# if comments, then delete them
|
||||
if comment_index != -1:
|
||||
arr[i] = arr[i][:comment_index].rstrip()
|
||||
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
|
||||
# problem, `"""` or `'''` without `,`
|
||||
new_line = f",{line}"
|
||||
elif '",' not in line and "," not in line and '"' not in line:
|
||||
new_line = f'{line}",'
|
||||
elif not line.endswith(","):
|
||||
# problem, miss char `,` at the end.
|
||||
new_line = f"{line},"
|
||||
elif "," in line and len(line) == 1:
|
||||
new_line = f'"{line}'
|
||||
elif '",' in line:
|
||||
new_line = line[:-2] + "',"
|
||||
else:
|
||||
new_line = line
|
||||
|
||||
arr[line_no] = new_line
|
||||
output = "\n".join(arr)
|
||||
logger.info(f"repair_invalid_json, raw error: {error}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
|
||||
def run_and_passon(retry_state: RetryCallState) -> None:
|
||||
"""
|
||||
RetryCallState example
|
||||
{
|
||||
"start_time":143.098322024,
|
||||
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
|
||||
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
|
||||
"args":"(\"tag:[/CONTENT]\",)", # function input args
|
||||
"kwargs":{}, # function input kwargs
|
||||
"attempt_number":1, # retry number
|
||||
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
|
||||
"outcome_timestamp":143.098416904,
|
||||
"idle_for":0,
|
||||
"next_action":"None"
|
||||
}
|
||||
"""
|
||||
if retry_state.outcome.failed:
|
||||
if retry_state.args:
|
||||
# # can't be used as args=retry_state.args
|
||||
func_param_output = retry_state.args[0]
|
||||
elif retry_state.kwargs:
|
||||
func_param_output = retry_state.kwargs.get("output", "")
|
||||
exp_str = str(retry_state.outcome.exception())
|
||||
|
||||
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
|
||||
logger.warning(
|
||||
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
|
||||
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
|
||||
)
|
||||
|
||||
repaired_output = repair_invalid_json(func_param_output, exp_str)
|
||||
retry_state.kwargs["output"] = repaired_output
|
||||
|
||||
return run_and_passon
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
|
||||
wait=wait_fixed(1),
|
||||
after=run_after_exp_and_passon_next_retry(logger),
|
||||
)
|
||||
def retry_parse_json_text(output: str) -> Union[list, dict]:
|
||||
"""
|
||||
repair the json-text situation like there are extra chars like [']', '}']
|
||||
|
||||
Warning
|
||||
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
|
||||
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
|
||||
it's a two-layer retry cycle
|
||||
"""
|
||||
# logger.debug(f"output to json decode:\n{output}")
|
||||
|
||||
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
|
||||
parsed_data = CustomDecoder(strict=False).decode(output)
|
||||
|
||||
return parsed_data
|
||||
|
||||
|
||||
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
|
||||
"""extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern"""
|
||||
|
||||
def re_extract_content(cont: str, pattern: str) -> str:
|
||||
matches = re.findall(pattern, cont, re.DOTALL)
|
||||
for match in matches:
|
||||
if match:
|
||||
cont = match
|
||||
break
|
||||
return cont.strip()
|
||||
|
||||
# TODO construct the extract pattern with the `right_key`
|
||||
raw_content = copy.deepcopy(content)
|
||||
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
|
||||
if not new_content.startswith("{"):
|
||||
# TODO find a more general pattern
|
||||
# # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
|
||||
logger.warning(f"extract_content try another pattern: {pattern}")
|
||||
if right_key not in new_content:
|
||||
raw_content = copy.deepcopy(new_content + "\n" + right_key)
|
||||
# # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
else:
|
||||
if right_key in new_content:
|
||||
idx = new_content.find(right_key)
|
||||
new_content = new_content[:idx]
|
||||
new_content = new_content.strip()
|
||||
|
||||
return new_content
|
||||
|
||||
|
||||
def extract_state_value_from_output(content: str) -> str:
|
||||
"""
|
||||
For openai models, they will always return state number. But for open llm models, the instruction result maybe a
|
||||
long text contain target number, so here add a extraction to improve success rate.
|
||||
|
||||
Args:
|
||||
content (str): llm's output from `Role._think`
|
||||
"""
|
||||
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
|
||||
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
|
||||
matches = re.findall(pattern, content, re.DOTALL)
|
||||
matches = list(set(matches))
|
||||
state = matches[0] if len(matches) > 0 else "-1"
|
||||
return state
|
||||
Loading…
Add table
Add a link
Reference in a new issue