mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
Merge branch 'main' into feature-openai-v1
This commit is contained in:
commit
9a4f0d555c
260 changed files with 10576 additions and 3191 deletions
49
metagpt/utils/ahttp_client.py
Normal file
49
metagpt/utils/ahttp_client.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : pure async http_client
|
||||
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client import DEFAULT_TIMEOUT
|
||||
|
||||
|
||||
async def apost(
|
||||
url: str,
|
||||
params: Optional[Mapping[str, str]] = None,
|
||||
json: Any = None,
|
||||
data: Any = None,
|
||||
headers: Optional[dict] = None,
|
||||
as_json: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
timeout: int = DEFAULT_TIMEOUT.total,
|
||||
) -> Union[str, dict]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
|
||||
if as_json:
|
||||
data = await resp.json()
|
||||
else:
|
||||
data = await resp.read()
|
||||
data = data.decode(encoding)
|
||||
return data
|
||||
|
||||
|
||||
async def apost_stream(
|
||||
url: str,
|
||||
params: Optional[Mapping[str, str]] = None,
|
||||
json: Any = None,
|
||||
data: Any = None,
|
||||
headers: Optional[dict] = None,
|
||||
encoding: str = "utf-8",
|
||||
timeout: int = DEFAULT_TIMEOUT.total,
|
||||
) -> Any:
|
||||
"""
|
||||
usage:
|
||||
result = astream(url="xx")
|
||||
async for line in result:
|
||||
deal_with(line)
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
|
||||
async for line in resp.content:
|
||||
yield line.decode(encoding)
|
||||
|
|
@ -4,16 +4,34 @@
|
|||
@Time : 2023/4/29 16:07
|
||||
@Author : alexanderwu
|
||||
@File : common.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.2 of RFC 116:
|
||||
Add generic class-to-string and object-to-string conversion functionality.
|
||||
@Modified By: mashenquan, 2023/11/27. Bug fix: `parse_recipient` failed to parse the recipient in certain GPT-3.5
|
||||
responses.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from typing import List, Tuple, Union
|
||||
import traceback
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union, get_args, get_origin
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
from pydantic.json import pydantic_encoder
|
||||
from tenacity import RetryCallState, _utils
|
||||
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_ALL
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
def check_cmd_exists(command) -> int:
|
||||
|
|
@ -85,10 +103,7 @@ class OutputParser:
|
|||
|
||||
@staticmethod
|
||||
def parse_python_code(text: str) -> str:
|
||||
for pattern in (
|
||||
r"(.*?```python.*?\s+)?(?P<code>.*)(```.*?)",
|
||||
r"(.*?```python.*?\s+)?(?P<code>.*)",
|
||||
):
|
||||
for pattern in (r"(.*?```python.*?\s+)?(?P<code>.*)(```.*?)", r"(.*?```python.*?\s+)?(?P<code>.*)"):
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if not match:
|
||||
continue
|
||||
|
|
@ -119,8 +134,32 @@ class OutputParser:
|
|||
parsed_data[block] = content
|
||||
return parsed_data
|
||||
|
||||
@staticmethod
|
||||
def extract_content(text, tag="CONTENT"):
|
||||
# Use regular expression to extract content between [CONTENT] and [/CONTENT]
|
||||
extracted_content = re.search(rf"\[{tag}\](.*?)\[/{tag}\]", text, re.DOTALL)
|
||||
|
||||
if extracted_content:
|
||||
return extracted_content.group(1).strip()
|
||||
else:
|
||||
return "No content found between [CONTENT] and [/CONTENT] tags."
|
||||
|
||||
@staticmethod
|
||||
def is_supported_list_type(i):
|
||||
origin = get_origin(i)
|
||||
if origin is not List:
|
||||
return False
|
||||
|
||||
args = get_args(i)
|
||||
if args == (str,) or args == (Tuple[str, str],) or args == (List[str],):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse_data_with_mapping(cls, data, mapping):
|
||||
if "[CONTENT]" in data:
|
||||
data = cls.extract_content(text=data)
|
||||
block_dict = cls.parse_blocks(data)
|
||||
parsed_data = {}
|
||||
for block, content in block_dict.items():
|
||||
|
|
@ -187,7 +226,7 @@ class OutputParser:
|
|||
result = ast.literal_eval(structure_text)
|
||||
|
||||
# Ensure the result matches the specified data type
|
||||
if isinstance(result, list) or isinstance(result, dict):
|
||||
if isinstance(result, (list, dict)):
|
||||
return result
|
||||
|
||||
raise ValueError(f"The extracted structure is not a {data_type}.")
|
||||
|
|
@ -219,10 +258,15 @@ class CodeParser:
|
|||
# 遍历所有的block
|
||||
for block in blocks:
|
||||
# 如果block不为空,则继续处理
|
||||
if block.strip() != "":
|
||||
if block.strip() == "":
|
||||
continue
|
||||
if "\n" not in block:
|
||||
block_title = block
|
||||
block_content = ""
|
||||
else:
|
||||
# 将block的标题和内容分开,并分别去掉前后的空白字符
|
||||
block_title, block_content = block.split("\n", 1)
|
||||
block_dict[block_title.strip()] = block_content.strip()
|
||||
block_dict[block_title.strip()] = block_content.strip()
|
||||
|
||||
return block_dict
|
||||
|
||||
|
|
@ -282,9 +326,6 @@ class NoMoneyException(Exception):
|
|||
def print_members(module, indent=0):
|
||||
"""
|
||||
https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python
|
||||
:param module:
|
||||
:param indent:
|
||||
:return:
|
||||
"""
|
||||
prefix = " " * indent
|
||||
for name, obj in inspect.getmembers(module):
|
||||
|
|
@ -302,6 +343,173 @@ def print_members(module, indent=0):
|
|||
|
||||
|
||||
def parse_recipient(text):
|
||||
# FIXME: use ActionNode instead.
|
||||
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
|
||||
recipient = re.search(pattern, text)
|
||||
return recipient.group(1) if recipient else ""
|
||||
if recipient:
|
||||
return recipient.group(1)
|
||||
pattern = r"Send To:\s*([A-Za-z]+)\s*?"
|
||||
recipient = re.search(pattern, text)
|
||||
if recipient:
|
||||
return recipient.group(1)
|
||||
return ""
|
||||
|
||||
|
||||
def get_class_name(cls) -> str:
|
||||
"""Return class name"""
|
||||
return f"{cls.__module__}.{cls.__name__}"
|
||||
|
||||
|
||||
def any_to_str(val: str | typing.Callable) -> str:
|
||||
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
if not callable(val):
|
||||
return get_class_name(type(val))
|
||||
|
||||
return get_class_name(val)
|
||||
|
||||
|
||||
def any_to_str_set(val) -> set:
|
||||
"""Convert any type to string set."""
|
||||
res = set()
|
||||
|
||||
# Check if the value is iterable, but not a string (since strings are technically iterable)
|
||||
if isinstance(val, (dict, list, set, tuple)):
|
||||
# Special handling for dictionaries to iterate over values
|
||||
if isinstance(val, dict):
|
||||
val = val.values()
|
||||
|
||||
for i in val:
|
||||
res.add(any_to_str(i))
|
||||
else:
|
||||
res.add(any_to_str(val))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def is_subscribed(message: "Message", tags: set):
|
||||
"""Return whether it's consumer"""
|
||||
if MESSAGE_ROUTE_TO_ALL in message.send_to:
|
||||
return True
|
||||
|
||||
for i in tags:
|
||||
if i in message.send_to:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
"""
|
||||
Generates a logging function to be used after a call is retried.
|
||||
|
||||
This generated function logs an error message with the outcome of the retried function call. It includes
|
||||
the name of the function, the time taken for the call in seconds (formatted according to `sec_format`),
|
||||
the number of attempts made, and the exception raised, if any.
|
||||
|
||||
:param i: A Logger instance from the loguru library used to log the error message.
|
||||
:param sec_format: A string format specifier for how to format the number of seconds since the start of the call.
|
||||
Defaults to three decimal places.
|
||||
:return: A callable that accepts a RetryCallState object and returns None. This callable logs the details
|
||||
of the retried call.
|
||||
"""
|
||||
|
||||
def log_it(retry_state: "RetryCallState") -> None:
|
||||
# If the function name is not known, default to "<unknown>"
|
||||
if retry_state.fn is None:
|
||||
fn_name = "<unknown>"
|
||||
else:
|
||||
# Retrieve the callable's name using a utility function
|
||||
fn_name = _utils.get_callback_name(retry_state.fn)
|
||||
|
||||
# Log an error message with the function name, time since start, attempt number, and the exception
|
||||
i.error(
|
||||
f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
|
||||
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
|
||||
f"exp: {retry_state.outcome.exception()}"
|
||||
)
|
||||
|
||||
return log_it
|
||||
|
||||
|
||||
def read_json_file(json_file: str, encoding=None) -> list[Any]:
|
||||
if not Path(json_file).exists():
|
||||
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
|
||||
|
||||
with open(json_file, "r", encoding=encoding) as fin:
|
||||
try:
|
||||
data = json.load(fin)
|
||||
except Exception:
|
||||
raise ValueError(f"read json file: {json_file} failed")
|
||||
return data
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: list, encoding=None):
|
||||
folder_path = Path(json_file).parent
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(json_file, "w", encoding=encoding) as fout:
|
||||
json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder)
|
||||
|
||||
|
||||
def import_class(class_name: str, module_name: str) -> type:
|
||||
module = importlib.import_module(module_name)
|
||||
a_class = getattr(module, class_name)
|
||||
return a_class
|
||||
|
||||
|
||||
def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object:
|
||||
a_class = import_class(class_name, module_name)
|
||||
class_inst = a_class(*args, **kwargs)
|
||||
return class_inst
|
||||
|
||||
|
||||
def format_trackback_info(limit: int = 2):
|
||||
return traceback.format_exc(limit=limit)
|
||||
|
||||
|
||||
def serialize_decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
result = await func(self, *args, **kwargs)
|
||||
return result
|
||||
except KeyboardInterrupt:
|
||||
logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}")
|
||||
except Exception:
|
||||
logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}")
|
||||
self.serialize() # Team.serialize
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def role_raise_decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except KeyboardInterrupt as kbi:
|
||||
logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project")
|
||||
if self.latest_observed_msg:
|
||||
self._rc.memory.delete(self.latest_observed_msg)
|
||||
# raise again to make it captured outside
|
||||
raise Exception(format_trackback_info(limit=None))
|
||||
except Exception:
|
||||
if self.latest_observed_msg:
|
||||
logger.warning(
|
||||
"There is a exception in role's execution, in order to resume, "
|
||||
"we delete the newest role communication message in the role's memory."
|
||||
)
|
||||
# remove role newest observed msg to make it observed again
|
||||
self._rc.memory.delete(self.latest_observed_msg)
|
||||
# raise again to make it captured outside
|
||||
raise Exception(format_trackback_info(limit=None))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@handle_exception
|
||||
async def aread(file_path: str) -> str:
|
||||
"""Read file asynchronously."""
|
||||
async with aiofiles.open(str(file_path), mode="r") as reader:
|
||||
content = await reader.read()
|
||||
return content
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def py_make_scanner(context):
|
|||
except IndexError:
|
||||
raise StopIteration(idx) from None
|
||||
|
||||
if nextchar == '"' or nextchar == "'":
|
||||
if nextchar in ("'", '"'):
|
||||
if idx + 2 < len(string) and string[idx + 1] == nextchar and string[idx + 2] == nextchar:
|
||||
# Handle the case where the next two characters are the same as nextchar
|
||||
return parse_string(string, idx + 3, strict, delimiter=nextchar * 3) # triple quote
|
||||
|
|
|
|||
103
metagpt/utils/dependency_file.py
Normal file
103
metagpt/utils/dependency_file.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/22
|
||||
@Author : mashenquan
|
||||
@File : dependency_file.py
|
||||
@Desc: Implementation of the dependency file described in Section 2.2.3.2 of RFC 135.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class DependencyFile:
|
||||
"""A class representing a DependencyFile for managing dependencies.
|
||||
|
||||
:param workdir: The working directory path for the DependencyFile.
|
||||
"""
|
||||
|
||||
def __init__(self, workdir: Path | str):
|
||||
"""Initialize a DependencyFile instance.
|
||||
|
||||
:param workdir: The working directory path for the DependencyFile.
|
||||
"""
|
||||
self._dependencies = {}
|
||||
self._filename = Path(workdir) / ".dependencies.json"
|
||||
|
||||
async def load(self):
|
||||
"""Load dependencies from the file asynchronously."""
|
||||
if not self._filename.exists():
|
||||
return
|
||||
self._dependencies = json.loads(await aread(self._filename))
|
||||
|
||||
@handle_exception
|
||||
async def save(self):
|
||||
"""Save dependencies to the file asynchronously."""
|
||||
data = json.dumps(self._dependencies)
|
||||
async with aiofiles.open(str(self._filename), mode="w") as writer:
|
||||
await writer.write(data)
|
||||
|
||||
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
|
||||
"""Update dependencies for a file asynchronously.
|
||||
|
||||
:param filename: The filename or path.
|
||||
:param dependencies: The set of dependencies.
|
||||
:param persist: Whether to persist the changes immediately.
|
||||
"""
|
||||
if persist:
|
||||
await self.load()
|
||||
|
||||
root = self._filename.parent
|
||||
try:
|
||||
key = Path(filename).relative_to(root)
|
||||
except ValueError:
|
||||
key = filename
|
||||
|
||||
if dependencies:
|
||||
relative_paths = []
|
||||
for i in dependencies:
|
||||
try:
|
||||
relative_paths.append(str(Path(i).relative_to(root)))
|
||||
except ValueError:
|
||||
relative_paths.append(str(i))
|
||||
self._dependencies[str(key)] = relative_paths
|
||||
elif str(key) in self._dependencies:
|
||||
del self._dependencies[str(key)]
|
||||
|
||||
if persist:
|
||||
await self.save()
|
||||
|
||||
async def get(self, filename: Path | str, persist=True):
|
||||
"""Get dependencies for a file asynchronously.
|
||||
|
||||
:param filename: The filename or path.
|
||||
:param persist: Whether to load dependencies from the file immediately.
|
||||
:return: A set of dependencies.
|
||||
"""
|
||||
if persist:
|
||||
await self.load()
|
||||
|
||||
root = CONFIG.git_repo.workdir
|
||||
try:
|
||||
key = Path(filename).relative_to(root)
|
||||
except ValueError:
|
||||
key = filename
|
||||
return set(self._dependencies.get(str(key), {}))
|
||||
|
||||
def delete_file(self):
|
||||
"""Delete the dependency file."""
|
||||
self._filename.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def exists(self):
|
||||
"""Check if the dependency file exists."""
|
||||
return self._filename.exists()
|
||||
59
metagpt/utils/exceptions.py
Normal file
59
metagpt/utils/exceptions.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19 14:46
|
||||
@Author : alexanderwu
|
||||
@File : exceptions.py
|
||||
"""
|
||||
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import traceback
|
||||
from typing import Any, Callable, Tuple, Type, TypeVar, Union
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
ReturnType = TypeVar("ReturnType")
|
||||
|
||||
|
||||
def handle_exception(
|
||||
_func: Callable[..., ReturnType] = None,
|
||||
*,
|
||||
exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception,
|
||||
default_return: Any = None,
|
||||
) -> Callable[..., ReturnType]:
|
||||
"""handle exception, return default value"""
|
||||
|
||||
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except exception_type as e:
|
||||
logger.opt(depth=1).error(
|
||||
f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, "
|
||||
f"stack: {traceback.format_exc()}"
|
||||
)
|
||||
return default_return
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exception_type as e:
|
||||
logger.opt(depth=1).error(
|
||||
f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, "
|
||||
f"stack: {traceback.format_exc()}"
|
||||
)
|
||||
return default_return
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
if _func is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(_func)
|
||||
|
|
@ -6,10 +6,12 @@
|
|||
@File : file.py
|
||||
@Describe : General file operations.
|
||||
"""
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class File:
|
||||
|
|
@ -18,6 +20,7 @@ class File:
|
|||
CHUNK_SIZE = 64 * 1024
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
async def write(cls, root_path: Path, filename: str, content: bytes) -> Path:
|
||||
"""Write the file content to the local specified path.
|
||||
|
||||
|
|
@ -32,18 +35,15 @@ class File:
|
|||
Raises:
|
||||
Exception: If an unexpected error occurs during the file writing process.
|
||||
"""
|
||||
try:
|
||||
root_path.mkdir(parents=True, exist_ok=True)
|
||||
full_path = root_path / filename
|
||||
async with aiofiles.open(full_path, mode="wb") as writer:
|
||||
await writer.write(content)
|
||||
logger.debug(f"Successfully write file: {full_path}")
|
||||
return full_path
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing file: {e}")
|
||||
raise e
|
||||
root_path.mkdir(parents=True, exist_ok=True)
|
||||
full_path = root_path / filename
|
||||
async with aiofiles.open(full_path, mode="wb") as writer:
|
||||
await writer.write(content)
|
||||
logger.debug(f"Successfully write file: {full_path}")
|
||||
return full_path
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
async def read(cls, file_path: Path, chunk_size: int = None) -> bytes:
|
||||
"""Partitioning read the file content from the local specified path.
|
||||
|
||||
|
|
@ -57,19 +57,14 @@ class File:
|
|||
Raises:
|
||||
Exception: If an unexpected error occurs during the file reading process.
|
||||
"""
|
||||
try:
|
||||
chunk_size = chunk_size or cls.CHUNK_SIZE
|
||||
async with aiofiles.open(file_path, mode="rb") as reader:
|
||||
chunks = list()
|
||||
while True:
|
||||
chunk = await reader.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
content = b''.join(chunks)
|
||||
logger.debug(f"Successfully read file, the path of file: {file_path}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file: {e}")
|
||||
raise e
|
||||
|
||||
chunk_size = chunk_size or cls.CHUNK_SIZE
|
||||
async with aiofiles.open(file_path, mode="rb") as reader:
|
||||
chunks = list()
|
||||
while True:
|
||||
chunk = await reader.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
logger.debug(f"Successfully read file, the path of file: {file_path}")
|
||||
return content
|
||||
|
|
|
|||
287
metagpt/utils/file_repository.py
Normal file
287
metagpt/utils/file_repository.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/20
|
||||
@Author : mashenquan
|
||||
@File : git_repository.py
|
||||
@Desc: File repository management. RFC 135 2.2.3.2, 2.2.3.4 and 2.2.3.13.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.json_to_markdown import json_to_markdown
|
||||
|
||||
|
||||
class FileRepository:
|
||||
"""A class representing a FileRepository associated with a Git repository.
|
||||
|
||||
:param git_repo: The associated GitRepository instance.
|
||||
:param relative_path: The relative path within the Git repository.
|
||||
|
||||
Attributes:
|
||||
_relative_path (Path): The relative path within the Git repository.
|
||||
_git_repo (GitRepository): The associated GitRepository instance.
|
||||
"""
|
||||
|
||||
def __init__(self, git_repo, relative_path: Path = Path(".")):
|
||||
"""Initialize a FileRepository instance.
|
||||
|
||||
:param git_repo: The associated GitRepository instance.
|
||||
:param relative_path: The relative path within the Git repository.
|
||||
"""
|
||||
self._relative_path = relative_path
|
||||
self._git_repo = git_repo
|
||||
|
||||
# Initializing
|
||||
self.workdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def save(self, filename: Path | str, content, dependencies: List[str] = None):
|
||||
"""Save content to a file and update its dependencies.
|
||||
|
||||
:param filename: The filename or path within the repository.
|
||||
:param content: The content to be saved.
|
||||
:param dependencies: List of dependency filenames or paths.
|
||||
"""
|
||||
pathname = self.workdir / filename
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(str(pathname), mode="w") as writer:
|
||||
await writer.write(content)
|
||||
logger.info(f"save to: {str(pathname)}")
|
||||
|
||||
if dependencies is not None:
|
||||
dependency_file = await self._git_repo.get_dependency()
|
||||
await dependency_file.update(pathname, set(dependencies))
|
||||
logger.info(f"update dependency: {str(pathname)}:{dependencies}")
|
||||
|
||||
async def get_dependency(self, filename: Path | str) -> Set[str]:
|
||||
"""Get the dependencies of a file.
|
||||
|
||||
:param filename: The filename or path within the repository.
|
||||
:return: Set of dependency filenames or paths.
|
||||
"""
|
||||
pathname = self.workdir / filename
|
||||
dependency_file = await self._git_repo.get_dependency()
|
||||
return await dependency_file.get(pathname)
|
||||
|
||||
async def get_changed_dependency(self, filename: Path | str) -> Set[str]:
|
||||
"""Get the dependencies of a file that have changed.
|
||||
|
||||
:param filename: The filename or path within the repository.
|
||||
:return: List of changed dependency filenames or paths.
|
||||
"""
|
||||
dependencies = await self.get_dependency(filename=filename)
|
||||
changed_files = self.changed_files
|
||||
changed_dependent_files = set()
|
||||
for df in dependencies:
|
||||
if df in changed_files.keys():
|
||||
changed_dependent_files.add(df)
|
||||
return changed_dependent_files
|
||||
|
||||
async def get(self, filename: Path | str) -> Document | None:
|
||||
"""Read the content of a file.
|
||||
|
||||
:param filename: The filename or path within the repository.
|
||||
:return: The content of the file.
|
||||
"""
|
||||
doc = Document(root_path=str(self.root_path), filename=str(filename))
|
||||
path_name = self.workdir / filename
|
||||
if not path_name.exists():
|
||||
return None
|
||||
doc.content = await aread(path_name)
|
||||
return doc
|
||||
|
||||
async def get_all(self) -> List[Document]:
|
||||
"""Get the content of all files in the repository.
|
||||
|
||||
:return: List of Document instances representing files.
|
||||
"""
|
||||
docs = []
|
||||
for root, dirs, files in os.walk(str(self.workdir)):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
relative_path = file_path.relative_to(self.workdir)
|
||||
doc = await self.get(relative_path)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@property
|
||||
def workdir(self):
|
||||
"""Return the absolute path to the working directory of the FileRepository.
|
||||
|
||||
:return: The absolute path to the working directory.
|
||||
"""
|
||||
return self._git_repo.workdir / self._relative_path
|
||||
|
||||
@property
|
||||
def root_path(self):
|
||||
"""Return the relative path from git repository root"""
|
||||
return self._relative_path
|
||||
|
||||
@property
|
||||
def changed_files(self) -> Dict[str, str]:
|
||||
"""Return a dictionary of changed files and their change types.
|
||||
|
||||
:return: A dictionary where keys are file paths and values are change types.
|
||||
"""
|
||||
files = self._git_repo.changed_files
|
||||
relative_files = {}
|
||||
for p, ct in files.items():
|
||||
try:
|
||||
rf = Path(p).relative_to(self._relative_path)
|
||||
except ValueError:
|
||||
continue
|
||||
relative_files[str(rf)] = ct
|
||||
return relative_files
|
||||
|
||||
@property
|
||||
def all_files(self) -> List:
|
||||
"""Get a dictionary of all files in the repository.
|
||||
|
||||
The dictionary includes file paths relative to the current FileRepository.
|
||||
|
||||
:return: A dictionary where keys are file paths and values are file information.
|
||||
:rtype: List
|
||||
"""
|
||||
return self._git_repo.get_files(relative_path=self._relative_path)
|
||||
|
||||
def get_change_dir_files(self, dir: Path | str) -> List:
|
||||
"""Get the files in a directory that have changed.
|
||||
|
||||
:param dir: The directory path within the repository.
|
||||
:return: List of changed filenames or paths within the directory.
|
||||
"""
|
||||
changed_files = self.changed_files
|
||||
children = []
|
||||
for f in changed_files:
|
||||
try:
|
||||
Path(f).relative_to(Path(dir))
|
||||
except ValueError:
|
||||
continue
|
||||
children.append(str(f))
|
||||
return children
|
||||
|
||||
@staticmethod
|
||||
def new_filename():
|
||||
"""Generate a new filename based on the current timestamp and a UUID suffix.
|
||||
|
||||
:return: A new filename string.
|
||||
"""
|
||||
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
return current_time
|
||||
# guid_suffix = str(uuid.uuid4())[:8]
|
||||
# return f"{current_time}x{guid_suffix}"
|
||||
|
||||
async def save_doc(self, doc: Document, with_suffix: str = None, dependencies: List[str] = None):
|
||||
"""Save a Document instance as a PDF file.
|
||||
|
||||
This method converts the content of the Document instance to Markdown,
|
||||
saves it to a file with an optional specified suffix, and logs the saved file.
|
||||
|
||||
:param doc: The Document instance to be saved.
|
||||
:type doc: Document
|
||||
:param with_suffix: An optional suffix to append to the saved file's name.
|
||||
:type with_suffix: str, optional
|
||||
:param dependencies: A list of dependencies for the saved file.
|
||||
:type dependencies: List[str], optional
|
||||
"""
|
||||
m = json.loads(doc.content)
|
||||
filename = Path(doc.filename).with_suffix(with_suffix) if with_suffix is not None else Path(doc.filename)
|
||||
await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
|
||||
logger.debug(f"File Saved: {str(filename)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_file(filename: Path | str, relative_path: Path | str = ".") -> Document | None:
|
||||
"""Retrieve a specific file from the file repository.
|
||||
|
||||
:param filename: The name or path of the file to retrieve.
|
||||
:type filename: Path or str
|
||||
:param relative_path: The relative path within the file repository.
|
||||
:type relative_path: Path or str, optional
|
||||
:return: The document representing the file, or None if not found.
|
||||
:rtype: Document or None
|
||||
"""
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
return await file_repo.get(filename=filename)
|
||||
|
||||
@staticmethod
|
||||
async def get_all_files(relative_path: Path | str = ".") -> List[Document]:
|
||||
"""Retrieve all files from the file repository.
|
||||
|
||||
:param relative_path: The relative path within the file repository.
|
||||
:type relative_path: Path or str, optional
|
||||
:return: A list of documents representing all files in the repository.
|
||||
:rtype: List[Document]
|
||||
"""
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
return await file_repo.get_all()
|
||||
|
||||
@staticmethod
|
||||
async def save_file(filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."):
|
||||
"""Save a file to the file repository.
|
||||
|
||||
:param filename: The name or path of the file to save.
|
||||
:type filename: Path or str
|
||||
:param content: The content of the file.
|
||||
:param dependencies: A list of dependencies for the file.
|
||||
:type dependencies: List[str], optional
|
||||
:param relative_path: The relative path within the file repository.
|
||||
:type relative_path: Path or str, optional
|
||||
"""
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
return await file_repo.save(filename=filename, content=content, dependencies=dependencies)
|
||||
|
||||
@staticmethod
|
||||
async def save_as(
|
||||
doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "."
|
||||
):
|
||||
"""Save a Document instance with optional modifications.
|
||||
|
||||
This static method creates a new FileRepository, saves the Document instance
|
||||
with optional modifications (such as a suffix), and logs the saved file.
|
||||
|
||||
:param doc: The Document instance to be saved.
|
||||
:type doc: Document
|
||||
:param with_suffix: An optional suffix to append to the saved file's name.
|
||||
:type with_suffix: str, optional
|
||||
:param dependencies: A list of dependencies for the saved file.
|
||||
:type dependencies: List[str], optional
|
||||
:param relative_path: The relative path within the file repository.
|
||||
:type relative_path: Path or str, optional
|
||||
:return: A boolean indicating whether the save operation was successful.
|
||||
:rtype: bool
|
||||
"""
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
return await file_repo.save_doc(doc=doc, with_suffix=with_suffix, dependencies=dependencies)
|
||||
|
||||
async def delete(self, filename: Path | str):
|
||||
"""Delete a file from the file repository.
|
||||
|
||||
This method deletes a file from the file repository based on the provided filename.
|
||||
|
||||
:param filename: The name or path of the file to be deleted.
|
||||
:type filename: Path or str
|
||||
"""
|
||||
pathname = self.workdir / filename
|
||||
if not pathname.exists():
|
||||
return
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
dependency_file = await self._git_repo.get_dependency()
|
||||
await dependency_file.update(filename=pathname, dependencies=None)
|
||||
logger.info(f"remove dependency key: {str(pathname)}")
|
||||
|
||||
@staticmethod
|
||||
async def delete_file(filename: Path | str, relative_path: Path | str = "."):
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
await file_repo.delete(filename=filename)
|
||||
|
|
@ -8,10 +8,10 @@
|
|||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
def get_template(templates, format=CONFIG.prompt_format):
|
||||
selected_templates = templates.get(format)
|
||||
def get_template(templates, schema=CONFIG.prompt_schema):
|
||||
selected_templates = templates.get(schema)
|
||||
if selected_templates is None:
|
||||
raise ValueError(f"Can't find {format} in passed in templates")
|
||||
raise ValueError(f"Can't find {schema} in passed in templates")
|
||||
|
||||
# Extract the selected templates
|
||||
prompt_template = selected_templates["PROMPT_TEMPLATE"]
|
||||
|
|
|
|||
290
metagpt/utils/git_repository.py
Normal file
290
metagpt/utils/git_repository.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/20
|
||||
@Author : mashenquan
|
||||
@File : git_repository.py
|
||||
@Desc: Git repository management. RFC 135 2.2.3.3.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from git.repo import Repo
|
||||
from git.repo.fun import is_git_dir
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.dependency_file import DependencyFile
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
class ChangeType(Enum):
|
||||
ADDED = "A" # File was added
|
||||
COPIED = "C" # File was copied
|
||||
DELETED = "D" # File was deleted
|
||||
RENAMED = "R" # File was renamed
|
||||
MODIFIED = "M" # File was modified
|
||||
TYPE_CHANGED = "T" # Type of the file was changed
|
||||
UNTRACTED = "U" # File is untracked (not added to version control)
|
||||
|
||||
|
||||
class GitRepository:
|
||||
"""A class representing a Git repository.
|
||||
|
||||
:param local_path: The local path to the Git repository.
|
||||
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
|
||||
|
||||
Attributes:
|
||||
_repository (Repo): The GitPython `Repo` object representing the Git repository.
|
||||
"""
|
||||
|
||||
def __init__(self, local_path=None, auto_init=True):
|
||||
"""Initialize a GitRepository instance.
|
||||
|
||||
:param local_path: The local path to the Git repository.
|
||||
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
|
||||
"""
|
||||
self._repository = None
|
||||
self._dependency = None
|
||||
self._gitignore_rules = None
|
||||
if local_path:
|
||||
self.open(local_path=local_path, auto_init=auto_init)
|
||||
|
||||
def open(self, local_path: Path, auto_init=False):
|
||||
"""Open an existing Git repository or initialize a new one if auto_init is True.
|
||||
|
||||
:param local_path: The local path to the Git repository.
|
||||
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
|
||||
"""
|
||||
local_path = Path(local_path)
|
||||
if self.is_git_dir(local_path):
|
||||
self._repository = Repo(local_path)
|
||||
self._gitignore_rules = parse_gitignore(full_path=str(local_path / ".gitignore"))
|
||||
return
|
||||
if not auto_init:
|
||||
return
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
return self._init(local_path)
|
||||
|
||||
def _init(self, local_path: Path):
|
||||
"""Initialize a new Git repository at the specified path.
|
||||
|
||||
:param local_path: The local path where the new Git repository will be initialized.
|
||||
"""
|
||||
self._repository = Repo.init(path=Path(local_path))
|
||||
|
||||
gitignore_filename = Path(local_path) / ".gitignore"
|
||||
ignores = ["__pycache__", "*.pyc"]
|
||||
with open(str(gitignore_filename), mode="w") as writer:
|
||||
writer.write("\n".join(ignores))
|
||||
self._repository.index.add([".gitignore"])
|
||||
self._repository.index.commit("Add .gitignore")
|
||||
self._gitignore_rules = parse_gitignore(full_path=gitignore_filename)
|
||||
|
||||
def add_change(self, files: Dict):
|
||||
"""Add or remove files from the staging area based on the provided changes.
|
||||
|
||||
:param files: A dictionary where keys are file paths and values are instances of ChangeType.
|
||||
"""
|
||||
if not self.is_valid or not files:
|
||||
return
|
||||
|
||||
for k, v in files.items():
|
||||
self._repository.index.remove(k) if v is ChangeType.DELETED else self._repository.index.add([k])
|
||||
|
||||
def commit(self, comments):
|
||||
"""Commit the staged changes with the given comments.
|
||||
|
||||
:param comments: Comments for the commit.
|
||||
"""
|
||||
if self.is_valid:
|
||||
self._repository.index.commit(comments)
|
||||
|
||||
def delete_repository(self):
|
||||
"""Delete the entire repository directory."""
|
||||
if self.is_valid:
|
||||
shutil.rmtree(self._repository.working_dir)
|
||||
|
||||
@property
|
||||
def changed_files(self) -> Dict[str, str]:
|
||||
"""Return a dictionary of changed files and their change types.
|
||||
|
||||
:return: A dictionary where keys are file paths and values are change types.
|
||||
"""
|
||||
files = {i: ChangeType.UNTRACTED for i in self._repository.untracked_files}
|
||||
changed_files = {f.a_path: ChangeType(f.change_type) for f in self._repository.index.diff(None)}
|
||||
files.update(changed_files)
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def is_git_dir(local_path):
|
||||
"""Check if the specified directory is a Git repository.
|
||||
|
||||
:param local_path: The local path to check.
|
||||
:return: True if the directory is a Git repository, False otherwise.
|
||||
"""
|
||||
git_dir = Path(local_path) / ".git"
|
||||
if git_dir.exists() and is_git_dir(git_dir):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
"""Check if the Git repository is valid (exists and is initialized).
|
||||
|
||||
:return: True if the repository is valid, False otherwise.
|
||||
"""
|
||||
return bool(self._repository)
|
||||
|
||||
@property
|
||||
def status(self) -> str:
|
||||
"""Return the Git repository's status as a string."""
|
||||
if not self.is_valid:
|
||||
return ""
|
||||
return self._repository.git.status()
|
||||
|
||||
@property
|
||||
def workdir(self) -> Path | None:
|
||||
"""Return the path to the working directory of the Git repository.
|
||||
|
||||
:return: The path to the working directory or None if the repository is not valid.
|
||||
"""
|
||||
if not self.is_valid:
|
||||
return None
|
||||
return Path(self._repository.working_dir)
|
||||
|
||||
def archive(self, comments="Archive"):
|
||||
"""Archive the current state of the Git repository.
|
||||
|
||||
:param comments: Comments for the archive commit.
|
||||
"""
|
||||
logger.info(f"Archive: {list(self.changed_files.keys())}")
|
||||
self.add_change(self.changed_files)
|
||||
self.commit(comments)
|
||||
|
||||
def new_file_repository(self, relative_path: Path | str = ".") -> FileRepository:
|
||||
"""Create a new instance of FileRepository associated with this Git repository.
|
||||
|
||||
:param relative_path: The relative path to the file repository within the Git repository.
|
||||
:return: A new instance of FileRepository.
|
||||
"""
|
||||
path = Path(relative_path)
|
||||
try:
|
||||
path = path.relative_to(self.workdir)
|
||||
except ValueError:
|
||||
path = relative_path
|
||||
return FileRepository(git_repo=self, relative_path=Path(path))
|
||||
|
||||
async def get_dependency(self) -> DependencyFile:
|
||||
"""Get the dependency file associated with the Git repository.
|
||||
|
||||
:return: An instance of DependencyFile.
|
||||
"""
|
||||
if not self._dependency:
|
||||
self._dependency = DependencyFile(workdir=self.workdir)
|
||||
return self._dependency
|
||||
|
||||
def rename_root(self, new_dir_name):
|
||||
"""Rename the root directory of the Git repository.
|
||||
|
||||
:param new_dir_name: The new name for the root directory.
|
||||
"""
|
||||
if self.workdir.name == new_dir_name:
|
||||
return
|
||||
new_path = self.workdir.parent / new_dir_name
|
||||
if new_path.exists():
|
||||
logger.info(f"Delete directory {str(new_path)}")
|
||||
shutil.rmtree(new_path)
|
||||
try:
|
||||
shutil.move(src=str(self.workdir), dst=str(new_path))
|
||||
except Exception as e:
|
||||
logger.warning(f"Move {str(self.workdir)} to {str(new_path)} error: {e}")
|
||||
logger.info(f"Rename directory {str(self.workdir)} to {str(new_path)}")
|
||||
self._repository = Repo(new_path)
|
||||
self._gitignore_rules = parse_gitignore(full_path=str(new_path / ".gitignore"))
|
||||
|
||||
def get_files(self, relative_path: Path | str, root_relative_path: Path | str = None, filter_ignored=True) -> List:
|
||||
"""
|
||||
Retrieve a list of files in the specified relative path.
|
||||
|
||||
The method returns a list of file paths relative to the current FileRepository.
|
||||
|
||||
:param relative_path: The relative path within the repository.
|
||||
:type relative_path: Path or str
|
||||
:param root_relative_path: The root relative path within the repository.
|
||||
:type root_relative_path: Path or str
|
||||
:param filter_ignored: Flag to indicate whether to filter files based on .gitignore rules.
|
||||
:type filter_ignored: bool
|
||||
:return: A list of file paths in the specified directory.
|
||||
:rtype: List[str]
|
||||
"""
|
||||
try:
|
||||
relative_path = Path(relative_path).relative_to(self.workdir)
|
||||
except ValueError:
|
||||
relative_path = Path(relative_path)
|
||||
|
||||
if not root_relative_path:
|
||||
root_relative_path = Path(self.workdir) / relative_path
|
||||
files = []
|
||||
try:
|
||||
directory_path = Path(self.workdir) / relative_path
|
||||
if not directory_path.exists():
|
||||
return []
|
||||
for file_path in directory_path.iterdir():
|
||||
if file_path.is_file():
|
||||
rpath = file_path.relative_to(root_relative_path)
|
||||
files.append(str(rpath))
|
||||
else:
|
||||
subfolder_files = self.get_files(
|
||||
relative_path=file_path, root_relative_path=root_relative_path, filter_ignored=False
|
||||
)
|
||||
files.extend(subfolder_files)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
if not filter_ignored:
|
||||
return files
|
||||
filtered_files = self.filter_gitignore(filenames=files, root_relative_path=root_relative_path)
|
||||
return filtered_files
|
||||
|
||||
def filter_gitignore(self, filenames: List[str], root_relative_path: Path | str = None) -> List[str]:
|
||||
"""
|
||||
Filter a list of filenames based on .gitignore rules.
|
||||
|
||||
:param filenames: A list of filenames to be filtered.
|
||||
:type filenames: List[str]
|
||||
:param root_relative_path: The root relative path within the repository.
|
||||
:type root_relative_path: Path or str
|
||||
:return: A list of filenames that pass the .gitignore filtering.
|
||||
:rtype: List[str]
|
||||
"""
|
||||
if root_relative_path is None:
|
||||
root_relative_path = self.workdir
|
||||
files = []
|
||||
for filename in filenames:
|
||||
pathname = root_relative_path / filename
|
||||
if self._gitignore_rules(str(pathname)):
|
||||
continue
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = DEFAULT_WORKSPACE_ROOT / "git"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
repo = GitRepository()
|
||||
repo.open(path, auto_init=True)
|
||||
repo.filter_gitignore(filenames=["snake_game/snake_game/__pycache__", "snake_game/snake_game/game.py"])
|
||||
|
||||
changes = repo.changed_files
|
||||
print(changes)
|
||||
repo.add_change(changes)
|
||||
print(repo.status)
|
||||
repo.commit("test")
|
||||
print(repo.status)
|
||||
repo.delete_repository()
|
||||
|
|
@ -1,22 +1,22 @@
|
|||
# 添加代码语法高亮显示
|
||||
from pygments import highlight as highlight_
|
||||
from pygments.formatters import HtmlFormatter, TerminalFormatter
|
||||
from pygments.lexers import PythonLexer, SqlLexer
|
||||
from pygments.formatters import TerminalFormatter, HtmlFormatter
|
||||
|
||||
|
||||
def highlight(code: str, language: str = 'python', formatter: str = 'terminal'):
|
||||
def highlight(code: str, language: str = "python", formatter: str = "terminal"):
|
||||
# 指定要高亮的语言
|
||||
if language.lower() == 'python':
|
||||
if language.lower() == "python":
|
||||
lexer = PythonLexer()
|
||||
elif language.lower() == 'sql':
|
||||
elif language.lower() == "sql":
|
||||
lexer = SqlLexer()
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
# 指定输出格式
|
||||
if formatter.lower() == 'terminal':
|
||||
if formatter.lower() == "terminal":
|
||||
formatter = TerminalFormatter()
|
||||
elif formatter.lower() == 'html':
|
||||
elif formatter.lower() == "html":
|
||||
formatter = HtmlFormatter()
|
||||
else:
|
||||
raise ValueError(f"Unsupported formatter: {formatter}")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import check_cmd_exists
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
if stdout:
|
||||
logger.info(stdout.decode())
|
||||
if stderr:
|
||||
logger.error(stderr.decode())
|
||||
logger.warning(stderr.decode())
|
||||
else:
|
||||
if engine == "playwright":
|
||||
from metagpt.utils.mmdc_playwright import mermaid_to_file
|
||||
|
|
@ -141,6 +141,6 @@ MMC2 = """sequenceDiagram
|
|||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC1, PROJECT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC2, PROJECT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC1, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
loop.close()
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@
|
|||
@File : mermaid.py
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
|
||||
from aiohttp import ClientSession,ClientError
|
||||
from aiohttp import ClientError, ClientSession
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix):
|
|||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
text = await response.content.read()
|
||||
with open(output_file, 'wb') as f:
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(text)
|
||||
logger.info(f"Generating {output_file}..")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@
|
|||
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
"""
|
||||
Converts the given Mermaid code to various output formats and saves them to files.
|
||||
|
||||
|
|
@ -24,66 +27,72 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
Returns:
|
||||
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
|
||||
"""
|
||||
suffixes=['png', 'svg', 'pdf']
|
||||
suffixes = ["png", "svg", "pdf"]
|
||||
__dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch()
|
||||
device_scale_factor = 1.0
|
||||
context = await browser.new_context(
|
||||
viewport={'width': width, 'height': height},
|
||||
device_scale_factor=device_scale_factor,
|
||||
)
|
||||
viewport={"width": width, "height": height},
|
||||
device_scale_factor=device_scale_factor,
|
||||
)
|
||||
page = await context.new_page()
|
||||
|
||||
async def console_message(msg):
|
||||
logger.info(msg.text)
|
||||
page.on('console', console_message)
|
||||
|
||||
page.on("console", console_message)
|
||||
|
||||
try:
|
||||
await page.set_viewport_size({'width': width, 'height': height})
|
||||
await page.set_viewport_size({"width": width, "height": height})
|
||||
|
||||
mermaid_html_path = os.path.abspath(
|
||||
os.path.join(__dirname, 'index.html'))
|
||||
mermaid_html_url = urljoin('file:', mermaid_html_path)
|
||||
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
|
||||
mermaid_html_url = urljoin("file:", mermaid_html_path)
|
||||
await page.goto(mermaid_html_url)
|
||||
await page.wait_for_load_state("networkidle")
|
||||
|
||||
await page.wait_for_selector("div#container", state="attached")
|
||||
mermaid_config = {}
|
||||
# mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
my_css = ""
|
||||
# my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
#
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
|
||||
}''', [mermaid_code, mermaid_config, my_css, background_color])
|
||||
|
||||
if 'svg' in suffixes :
|
||||
svg_xml = await page.evaluate('''() => {
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
"""() => {
|
||||
const svg = document.querySelector('svg');
|
||||
const xmlSerializer = new XMLSerializer();
|
||||
return xmlSerializer.serializeToString(svg);
|
||||
}''')
|
||||
}"""
|
||||
)
|
||||
logger.info(f"Generating {output_file_without_suffix}.svg..")
|
||||
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
|
||||
f.write(svg_xml.encode('utf-8'))
|
||||
with open(f"{output_file_without_suffix}.svg", "wb") as f:
|
||||
f.write(svg_xml.encode("utf-8"))
|
||||
|
||||
if 'png' in suffixes:
|
||||
clip = await page.evaluate('''() => {
|
||||
if "png" in suffixes:
|
||||
clip = await page.evaluate(
|
||||
"""() => {
|
||||
const svg = document.querySelector('svg');
|
||||
const rect = svg.getBoundingClientRect();
|
||||
return {
|
||||
|
|
@ -92,16 +101,17 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
width: Math.ceil(rect.width),
|
||||
height: Math.ceil(rect.height)
|
||||
};
|
||||
}''')
|
||||
await page.set_viewport_size({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height']})
|
||||
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
|
||||
}"""
|
||||
)
|
||||
await page.set_viewport_size({"width": clip["x"] + clip["width"], "height": clip["y"] + clip["height"]})
|
||||
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
|
||||
logger.info(f"Generating {output_file_without_suffix}.png..")
|
||||
with open(f'{output_file_without_suffix}.png', 'wb') as f:
|
||||
with open(f"{output_file_without_suffix}.png", "wb") as f:
|
||||
f.write(screenshot)
|
||||
if 'pdf' in suffixes:
|
||||
if "pdf" in suffixes:
|
||||
pdf_data = await page.pdf(scale=device_scale_factor)
|
||||
logger.info(f"Generating {output_file_without_suffix}.pdf..")
|
||||
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
|
||||
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
|
||||
f.write(pdf_data)
|
||||
return 0
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -7,11 +7,14 @@
|
|||
"""
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
from pyppeteer import launch
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
|
||||
from pyppeteer import launch
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
"""
|
||||
Converts the given Mermaid code to various output formats and saves them to files.
|
||||
|
||||
|
|
@ -24,15 +27,15 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
Returns:
|
||||
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
|
||||
"""
|
||||
suffixes = ['png', 'svg', 'pdf']
|
||||
suffixes = ["png", "svg", "pdf"]
|
||||
__dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
if CONFIG.pyppeteer_executable_path:
|
||||
browser = await launch(headless=True,
|
||||
executablePath=CONFIG.pyppeteer_executable_path,
|
||||
args=['--disable-extensions',"--no-sandbox"]
|
||||
)
|
||||
browser = await launch(
|
||||
headless=True,
|
||||
executablePath=CONFIG.pyppeteer_executable_path,
|
||||
args=["--disable-extensions", "--no-sandbox"],
|
||||
)
|
||||
else:
|
||||
logger.error("Please set the environment variable:PYPPETEER_EXECUTABLE_PATH.")
|
||||
return -1
|
||||
|
|
@ -41,50 +44,56 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
|
||||
async def console_message(msg):
|
||||
logger.info(msg.text)
|
||||
page.on('console', console_message)
|
||||
|
||||
page.on("console", console_message)
|
||||
|
||||
try:
|
||||
await page.setViewport(viewport={'width': width, 'height': height, 'deviceScaleFactor': device_scale_factor})
|
||||
await page.setViewport(viewport={"width": width, "height": height, "deviceScaleFactor": device_scale_factor})
|
||||
|
||||
mermaid_html_path = os.path.abspath(
|
||||
os.path.join(__dirname, 'index.html'))
|
||||
mermaid_html_url = urljoin('file:', mermaid_html_path)
|
||||
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
|
||||
mermaid_html_url = urljoin("file:", mermaid_html_path)
|
||||
await page.goto(mermaid_html_url)
|
||||
|
||||
await page.querySelector("div#container")
|
||||
mermaid_config = {}
|
||||
# mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
my_css = ""
|
||||
# my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
}''', [mermaid_code, mermaid_config, my_css, background_color])
|
||||
|
||||
if 'svg' in suffixes :
|
||||
svg_xml = await page.evaluate('''() => {
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
"""() => {
|
||||
const svg = document.querySelector('svg');
|
||||
const xmlSerializer = new XMLSerializer();
|
||||
return xmlSerializer.serializeToString(svg);
|
||||
}''')
|
||||
}"""
|
||||
)
|
||||
logger.info(f"Generating {output_file_without_suffix}.svg..")
|
||||
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
|
||||
f.write(svg_xml.encode('utf-8'))
|
||||
with open(f"{output_file_without_suffix}.svg", "wb") as f:
|
||||
f.write(svg_xml.encode("utf-8"))
|
||||
|
||||
if 'png' in suffixes:
|
||||
clip = await page.evaluate('''() => {
|
||||
if "png" in suffixes:
|
||||
clip = await page.evaluate(
|
||||
"""() => {
|
||||
const svg = document.querySelector('svg');
|
||||
const rect = svg.getBoundingClientRect();
|
||||
return {
|
||||
|
|
@ -93,16 +102,23 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
width: Math.ceil(rect.width),
|
||||
height: Math.ceil(rect.height)
|
||||
};
|
||||
}''')
|
||||
await page.setViewport({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height'], 'deviceScaleFactor': device_scale_factor})
|
||||
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
|
||||
}"""
|
||||
)
|
||||
await page.setViewport(
|
||||
{
|
||||
"width": clip["x"] + clip["width"],
|
||||
"height": clip["y"] + clip["height"],
|
||||
"deviceScaleFactor": device_scale_factor,
|
||||
}
|
||||
)
|
||||
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
|
||||
logger.info(f"Generating {output_file_without_suffix}.png..")
|
||||
with open(f'{output_file_without_suffix}.png', 'wb') as f:
|
||||
with open(f"{output_file_without_suffix}.png", "wb") as f:
|
||||
f.write(screenshot)
|
||||
if 'pdf' in suffixes:
|
||||
if "pdf" in suffixes:
|
||||
pdf_data = await page.pdf(scale=device_scale_factor)
|
||||
logger.info(f"Generating {output_file_without_suffix}.pdf..")
|
||||
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
|
||||
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
|
||||
f.write(pdf_data)
|
||||
return 0
|
||||
except Exception as e:
|
||||
|
|
@ -110,4 +126,3 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
return -1
|
||||
finally:
|
||||
await browser.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class WebPage(BaseModel):
|
|||
class Config:
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
_soup : Optional[BeautifulSoup] = None
|
||||
_soup: Optional[BeautifulSoup] = None
|
||||
_title: Optional[str] = None
|
||||
|
||||
@property
|
||||
|
|
@ -24,7 +24,7 @@ class WebPage(BaseModel):
|
|||
if self._soup is None:
|
||||
self._soup = BeautifulSoup(self.html, "html.parser")
|
||||
return self._soup
|
||||
|
||||
|
||||
@property
|
||||
def title(self):
|
||||
if self._title is None:
|
||||
|
|
|
|||
|
|
@ -37,12 +37,12 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
|
|||
|
||||
if not isinstance(expr, cst.Expr):
|
||||
return None
|
||||
|
||||
|
||||
val = expr.value
|
||||
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
|
||||
return None
|
||||
|
||||
evaluated_value = val.evaluated_value
|
||||
|
||||
evaluated_value = val.evaluated_value
|
||||
if isinstance(evaluated_value, bytes):
|
||||
return None
|
||||
|
||||
|
|
@ -56,6 +56,7 @@ class DocstringCollector(cst.CSTVisitor):
|
|||
stack: A list to keep track of the current path in the CST.
|
||||
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.stack: list[str] = []
|
||||
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
|
||||
|
|
@ -96,6 +97,7 @@ class DocstringTransformer(cst.CSTTransformer):
|
|||
stack: A list to keep track of the current path in the CST.
|
||||
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
|
||||
|
|
@ -125,7 +127,9 @@ class DocstringTransformer(cst.CSTTransformer):
|
|||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
|
||||
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
|
||||
if hasattr(updated_node, "decorators") and any(
|
||||
(i.decorator.value == "overload") for i in updated_node.decorators
|
||||
):
|
||||
return updated_node
|
||||
|
||||
statement = self.docstrings.get(key)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
import docx
|
||||
|
||||
|
||||
def read_docx(file_path: str) -> list:
|
||||
"""Open a docx file"""
|
||||
doc = docx.Document(file_path)
|
||||
|
|
|
|||
310
metagpt/utils/repair_llm_raw_output.py
Normal file
310
metagpt/utils/repair_llm_raw_output.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : repair llm raw output with particular conditions
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Callable, Union
|
||||
|
||||
import regex as re
|
||||
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.custom_decoder import CustomDecoder
|
||||
|
||||
|
||||
class RepairType(Enum):
|
||||
CS = "case sensitivity"
|
||||
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
|
||||
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
|
||||
JSON = "json format"
|
||||
|
||||
|
||||
def repair_case_sensitivity(output: str, req_key: str) -> str:
|
||||
"""
|
||||
usually, req_key is the key name of expected json or markdown content, it won't appear in the value part.
|
||||
fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually
|
||||
"""
|
||||
if req_key in output:
|
||||
return output
|
||||
|
||||
output_lower = output.lower()
|
||||
req_key_lower = req_key.lower()
|
||||
if req_key_lower in output_lower:
|
||||
# find the sub-part index, and replace it with raw req_key
|
||||
lidx = output_lower.find(req_key_lower)
|
||||
source = output[lidx : lidx + len(req_key_lower)]
|
||||
output = output.replace(source, req_key)
|
||||
logger.info(f"repair_case_sensitivity: {req_key}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
fix
|
||||
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
|
||||
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
|
||||
"""
|
||||
sc_arr = ["/"]
|
||||
|
||||
if req_key in output:
|
||||
return output
|
||||
|
||||
for sc in sc_arr:
|
||||
req_key_pure = req_key.replace(sc, "")
|
||||
appear_cnt = output.count(req_key_pure)
|
||||
if req_key_pure in output and appear_cnt > 1:
|
||||
# req_key with special_character usually in the tail side
|
||||
ridx = output.rfind(req_key_pure)
|
||||
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
|
||||
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
implement the req_key pair in the begin or end of the content
|
||||
req_key format
|
||||
1. `[req_key]`, and its pair `[/req_key]`
|
||||
2. `[/req_key]`, and its pair `[req_key]`
|
||||
"""
|
||||
sc = "/" # special char
|
||||
if req_key.startswith("[") and req_key.endswith("]"):
|
||||
if sc in req_key:
|
||||
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
|
||||
right_key = req_key
|
||||
else:
|
||||
left_key = req_key
|
||||
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
|
||||
|
||||
if left_key not in output:
|
||||
output = left_key + "\n" + output
|
||||
if right_key not in output:
|
||||
|
||||
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
|
||||
ridx = routput.rfind(left_key)
|
||||
if ridx < 0:
|
||||
return None
|
||||
sub_output = routput[ridx:]
|
||||
idx1 = sub_output.rfind("}")
|
||||
idx2 = sub_output.rindex("]")
|
||||
idx = idx1 if idx1 >= idx2 else idx2
|
||||
sub_output = sub_output[: idx + 1]
|
||||
return sub_output
|
||||
|
||||
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
|
||||
# # avoid [req_key]xx[req_key] case to append [/req_key]
|
||||
output = output + "\n" + right_key
|
||||
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
|
||||
sub_content = judge_potential_json(output, left_key)
|
||||
output = sub_content + "\n" + right_key
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_json_format(output: str) -> str:
|
||||
"""
|
||||
fix extra `[` or `}` in the end
|
||||
"""
|
||||
output = output.strip()
|
||||
|
||||
if output.startswith("[{"):
|
||||
output = output[1:]
|
||||
logger.info(f"repair_json_format: {'[{'}")
|
||||
elif output.endswith("}]"):
|
||||
output = output[:-1]
|
||||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str:
|
||||
repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]]
|
||||
for repair_type in repair_types:
|
||||
if repair_type == RepairType.CS:
|
||||
output = repair_case_sensitivity(output, req_key)
|
||||
elif repair_type == RepairType.RKPM:
|
||||
output = repair_required_key_pair_missing(output, req_key)
|
||||
elif repair_type == RepairType.SCM:
|
||||
output = repair_special_character_missing(output, req_key)
|
||||
elif repair_type == RepairType.JSON:
|
||||
output = repair_json_format(output)
|
||||
return output
|
||||
|
||||
|
||||
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
|
||||
"""
|
||||
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
|
||||
so here we try to repair it and use all repair methods by default.
|
||||
typical case
|
||||
1. case sensitivity
|
||||
target: "Original Requirements"
|
||||
output: "Original requirements"
|
||||
2. special character missing
|
||||
target: [/CONTENT]
|
||||
output: [CONTENT]
|
||||
3. json format
|
||||
target: { xxx }
|
||||
output: { xxx }]
|
||||
"""
|
||||
if not CONFIG.repair_llm_output:
|
||||
return output
|
||||
|
||||
# do the repairation usually for non-openai models
|
||||
for req_key in req_keys:
|
||||
output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type)
|
||||
return output
|
||||
|
||||
|
||||
def repair_invalid_json(output: str, error: str) -> str:
|
||||
"""
|
||||
repair the situation like there are extra chars like
|
||||
error examples
|
||||
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
|
||||
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
|
||||
"""
|
||||
pattern = r"line ([0-9]+)"
|
||||
|
||||
matches = re.findall(pattern, error, re.DOTALL)
|
||||
if len(matches) > 0:
|
||||
line_no = int(matches[0]) - 1
|
||||
|
||||
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
|
||||
output = output.replace('"""', '"').replace("'''", '"')
|
||||
arr = output.split("\n")
|
||||
line = arr[line_no].strip()
|
||||
# different general problems
|
||||
if line.endswith("],"):
|
||||
# problem, redundant char `]`
|
||||
new_line = line.replace("]", "")
|
||||
elif line.endswith("},") and not output.endswith("},"):
|
||||
# problem, redundant char `}`
|
||||
new_line = line.replace("}", "")
|
||||
elif line.endswith("},") and output.endswith("},"):
|
||||
new_line = line[:-1]
|
||||
elif '",' not in line and "," not in line:
|
||||
new_line = f'{line}",'
|
||||
elif "," not in line:
|
||||
# problem, miss char `,` at the end.
|
||||
new_line = f"{line},"
|
||||
elif "," in line and len(line) == 1:
|
||||
new_line = f'"{line}'
|
||||
elif '",' in line:
|
||||
new_line = line[:-2] + "',"
|
||||
|
||||
arr[line_no] = new_line
|
||||
output = "\n".join(arr)
|
||||
logger.info(f"repair_invalid_json, raw error: {error}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
|
||||
def run_and_passon(retry_state: RetryCallState) -> None:
|
||||
"""
|
||||
RetryCallState example
|
||||
{
|
||||
"start_time":143.098322024,
|
||||
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
|
||||
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
|
||||
"args":"(\"tag:[/CONTENT]\",)", # function input args
|
||||
"kwargs":{}, # function input kwargs
|
||||
"attempt_number":1, # retry number
|
||||
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
|
||||
"outcome_timestamp":143.098416904,
|
||||
"idle_for":0,
|
||||
"next_action":"None"
|
||||
}
|
||||
"""
|
||||
if retry_state.outcome.failed:
|
||||
if retry_state.args:
|
||||
# # can't be used as args=retry_state.args
|
||||
func_param_output = retry_state.args[0]
|
||||
elif retry_state.kwargs:
|
||||
func_param_output = retry_state.kwargs.get("output", "")
|
||||
exp_str = str(retry_state.outcome.exception())
|
||||
logger.warning(
|
||||
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
|
||||
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
|
||||
)
|
||||
|
||||
repaired_output = repair_invalid_json(func_param_output, exp_str)
|
||||
retry_state.kwargs["output"] = repaired_output
|
||||
|
||||
return run_and_passon
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
|
||||
wait=wait_fixed(1),
|
||||
after=run_after_exp_and_passon_next_retry(logger),
|
||||
)
|
||||
def retry_parse_json_text(output: str) -> Union[list, dict]:
|
||||
"""
|
||||
repair the json-text situation like there are extra chars like [']', '}']
|
||||
|
||||
Warning
|
||||
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
|
||||
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
|
||||
it's a two-layer retry cycle
|
||||
"""
|
||||
# logger.debug(f"output to json decode:\n{output}")
|
||||
|
||||
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
|
||||
parsed_data = CustomDecoder(strict=False).decode(output)
|
||||
|
||||
return parsed_data
|
||||
|
||||
|
||||
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
|
||||
"""extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern"""
|
||||
|
||||
def re_extract_content(cont: str, pattern: str) -> str:
|
||||
matches = re.findall(pattern, cont, re.DOTALL)
|
||||
for match in matches:
|
||||
if match:
|
||||
cont = match
|
||||
break
|
||||
return cont.strip()
|
||||
|
||||
# TODO construct the extract pattern with the `right_key`
|
||||
raw_content = copy.deepcopy(content)
|
||||
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
|
||||
if not new_content.startswith("{"):
|
||||
# TODO find a more general pattern
|
||||
# # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
|
||||
logger.warning(f"extract_content try another pattern: {pattern}")
|
||||
if right_key not in new_content:
|
||||
raw_content = copy.deepcopy(new_content + "\n" + right_key)
|
||||
# # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
else:
|
||||
if right_key in new_content:
|
||||
idx = new_content.find(right_key)
|
||||
new_content = new_content[:idx]
|
||||
new_content = new_content.strip()
|
||||
|
||||
return new_content
|
||||
|
||||
|
||||
def extract_state_value_from_output(content: str) -> str:
|
||||
"""
|
||||
For openai models, they will always return state number. But for open llm models, the instruction result maybe a
|
||||
long text contain target number, so here add a extraction to improve success rate.
|
||||
|
||||
Args:
|
||||
content (str): llm's output from `Role._think`
|
||||
"""
|
||||
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
|
||||
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
|
||||
matches = re.findall(pattern, content, re.DOTALL)
|
||||
matches = list(set(matches))
|
||||
state = matches[0] if len(matches) > 0 else "-1"
|
||||
return state
|
||||
|
|
@ -4,13 +4,11 @@
|
|||
|
||||
import copy
|
||||
import pickle
|
||||
from typing import Dict, List
|
||||
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
||||
def actionoutout_schema_to_mapping(schema: dict) -> dict:
|
||||
"""
|
||||
directly traverse the `properties` in the first level.
|
||||
schema structure likes
|
||||
|
|
@ -35,14 +33,31 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
|||
if property["type"] == "string":
|
||||
mapping[field] = (str, ...)
|
||||
elif property["type"] == "array" and property["items"]["type"] == "string":
|
||||
mapping[field] = (List[str], ...)
|
||||
mapping[field] = (list[str], ...)
|
||||
elif property["type"] == "array" and property["items"]["type"] == "array":
|
||||
# here only consider the `List[List[str]]` situation
|
||||
mapping[field] = (List[List[str]], ...)
|
||||
# here only consider the `list[list[str]]` situation
|
||||
mapping[field] = (list[list[str]], ...)
|
||||
return mapping
|
||||
|
||||
|
||||
def serialize_message(message: Message):
|
||||
def actionoutput_mapping_to_str(mapping: dict) -> dict:
|
||||
new_mapping = {}
|
||||
for key, value in mapping.items():
|
||||
new_mapping[key] = str(value)
|
||||
return new_mapping
|
||||
|
||||
|
||||
def actionoutput_str_to_mapping(mapping: dict) -> dict:
|
||||
new_mapping = {}
|
||||
for key, value in mapping.items():
|
||||
if value == "(<class 'str'>, Ellipsis)":
|
||||
new_mapping[key] = (str, ...)
|
||||
else:
|
||||
new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)`
|
||||
return new_mapping
|
||||
|
||||
|
||||
def serialize_message(message: "Message"):
|
||||
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference
|
||||
ic = message_cp.instruct_content
|
||||
if ic:
|
||||
|
|
@ -56,11 +71,12 @@ def serialize_message(message: Message):
|
|||
return msg_ser
|
||||
|
||||
|
||||
def deserialize_message(message_ser: str) -> Message:
|
||||
def deserialize_message(message_ser: str) -> "Message":
|
||||
message = pickle.loads(message_ser)
|
||||
if message.instruct_content:
|
||||
ic = message.instruct_content
|
||||
ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
ic_new = ic_obj(**ic["value"])
|
||||
message.instruct_content = ic_new
|
||||
|
||||
|
|
|
|||
|
|
@ -20,4 +20,3 @@ class Singleton(abc.ABCMeta, type):
|
|||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# token to separate different code messages in a WriteCode Message content
|
||||
MSG_SEP = "#*000*#"
|
||||
MSG_SEP = "#*000*#"
|
||||
# token to seperate file name and the actual code text in a code message
|
||||
FILENAME_CODE_SEP = "#*001*#"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@ from typing import Generator, Sequence
|
|||
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
|
||||
|
||||
|
||||
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
|
||||
def reduce_message_length(
|
||||
msgs: Generator[str, None, None],
|
||||
model_name: str,
|
||||
system_text: str,
|
||||
reserved: int = 0,
|
||||
) -> str:
|
||||
"""Reduce the length of concatenated message segments to fit within the maximum token size.
|
||||
|
||||
Args:
|
||||
|
|
@ -49,9 +54,9 @@ def generate_prompt_chunk(
|
|||
current_token = 0
|
||||
current_lines = []
|
||||
|
||||
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
|
||||
reserved = reserved + count_string_tokens(prompt_template + system_text, model_name)
|
||||
# 100 is a magic number to ensure the maximum context length is not exceeded
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
|
||||
|
||||
while paragraphs:
|
||||
paragraph = paragraphs.pop(0)
|
||||
|
|
@ -103,7 +108,7 @@ def decode_unicode_escape(text: str) -> str:
|
|||
return text.encode("utf-8").decode("unicode_escape", "ignore")
|
||||
|
||||
|
||||
def _split_by_count(lst: Sequence , count: int):
|
||||
def _split_by_count(lst: Sequence, count: int):
|
||||
avg = len(lst) // count
|
||||
remainder = len(lst) % count
|
||||
start = 0
|
||||
|
|
|
|||
|
|
@ -18,11 +18,13 @@ TOKEN_COSTS = {
|
|||
"gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004},
|
||||
"gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004},
|
||||
"gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002},
|
||||
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
|
||||
"gpt-4": {"prompt": 0.03, "completion": 0.06},
|
||||
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
}
|
||||
|
|
@ -36,11 +38,13 @@ TOKEN_MAX = {
|
|||
"gpt-3.5-turbo-16k-0613": 16384,
|
||||
"gpt-35-turbo": 4096,
|
||||
"gpt-35-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-1106": 16384,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"text-embedding-ada-002": 8192,
|
||||
"chatglm_turbo": 32768,
|
||||
}
|
||||
|
|
@ -58,20 +62,23 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-35-turbo",
|
||||
"gpt-35-turbo-16k",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
elif "gpt-3.5-turbo" == model:
|
||||
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
elif "gpt-4" == model:
|
||||
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue