mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
feat: merge geekan:env_refactor
This commit is contained in:
commit
be832c9995
47 changed files with 1928 additions and 843 deletions
49
metagpt/utils/ahttp_client.py
Normal file
49
metagpt/utils/ahttp_client.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : pure async http_client
|
||||
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client import DEFAULT_TIMEOUT
|
||||
|
||||
|
||||
async def apost(
|
||||
url: str,
|
||||
params: Optional[Mapping[str, str]] = None,
|
||||
json: Any = None,
|
||||
data: Any = None,
|
||||
headers: Optional[dict] = None,
|
||||
as_json: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
timeout: int = DEFAULT_TIMEOUT.total,
|
||||
) -> Union[str, dict]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
|
||||
if as_json:
|
||||
data = await resp.json()
|
||||
else:
|
||||
data = await resp.read()
|
||||
data = data.decode(encoding)
|
||||
return data
|
||||
|
||||
|
||||
async def apost_stream(
|
||||
url: str,
|
||||
params: Optional[Mapping[str, str]] = None,
|
||||
json: Any = None,
|
||||
data: Any = None,
|
||||
headers: Optional[dict] = None,
|
||||
encoding: str = "utf-8",
|
||||
timeout: int = DEFAULT_TIMEOUT.total,
|
||||
) -> Any:
|
||||
"""
|
||||
usage:
|
||||
result = astream(url="xx")
|
||||
async for line in result:
|
||||
deal_with(line)
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
|
||||
async for line in resp.content:
|
||||
yield line.decode(encoding)
|
||||
|
|
@ -224,10 +224,15 @@ class CodeParser:
|
|||
# 遍历所有的block
|
||||
for block in blocks:
|
||||
# 如果block不为空,则继续处理
|
||||
if block.strip() != "":
|
||||
if block.strip() == "":
|
||||
continue
|
||||
if "\n" not in block:
|
||||
block_title = block
|
||||
block_content = ""
|
||||
else:
|
||||
# 将block的标题和内容分开,并分别去掉前后的空白字符
|
||||
block_title, block_content = block.split("\n", 1)
|
||||
block_dict[block_title.strip()] = block_content.strip()
|
||||
block_dict[block_title.strip()] = block_content.strip()
|
||||
|
||||
return block_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ class FileRepository:
|
|||
m = json.loads(doc.content)
|
||||
filename = Path(doc.filename).with_suffix(with_suffix) if with_suffix is not None else Path(doc.filename)
|
||||
await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
|
||||
logger.info(f"File Saved: {str(filename)}")
|
||||
logger.debug(f"File Saved: {str(filename)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_file(filename: Path | str, relative_path: Path | str = ".") -> Document | None:
|
||||
|
|
|
|||
|
|
@ -233,6 +233,8 @@ class GitRepository:
|
|||
files = []
|
||||
try:
|
||||
directory_path = Path(self.workdir) / relative_path
|
||||
if not directory_path.exists():
|
||||
return []
|
||||
for file_path in directory_path.iterdir():
|
||||
if file_path.is_file():
|
||||
rpath = file_path.relative_to(root_relative_path)
|
||||
|
|
|
|||
310
metagpt/utils/repair_llm_raw_output.py
Normal file
310
metagpt/utils/repair_llm_raw_output.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : repair llm raw output with particular conditions
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Callable, Union
|
||||
|
||||
import regex as re
|
||||
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.custom_decoder import CustomDecoder
|
||||
|
||||
|
||||
class RepairType(Enum):
|
||||
CS = "case sensitivity"
|
||||
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
|
||||
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
|
||||
JSON = "json format"
|
||||
|
||||
|
||||
def repair_case_sensitivity(output: str, req_key: str) -> str:
|
||||
"""
|
||||
usually, req_key is the key name of expected json or markdown content, it won't appear in the value part.
|
||||
fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually
|
||||
"""
|
||||
if req_key in output:
|
||||
return output
|
||||
|
||||
output_lower = output.lower()
|
||||
req_key_lower = req_key.lower()
|
||||
if req_key_lower in output_lower:
|
||||
# find the sub-part index, and replace it with raw req_key
|
||||
lidx = output_lower.find(req_key_lower)
|
||||
source = output[lidx : lidx + len(req_key_lower)]
|
||||
output = output.replace(source, req_key)
|
||||
logger.info(f"repair_case_sensitivity: {req_key}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
fix
|
||||
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
|
||||
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
|
||||
"""
|
||||
sc_arr = ["/"]
|
||||
|
||||
if req_key in output:
|
||||
return output
|
||||
|
||||
for sc in sc_arr:
|
||||
req_key_pure = req_key.replace(sc, "")
|
||||
appear_cnt = output.count(req_key_pure)
|
||||
if req_key_pure in output and appear_cnt > 1:
|
||||
# req_key with special_character usually in the tail side
|
||||
ridx = output.rfind(req_key_pure)
|
||||
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
|
||||
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
implement the req_key pair in the begin or end of the content
|
||||
req_key format
|
||||
1. `[req_key]`, and its pair `[/req_key]`
|
||||
2. `[/req_key]`, and its pair `[req_key]`
|
||||
"""
|
||||
sc = "/" # special char
|
||||
if req_key.startswith("[") and req_key.endswith("]"):
|
||||
if sc in req_key:
|
||||
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
|
||||
right_key = req_key
|
||||
else:
|
||||
left_key = req_key
|
||||
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
|
||||
|
||||
if left_key not in output:
|
||||
output = left_key + "\n" + output
|
||||
if right_key not in output:
|
||||
|
||||
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
|
||||
ridx = routput.rfind(left_key)
|
||||
if ridx < 0:
|
||||
return None
|
||||
sub_output = routput[ridx:]
|
||||
idx1 = sub_output.rfind("}")
|
||||
idx2 = sub_output.rindex("]")
|
||||
idx = idx1 if idx1 >= idx2 else idx2
|
||||
sub_output = sub_output[: idx + 1]
|
||||
return sub_output
|
||||
|
||||
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
|
||||
# # avoid [req_key]xx[req_key] case to append [/req_key]
|
||||
output = output + "\n" + right_key
|
||||
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
|
||||
sub_content = judge_potential_json(output, left_key)
|
||||
output = sub_content + "\n" + right_key
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_json_format(output: str) -> str:
|
||||
"""
|
||||
fix extra `[` or `}` in the end
|
||||
"""
|
||||
output = output.strip()
|
||||
|
||||
if output.startswith("[{"):
|
||||
output = output[1:]
|
||||
logger.info(f"repair_json_format: {'[{'}")
|
||||
elif output.endswith("}]"):
|
||||
output = output[:-1]
|
||||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str:
|
||||
repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]]
|
||||
for repair_type in repair_types:
|
||||
if repair_type == RepairType.CS:
|
||||
output = repair_case_sensitivity(output, req_key)
|
||||
elif repair_type == RepairType.RKPM:
|
||||
output = repair_required_key_pair_missing(output, req_key)
|
||||
elif repair_type == RepairType.SCM:
|
||||
output = repair_special_character_missing(output, req_key)
|
||||
elif repair_type == RepairType.JSON:
|
||||
output = repair_json_format(output)
|
||||
return output
|
||||
|
||||
|
||||
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
|
||||
"""
|
||||
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
|
||||
so here we try to repair it and use all repair methods by default.
|
||||
typical case
|
||||
1. case sensitivity
|
||||
target: "Original Requirements"
|
||||
output: "Original requirements"
|
||||
2. special character missing
|
||||
target: [/CONTENT]
|
||||
output: [CONTENT]
|
||||
3. json format
|
||||
target: { xxx }
|
||||
output: { xxx }]
|
||||
"""
|
||||
if not CONFIG.repair_llm_output:
|
||||
return output
|
||||
|
||||
# do the repairation usually for non-openai models
|
||||
for req_key in req_keys:
|
||||
output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type)
|
||||
return output
|
||||
|
||||
|
||||
def repair_invalid_json(output: str, error: str) -> str:
|
||||
"""
|
||||
repair the situation like there are extra chars like
|
||||
error examples
|
||||
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
|
||||
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
|
||||
"""
|
||||
pattern = r"line ([0-9]+)"
|
||||
|
||||
matches = re.findall(pattern, error, re.DOTALL)
|
||||
if len(matches) > 0:
|
||||
line_no = int(matches[0]) - 1
|
||||
|
||||
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
|
||||
output = output.replace('"""', '"').replace("'''", '"')
|
||||
arr = output.split("\n")
|
||||
line = arr[line_no].strip()
|
||||
# different general problems
|
||||
if line.endswith("],"):
|
||||
# problem, redundant char `]`
|
||||
new_line = line.replace("]", "")
|
||||
elif line.endswith("},") and not output.endswith("},"):
|
||||
# problem, redundant char `}`
|
||||
new_line = line.replace("}", "")
|
||||
elif line.endswith("},") and output.endswith("},"):
|
||||
new_line = line[:-1]
|
||||
elif '",' not in line and "," not in line:
|
||||
new_line = f'{line}",'
|
||||
elif "," not in line:
|
||||
# problem, miss char `,` at the end.
|
||||
new_line = f"{line},"
|
||||
elif "," in line and len(line) == 1:
|
||||
new_line = f'"{line}'
|
||||
elif '",' in line:
|
||||
new_line = line[:-2] + "',"
|
||||
|
||||
arr[line_no] = new_line
|
||||
output = "\n".join(arr)
|
||||
logger.info(f"repair_invalid_json, raw error: {error}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
|
||||
def run_and_passon(retry_state: RetryCallState) -> None:
|
||||
"""
|
||||
RetryCallState example
|
||||
{
|
||||
"start_time":143.098322024,
|
||||
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
|
||||
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
|
||||
"args":"(\"tag:[/CONTENT]\",)", # function input args
|
||||
"kwargs":{}, # function input kwargs
|
||||
"attempt_number":1, # retry number
|
||||
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
|
||||
"outcome_timestamp":143.098416904,
|
||||
"idle_for":0,
|
||||
"next_action":"None"
|
||||
}
|
||||
"""
|
||||
if retry_state.outcome.failed:
|
||||
if retry_state.args:
|
||||
# # can't be used as args=retry_state.args
|
||||
func_param_output = retry_state.args[0]
|
||||
elif retry_state.kwargs:
|
||||
func_param_output = retry_state.kwargs.get("output", "")
|
||||
exp_str = str(retry_state.outcome.exception())
|
||||
logger.warning(
|
||||
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
|
||||
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
|
||||
)
|
||||
|
||||
repaired_output = repair_invalid_json(func_param_output, exp_str)
|
||||
retry_state.kwargs["output"] = repaired_output
|
||||
|
||||
return run_and_passon
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
|
||||
wait=wait_fixed(1),
|
||||
after=run_after_exp_and_passon_next_retry(logger),
|
||||
)
|
||||
def retry_parse_json_text(output: str) -> Union[list, dict]:
|
||||
"""
|
||||
repair the json-text situation like there are extra chars like [']', '}']
|
||||
|
||||
Warning
|
||||
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
|
||||
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
|
||||
it's a two-layer retry cycle
|
||||
"""
|
||||
logger.debug(f"output to json decode:\n{output}")
|
||||
|
||||
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
|
||||
parsed_data = CustomDecoder(strict=False).decode(output)
|
||||
|
||||
return parsed_data
|
||||
|
||||
|
||||
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
|
||||
"""extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern"""
|
||||
|
||||
def re_extract_content(cont: str, pattern: str) -> str:
|
||||
matches = re.findall(pattern, cont, re.DOTALL)
|
||||
for match in matches:
|
||||
if match:
|
||||
cont = match
|
||||
break
|
||||
return cont.strip()
|
||||
|
||||
# TODO construct the extract pattern with the `right_key`
|
||||
raw_content = copy.deepcopy(content)
|
||||
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
|
||||
if not new_content.startswith("{"):
|
||||
# TODO find a more general pattern
|
||||
# # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
|
||||
logger.warning(f"extract_content try another pattern: {pattern}")
|
||||
if right_key not in new_content:
|
||||
raw_content = copy.deepcopy(new_content + "\n" + right_key)
|
||||
# # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
else:
|
||||
if right_key in new_content:
|
||||
idx = new_content.find(right_key)
|
||||
new_content = new_content[:idx]
|
||||
new_content = new_content.strip()
|
||||
|
||||
return new_content
|
||||
|
||||
|
||||
def extract_state_value_from_output(content: str) -> str:
|
||||
"""
|
||||
For openai models, they will always return state number. But for open llm models, the instruction result maybe a
|
||||
long text contain target number, so here add a extraction to improve success rate.
|
||||
|
||||
Args:
|
||||
content (str): llm's output from `Role._think`
|
||||
"""
|
||||
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
|
||||
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
|
||||
matches = re.findall(pattern, content, re.DOTALL)
|
||||
matches = list(set(matches))
|
||||
state = matches[0] if len(matches) > 0 else "-1"
|
||||
return state
|
||||
22
metagpt/utils/utils.py
Normal file
22
metagpt/utils/utils.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import typing
|
||||
|
||||
from tenacity import _utils
|
||||
|
||||
|
||||
def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
def log_it(retry_state: "RetryCallState") -> None:
|
||||
if retry_state.fn is None:
|
||||
fn_name = "<unknown>"
|
||||
else:
|
||||
fn_name = _utils.get_callback_name(retry_state.fn)
|
||||
logger.error(
|
||||
f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
|
||||
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
|
||||
f"exp: {retry_state.outcome.exception()}"
|
||||
)
|
||||
|
||||
return log_it
|
||||
Loading…
Add table
Add a link
Reference in a new issue