refactor: format

This commit is contained in:
莘权 马 2023-12-14 20:44:27 +08:00
parent 829dfd8997
commit 290fb8b8d0
21 changed files with 262 additions and 284 deletions

View file

@ -2,29 +2,24 @@
# -*- coding: utf-8 -*-
# @Desc : pure async http_client
from typing import Optional, Any, Mapping, Union
from typing import Any, Mapping, Optional, Union
from aiohttp.client import DEFAULT_TIMEOUT
import aiohttp
from aiohttp.client import DEFAULT_TIMEOUT
async def apost(url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
as_json: bool = False,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total) -> Union[str, dict]:
async def apost(
url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
as_json: bool = False,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total,
) -> Union[str, dict]:
async with aiohttp.ClientSession() as session:
async with session.post(
url=url,
params=params,
json=json,
data=data,
headers=headers,
timeout=timeout
) as resp:
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
if as_json:
data = await resp.json()
else:
@ -33,13 +28,15 @@ async def apost(url: str,
return data
async def apost_stream(url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total) -> Any:
async def apost_stream(
url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total,
) -> Any:
"""
usage:
result = astream(url="xx")
@ -47,13 +44,6 @@ async def apost_stream(url: str,
deal_with(line)
"""
async with aiohttp.ClientSession() as session:
async with session.post(
url=url,
params=params,
json=json,
data=data,
headers=headers,
timeout=timeout
) as resp:
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
async for line in resp.content:
yield line.decode(encoding)

View file

@ -8,13 +8,15 @@
"""
from __future__ import annotations
from gitignore_parser import parse_gitignore, rule_from_pattern, handle_negation
import shutil
from enum import Enum
from pathlib import Path
from typing import Dict, List
from git.repo import Repo
from git.repo.fun import is_git_dir
from gitignore_parser import parse_gitignore
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.utils.dependency_file import DependencyFile
@ -236,8 +238,9 @@ class GitRepository:
rpath = file_path.relative_to(root_relative_path)
files.append(str(rpath))
else:
subfolder_files = self.get_files(relative_path=file_path, root_relative_path=root_relative_path,
filter_ignored=False)
subfolder_files = self.get_files(
relative_path=file_path, root_relative_path=root_relative_path, filter_ignored=False
)
files.extend(subfolder_files)
except Exception as e:
logger.error(f"Error: {e}")

View file

@ -4,12 +4,13 @@
import copy
from enum import Enum
from typing import Union, Callable
import regex as re
from tenacity import retry, stop_after_attempt, wait_fixed, after_log, RetryCallState
from typing import Callable, Union
import regex as re
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.custom_decoder import CustomDecoder
@ -33,7 +34,7 @@ def repair_case_sensitivity(output: str, req_key: str) -> str:
if req_key_lower in output_lower:
# find the sub-part index, and replace it with raw req_key
lidx = output_lower.find(req_key_lower)
source = output[lidx: lidx + len(req_key_lower)]
source = output[lidx : lidx + len(req_key_lower)]
output = output.replace(source, req_key)
logger.info(f"repair_case_sensitivity: {req_key}")
@ -73,7 +74,7 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -
sc = "/" # special char
if req_key.startswith("[") and req_key.endswith("]"):
if sc in req_key:
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
right_key = req_key
else:
left_key = req_key
@ -82,6 +83,7 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -
if left_key not in output:
output = left_key + "\n" + output
if right_key not in output:
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
ridx = routput.rfind(left_key)
if ridx < 0:
@ -90,7 +92,7 @@ def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -
idx1 = sub_output.rfind("}")
idx2 = sub_output.rindex("]")
idx = idx1 if idx1 >= idx2 else idx2
sub_output = sub_output[: idx+1]
sub_output = sub_output[: idx + 1]
return sub_output
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
@ -155,9 +157,7 @@ def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairT
# do the repairation usually for non-openai models
for req_key in req_keys:
output = _repair_llm_raw_output(output=output,
req_key=req_key,
repair_type=repair_type)
output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type)
return output
@ -187,7 +187,7 @@ def repair_invalid_json(output: str, error: str) -> str:
new_line = line.replace("}", "")
elif line.endswith("},") and output.endswith("},"):
new_line = line[:-1]
elif '",' not in line and ',' not in line:
elif '",' not in line and "," not in line:
new_line = f'{line}",'
elif "," not in line:
# problem, miss char `,` at the end.
@ -228,8 +228,10 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
elif retry_state.kwargs:
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
logger.warning(f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}")
logger.warning(
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
)
repaired_output = repair_invalid_json(func_param_output, exp_str)
retry_state.kwargs["output"] = repaired_output
@ -260,7 +262,8 @@ def retry_parse_json_text(output: str) -> Union[list, dict]:
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
""" extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern """
"""extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern"""
def re_extract_content(cont: str, pattern: str) -> str:
matches = re.findall(pattern, cont, re.DOTALL)
for match in matches:

View file

@ -4,7 +4,7 @@
import typing
from tenacity import after_log, _utils
from tenacity import _utils
def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
@ -13,7 +13,10 @@ def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typ
fn_name = "<unknown>"
else:
fn_name = _utils.get_callback_name(retry_state.fn)
logger.error(f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
f"exp: {retry_state.outcome.exception()}")
logger.error(
f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
f"exp: {retry_state.outcome.exception()}"
)
return log_it