Merge branch 'main' into feature-openai-v1

This commit is contained in:
seehi 2023-12-21 12:06:12 +08:00 committed by GitHub
commit 9a4f0d555c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
260 changed files with 10576 additions and 3191 deletions

View file

@ -0,0 +1,49 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : pure async http_client
from typing import Any, Mapping, Optional, Union
import aiohttp
from aiohttp.client import DEFAULT_TIMEOUT
async def apost(
url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
as_json: bool = False,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total,
) -> Union[str, dict]:
async with aiohttp.ClientSession() as session:
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
if as_json:
data = await resp.json()
else:
data = await resp.read()
data = data.decode(encoding)
return data
async def apost_stream(
url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total,
) -> Any:
"""
usage:
result = astream(url="xx")
async for line in result:
deal_with(line)
"""
async with aiohttp.ClientSession() as session:
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
async for line in resp.content:
yield line.decode(encoding)

View file

@ -4,16 +4,34 @@
@Time : 2023/4/29 16:07
@Author : alexanderwu
@File : common.py
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.2 of RFC 116:
Add generic class-to-string and object-to-string conversion functionality.
@Modified By: mashenquan, 2023/11/27. Bug fix: `parse_recipient` failed to parse the recipient in certain GPT-3.5
responses.
"""
from __future__ import annotations
import ast
import contextlib
import importlib
import inspect
import json
import os
import platform
import re
from typing import List, Tuple, Union
import traceback
import typing
from pathlib import Path
from typing import Any, List, Tuple, Union, get_args, get_origin
import aiofiles
import loguru
from pydantic.json import pydantic_encoder
from tenacity import RetryCallState, _utils
from metagpt.const import MESSAGE_ROUTE_TO_ALL
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
def check_cmd_exists(command) -> int:
@ -85,10 +103,7 @@ class OutputParser:
@staticmethod
def parse_python_code(text: str) -> str:
for pattern in (
r"(.*?```python.*?\s+)?(?P<code>.*)(```.*?)",
r"(.*?```python.*?\s+)?(?P<code>.*)",
):
for pattern in (r"(.*?```python.*?\s+)?(?P<code>.*)(```.*?)", r"(.*?```python.*?\s+)?(?P<code>.*)"):
match = re.search(pattern, text, re.DOTALL)
if not match:
continue
@ -119,8 +134,32 @@ class OutputParser:
parsed_data[block] = content
return parsed_data
@staticmethod
def extract_content(text, tag="CONTENT"):
# Use regular expression to extract content between [CONTENT] and [/CONTENT]
extracted_content = re.search(rf"\[{tag}\](.*?)\[/{tag}\]", text, re.DOTALL)
if extracted_content:
return extracted_content.group(1).strip()
else:
return "No content found between [CONTENT] and [/CONTENT] tags."
@staticmethod
def is_supported_list_type(i):
origin = get_origin(i)
if origin is not List:
return False
args = get_args(i)
if args == (str,) or args == (Tuple[str, str],) or args == (List[str],):
return True
return False
@classmethod
def parse_data_with_mapping(cls, data, mapping):
if "[CONTENT]" in data:
data = cls.extract_content(text=data)
block_dict = cls.parse_blocks(data)
parsed_data = {}
for block, content in block_dict.items():
@ -187,7 +226,7 @@ class OutputParser:
result = ast.literal_eval(structure_text)
# Ensure the result matches the specified data type
if isinstance(result, list) or isinstance(result, dict):
if isinstance(result, (list, dict)):
return result
raise ValueError(f"The extracted structure is not a {data_type}.")
@ -219,10 +258,15 @@ class CodeParser:
# 遍历所有的block
for block in blocks:
# 如果block不为空则继续处理
if block.strip() != "":
if block.strip() == "":
continue
if "\n" not in block:
block_title = block
block_content = ""
else:
# 将block的标题和内容分开并分别去掉前后的空白字符
block_title, block_content = block.split("\n", 1)
block_dict[block_title.strip()] = block_content.strip()
block_dict[block_title.strip()] = block_content.strip()
return block_dict
@ -282,9 +326,6 @@ class NoMoneyException(Exception):
def print_members(module, indent=0):
"""
https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python
:param module:
:param indent:
:return:
"""
prefix = " " * indent
for name, obj in inspect.getmembers(module):
@ -302,6 +343,173 @@ def print_members(module, indent=0):
def parse_recipient(text):
# FIXME: use ActionNode instead.
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
recipient = re.search(pattern, text)
return recipient.group(1) if recipient else ""
if recipient:
return recipient.group(1)
pattern = r"Send To:\s*([A-Za-z]+)\s*?"
recipient = re.search(pattern, text)
if recipient:
return recipient.group(1)
return ""
def get_class_name(cls) -> str:
"""Return class name"""
return f"{cls.__module__}.{cls.__name__}"
def any_to_str(val: str | typing.Callable) -> str:
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
if isinstance(val, str):
return val
if not callable(val):
return get_class_name(type(val))
return get_class_name(val)
def any_to_str_set(val) -> set:
"""Convert any type to string set."""
res = set()
# Check if the value is iterable, but not a string (since strings are technically iterable)
if isinstance(val, (dict, list, set, tuple)):
# Special handling for dictionaries to iterate over values
if isinstance(val, dict):
val = val.values()
for i in val:
res.add(any_to_str(i))
else:
res.add(any_to_str(val))
return res
def is_subscribed(message: "Message", tags: set):
"""Return whether it's consumer"""
if MESSAGE_ROUTE_TO_ALL in message.send_to:
return True
for i in tags:
if i in message.send_to:
return True
return False
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
This generated function logs an error message with the outcome of the retried function call. It includes
the name of the function, the time taken for the call in seconds (formatted according to `sec_format`),
the number of attempts made, and the exception raised, if any.
:param i: A Logger instance from the loguru library used to log the error message.
:param sec_format: A string format specifier for how to format the number of seconds since the start of the call.
Defaults to three decimal places.
:return: A callable that accepts a RetryCallState object and returns None. This callable logs the details
of the retried call.
"""
def log_it(retry_state: "RetryCallState") -> None:
# If the function name is not known, default to "<unknown>"
if retry_state.fn is None:
fn_name = "<unknown>"
else:
# Retrieve the callable's name using a utility function
fn_name = _utils.get_callback_name(retry_state.fn)
# Log an error message with the function name, time since start, attempt number, and the exception
i.error(
f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
f"exp: {retry_state.outcome.exception()}"
)
return log_it
def read_json_file(json_file: str, encoding=None) -> list[Any]:
if not Path(json_file).exists():
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
with open(json_file, "r", encoding=encoding) as fin:
try:
data = json.load(fin)
except Exception:
raise ValueError(f"read json file: {json_file} failed")
return data
def write_json_file(json_file: str, data: list, encoding=None):
folder_path = Path(json_file).parent
if not folder_path.exists():
folder_path.mkdir(parents=True, exist_ok=True)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder)
def import_class(class_name: str, module_name: str) -> type:
module = importlib.import_module(module_name)
a_class = getattr(module, class_name)
return a_class
def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object:
a_class = import_class(class_name, module_name)
class_inst = a_class(*args, **kwargs)
return class_inst
def format_trackback_info(limit: int = 2):
return traceback.format_exc(limit=limit)
def serialize_decorator(func):
async def wrapper(self, *args, **kwargs):
try:
result = await func(self, *args, **kwargs)
return result
except KeyboardInterrupt:
logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}")
except Exception:
logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}")
self.serialize() # Team.serialize
return wrapper
def role_raise_decorator(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except KeyboardInterrupt as kbi:
logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project")
if self.latest_observed_msg:
self._rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
except Exception:
if self.latest_observed_msg:
logger.warning(
"There is a exception in role's execution, in order to resume, "
"we delete the newest role communication message in the role's memory."
)
# remove role newest observed msg to make it observed again
self._rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
return wrapper
@handle_exception
async def aread(file_path: str) -> str:
"""Read file asynchronously."""
async with aiofiles.open(str(file_path), mode="r") as reader:
content = await reader.read()
return content

View file

@ -25,7 +25,7 @@ def py_make_scanner(context):
except IndexError:
raise StopIteration(idx) from None
if nextchar == '"' or nextchar == "'":
if nextchar in ("'", '"'):
if idx + 2 < len(string) and string[idx + 1] == nextchar and string[idx + 2] == nextchar:
# Handle the case where the next two characters are the same as nextchar
return parse_string(string, idx + 3, strict, delimiter=nextchar * 3) # triple quote

View file

@ -0,0 +1,103 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/11/22
@Author : mashenquan
@File : dependency_file.py
@Desc: Implementation of the dependency file described in Section 2.2.3.2 of RFC 135.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Set
import aiofiles
from metagpt.config import CONFIG
from metagpt.utils.common import aread
from metagpt.utils.exceptions import handle_exception
class DependencyFile:
"""A class representing a DependencyFile for managing dependencies.
:param workdir: The working directory path for the DependencyFile.
"""
def __init__(self, workdir: Path | str):
"""Initialize a DependencyFile instance.
:param workdir: The working directory path for the DependencyFile.
"""
self._dependencies = {}
self._filename = Path(workdir) / ".dependencies.json"
async def load(self):
"""Load dependencies from the file asynchronously."""
if not self._filename.exists():
return
self._dependencies = json.loads(await aread(self._filename))
@handle_exception
async def save(self):
"""Save dependencies to the file asynchronously."""
data = json.dumps(self._dependencies)
async with aiofiles.open(str(self._filename), mode="w") as writer:
await writer.write(data)
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
"""Update dependencies for a file asynchronously.
:param filename: The filename or path.
:param dependencies: The set of dependencies.
:param persist: Whether to persist the changes immediately.
"""
if persist:
await self.load()
root = self._filename.parent
try:
key = Path(filename).relative_to(root)
except ValueError:
key = filename
if dependencies:
relative_paths = []
for i in dependencies:
try:
relative_paths.append(str(Path(i).relative_to(root)))
except ValueError:
relative_paths.append(str(i))
self._dependencies[str(key)] = relative_paths
elif str(key) in self._dependencies:
del self._dependencies[str(key)]
if persist:
await self.save()
async def get(self, filename: Path | str, persist=True):
"""Get dependencies for a file asynchronously.
:param filename: The filename or path.
:param persist: Whether to load dependencies from the file immediately.
:return: A set of dependencies.
"""
if persist:
await self.load()
root = CONFIG.git_repo.workdir
try:
key = Path(filename).relative_to(root)
except ValueError:
key = filename
return set(self._dependencies.get(str(key), {}))
def delete_file(self):
"""Delete the dependency file."""
self._filename.unlink(missing_ok=True)
@property
def exists(self):
"""Check if the dependency file exists."""
return self._filename.exists()

View file

@ -0,0 +1,59 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19 14:46
@Author : alexanderwu
@File : exceptions.py
"""
import asyncio
import functools
import traceback
from typing import Any, Callable, Tuple, Type, TypeVar, Union
from metagpt.logs import logger
ReturnType = TypeVar("ReturnType")
def handle_exception(
_func: Callable[..., ReturnType] = None,
*,
exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception,
default_return: Any = None,
) -> Callable[..., ReturnType]:
"""handle exception, return default value"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
try:
return await func(*args, **kwargs)
except exception_type as e:
logger.opt(depth=1).error(
f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, "
f"stack: {traceback.format_exc()}"
)
return default_return
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
try:
return func(*args, **kwargs)
except exception_type as e:
logger.opt(depth=1).error(
f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, "
f"stack: {traceback.format_exc()}"
)
return default_return
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
if _func is None:
return decorator
else:
return decorator(_func)

View file

@ -6,10 +6,12 @@
@File : file.py
@Describe : General file operations.
"""
import aiofiles
from pathlib import Path
import aiofiles
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
class File:
@ -18,6 +20,7 @@ class File:
CHUNK_SIZE = 64 * 1024
@classmethod
@handle_exception
async def write(cls, root_path: Path, filename: str, content: bytes) -> Path:
"""Write the file content to the local specified path.
@ -32,18 +35,15 @@ class File:
Raises:
Exception: If an unexpected error occurs during the file writing process.
"""
try:
root_path.mkdir(parents=True, exist_ok=True)
full_path = root_path / filename
async with aiofiles.open(full_path, mode="wb") as writer:
await writer.write(content)
logger.debug(f"Successfully write file: {full_path}")
return full_path
except Exception as e:
logger.error(f"Error writing file: {e}")
raise e
root_path.mkdir(parents=True, exist_ok=True)
full_path = root_path / filename
async with aiofiles.open(full_path, mode="wb") as writer:
await writer.write(content)
logger.debug(f"Successfully write file: {full_path}")
return full_path
@classmethod
@handle_exception
async def read(cls, file_path: Path, chunk_size: int = None) -> bytes:
"""Partitioning read the file content from the local specified path.
@ -57,19 +57,14 @@ class File:
Raises:
Exception: If an unexpected error occurs during the file reading process.
"""
try:
chunk_size = chunk_size or cls.CHUNK_SIZE
async with aiofiles.open(file_path, mode="rb") as reader:
chunks = list()
while True:
chunk = await reader.read(chunk_size)
if not chunk:
break
chunks.append(chunk)
content = b''.join(chunks)
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content
except Exception as e:
logger.error(f"Error reading file: {e}")
raise e
chunk_size = chunk_size or cls.CHUNK_SIZE
async with aiofiles.open(file_path, mode="rb") as reader:
chunks = list()
while True:
chunk = await reader.read(chunk_size)
if not chunk:
break
chunks.append(chunk)
content = b"".join(chunks)
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content

View file

@ -0,0 +1,287 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/11/20
@Author : mashenquan
@File : git_repository.py
@Desc: File repository management. RFC 135 2.2.3.2, 2.2.3.4 and 2.2.3.13.
"""
from __future__ import annotations
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Set
import aiofiles
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import Document
from metagpt.utils.common import aread
from metagpt.utils.json_to_markdown import json_to_markdown
class FileRepository:
"""A class representing a FileRepository associated with a Git repository.
:param git_repo: The associated GitRepository instance.
:param relative_path: The relative path within the Git repository.
Attributes:
_relative_path (Path): The relative path within the Git repository.
_git_repo (GitRepository): The associated GitRepository instance.
"""
def __init__(self, git_repo, relative_path: Path = Path(".")):
"""Initialize a FileRepository instance.
:param git_repo: The associated GitRepository instance.
:param relative_path: The relative path within the Git repository.
"""
self._relative_path = relative_path
self._git_repo = git_repo
# Initializing
self.workdir.mkdir(parents=True, exist_ok=True)
async def save(self, filename: Path | str, content, dependencies: List[str] = None):
"""Save content to a file and update its dependencies.
:param filename: The filename or path within the repository.
:param content: The content to be saved.
:param dependencies: List of dependency filenames or paths.
"""
pathname = self.workdir / filename
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="w") as writer:
await writer.write(content)
logger.info(f"save to: {str(pathname)}")
if dependencies is not None:
dependency_file = await self._git_repo.get_dependency()
await dependency_file.update(pathname, set(dependencies))
logger.info(f"update dependency: {str(pathname)}:{dependencies}")
async def get_dependency(self, filename: Path | str) -> Set[str]:
"""Get the dependencies of a file.
:param filename: The filename or path within the repository.
:return: Set of dependency filenames or paths.
"""
pathname = self.workdir / filename
dependency_file = await self._git_repo.get_dependency()
return await dependency_file.get(pathname)
async def get_changed_dependency(self, filename: Path | str) -> Set[str]:
"""Get the dependencies of a file that have changed.
:param filename: The filename or path within the repository.
:return: List of changed dependency filenames or paths.
"""
dependencies = await self.get_dependency(filename=filename)
changed_files = self.changed_files
changed_dependent_files = set()
for df in dependencies:
if df in changed_files.keys():
changed_dependent_files.add(df)
return changed_dependent_files
async def get(self, filename: Path | str) -> Document | None:
"""Read the content of a file.
:param filename: The filename or path within the repository.
:return: The content of the file.
"""
doc = Document(root_path=str(self.root_path), filename=str(filename))
path_name = self.workdir / filename
if not path_name.exists():
return None
doc.content = await aread(path_name)
return doc
async def get_all(self) -> List[Document]:
"""Get the content of all files in the repository.
:return: List of Document instances representing files.
"""
docs = []
for root, dirs, files in os.walk(str(self.workdir)):
for file in files:
file_path = Path(root) / file
relative_path = file_path.relative_to(self.workdir)
doc = await self.get(relative_path)
docs.append(doc)
return docs
@property
def workdir(self):
"""Return the absolute path to the working directory of the FileRepository.
:return: The absolute path to the working directory.
"""
return self._git_repo.workdir / self._relative_path
@property
def root_path(self):
"""Return the relative path from git repository root"""
return self._relative_path
@property
def changed_files(self) -> Dict[str, str]:
"""Return a dictionary of changed files and their change types.
:return: A dictionary where keys are file paths and values are change types.
"""
files = self._git_repo.changed_files
relative_files = {}
for p, ct in files.items():
try:
rf = Path(p).relative_to(self._relative_path)
except ValueError:
continue
relative_files[str(rf)] = ct
return relative_files
@property
def all_files(self) -> List:
"""Get a dictionary of all files in the repository.
The dictionary includes file paths relative to the current FileRepository.
:return: A dictionary where keys are file paths and values are file information.
:rtype: List
"""
return self._git_repo.get_files(relative_path=self._relative_path)
def get_change_dir_files(self, dir: Path | str) -> List:
"""Get the files in a directory that have changed.
:param dir: The directory path within the repository.
:return: List of changed filenames or paths within the directory.
"""
changed_files = self.changed_files
children = []
for f in changed_files:
try:
Path(f).relative_to(Path(dir))
except ValueError:
continue
children.append(str(f))
return children
@staticmethod
def new_filename():
"""Generate a new filename based on the current timestamp and a UUID suffix.
:return: A new filename string.
"""
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
return current_time
# guid_suffix = str(uuid.uuid4())[:8]
# return f"{current_time}x{guid_suffix}"
async def save_doc(self, doc: Document, with_suffix: str = None, dependencies: List[str] = None):
"""Save a Document instance as a PDF file.
This method converts the content of the Document instance to Markdown,
saves it to a file with an optional specified suffix, and logs the saved file.
:param doc: The Document instance to be saved.
:type doc: Document
:param with_suffix: An optional suffix to append to the saved file's name.
:type with_suffix: str, optional
:param dependencies: A list of dependencies for the saved file.
:type dependencies: List[str], optional
"""
m = json.loads(doc.content)
filename = Path(doc.filename).with_suffix(with_suffix) if with_suffix is not None else Path(doc.filename)
await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
logger.debug(f"File Saved: {str(filename)}")
@staticmethod
async def get_file(filename: Path | str, relative_path: Path | str = ".") -> Document | None:
"""Retrieve a specific file from the file repository.
:param filename: The name or path of the file to retrieve.
:type filename: Path or str
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: The document representing the file, or None if not found.
:rtype: Document or None
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get(filename=filename)
@staticmethod
async def get_all_files(relative_path: Path | str = ".") -> List[Document]:
"""Retrieve all files from the file repository.
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: A list of documents representing all files in the repository.
:rtype: List[Document]
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get_all()
@staticmethod
async def save_file(filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."):
"""Save a file to the file repository.
:param filename: The name or path of the file to save.
:type filename: Path or str
:param content: The content of the file.
:param dependencies: A list of dependencies for the file.
:type dependencies: List[str], optional
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save(filename=filename, content=content, dependencies=dependencies)
@staticmethod
async def save_as(
doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "."
):
"""Save a Document instance with optional modifications.
This static method creates a new FileRepository, saves the Document instance
with optional modifications (such as a suffix), and logs the saved file.
:param doc: The Document instance to be saved.
:type doc: Document
:param with_suffix: An optional suffix to append to the saved file's name.
:type with_suffix: str, optional
:param dependencies: A list of dependencies for the saved file.
:type dependencies: List[str], optional
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: A boolean indicating whether the save operation was successful.
:rtype: bool
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save_doc(doc=doc, with_suffix=with_suffix, dependencies=dependencies)
async def delete(self, filename: Path | str):
"""Delete a file from the file repository.
This method deletes a file from the file repository based on the provided filename.
:param filename: The name or path of the file to be deleted.
:type filename: Path or str
"""
pathname = self.workdir / filename
if not pathname.exists():
return
pathname.unlink(missing_ok=True)
dependency_file = await self._git_repo.get_dependency()
await dependency_file.update(filename=pathname, dependencies=None)
logger.info(f"remove dependency key: {str(pathname)}")
@staticmethod
async def delete_file(filename: Path | str, relative_path: Path | str = "."):
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
await file_repo.delete(filename=filename)

View file

@ -8,10 +8,10 @@
from metagpt.config import CONFIG
def get_template(templates, format=CONFIG.prompt_format):
selected_templates = templates.get(format)
def get_template(templates, schema=CONFIG.prompt_schema):
selected_templates = templates.get(schema)
if selected_templates is None:
raise ValueError(f"Can't find {format} in passed in templates")
raise ValueError(f"Can't find {schema} in passed in templates")
# Extract the selected templates
prompt_template = selected_templates["PROMPT_TEMPLATE"]

View file

@ -0,0 +1,290 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/11/20
@Author : mashenquan
@File : git_repository.py
@Desc: Git repository management. RFC 135 2.2.3.3.
"""
from __future__ import annotations
import shutil
from enum import Enum
from pathlib import Path
from typing import Dict, List
from git.repo import Repo
from git.repo.fun import is_git_dir
from gitignore_parser import parse_gitignore
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.utils.dependency_file import DependencyFile
from metagpt.utils.file_repository import FileRepository
class ChangeType(Enum):
ADDED = "A" # File was added
COPIED = "C" # File was copied
DELETED = "D" # File was deleted
RENAMED = "R" # File was renamed
MODIFIED = "M" # File was modified
TYPE_CHANGED = "T" # Type of the file was changed
UNTRACTED = "U" # File is untracked (not added to version control)
class GitRepository:
"""A class representing a Git repository.
:param local_path: The local path to the Git repository.
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
Attributes:
_repository (Repo): The GitPython `Repo` object representing the Git repository.
"""
def __init__(self, local_path=None, auto_init=True):
"""Initialize a GitRepository instance.
:param local_path: The local path to the Git repository.
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
"""
self._repository = None
self._dependency = None
self._gitignore_rules = None
if local_path:
self.open(local_path=local_path, auto_init=auto_init)
def open(self, local_path: Path, auto_init=False):
"""Open an existing Git repository or initialize a new one if auto_init is True.
:param local_path: The local path to the Git repository.
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
"""
local_path = Path(local_path)
if self.is_git_dir(local_path):
self._repository = Repo(local_path)
self._gitignore_rules = parse_gitignore(full_path=str(local_path / ".gitignore"))
return
if not auto_init:
return
local_path.mkdir(parents=True, exist_ok=True)
return self._init(local_path)
def _init(self, local_path: Path):
"""Initialize a new Git repository at the specified path.
:param local_path: The local path where the new Git repository will be initialized.
"""
self._repository = Repo.init(path=Path(local_path))
gitignore_filename = Path(local_path) / ".gitignore"
ignores = ["__pycache__", "*.pyc"]
with open(str(gitignore_filename), mode="w") as writer:
writer.write("\n".join(ignores))
self._repository.index.add([".gitignore"])
self._repository.index.commit("Add .gitignore")
self._gitignore_rules = parse_gitignore(full_path=gitignore_filename)
def add_change(self, files: Dict):
"""Add or remove files from the staging area based on the provided changes.
:param files: A dictionary where keys are file paths and values are instances of ChangeType.
"""
if not self.is_valid or not files:
return
for k, v in files.items():
self._repository.index.remove(k) if v is ChangeType.DELETED else self._repository.index.add([k])
def commit(self, comments):
"""Commit the staged changes with the given comments.
:param comments: Comments for the commit.
"""
if self.is_valid:
self._repository.index.commit(comments)
def delete_repository(self):
"""Delete the entire repository directory."""
if self.is_valid:
shutil.rmtree(self._repository.working_dir)
@property
def changed_files(self) -> Dict[str, str]:
"""Return a dictionary of changed files and their change types.
:return: A dictionary where keys are file paths and values are change types.
"""
files = {i: ChangeType.UNTRACTED for i in self._repository.untracked_files}
changed_files = {f.a_path: ChangeType(f.change_type) for f in self._repository.index.diff(None)}
files.update(changed_files)
return files
@staticmethod
def is_git_dir(local_path):
"""Check if the specified directory is a Git repository.
:param local_path: The local path to check.
:return: True if the directory is a Git repository, False otherwise.
"""
git_dir = Path(local_path) / ".git"
if git_dir.exists() and is_git_dir(git_dir):
return True
return False
@property
def is_valid(self):
"""Check if the Git repository is valid (exists and is initialized).
:return: True if the repository is valid, False otherwise.
"""
return bool(self._repository)
@property
def status(self) -> str:
"""Return the Git repository's status as a string."""
if not self.is_valid:
return ""
return self._repository.git.status()
@property
def workdir(self) -> Path | None:
"""Return the path to the working directory of the Git repository.
:return: The path to the working directory or None if the repository is not valid.
"""
if not self.is_valid:
return None
return Path(self._repository.working_dir)
def archive(self, comments="Archive"):
"""Archive the current state of the Git repository.
:param comments: Comments for the archive commit.
"""
logger.info(f"Archive: {list(self.changed_files.keys())}")
self.add_change(self.changed_files)
self.commit(comments)
def new_file_repository(self, relative_path: Path | str = ".") -> FileRepository:
"""Create a new instance of FileRepository associated with this Git repository.
:param relative_path: The relative path to the file repository within the Git repository.
:return: A new instance of FileRepository.
"""
path = Path(relative_path)
try:
path = path.relative_to(self.workdir)
except ValueError:
path = relative_path
return FileRepository(git_repo=self, relative_path=Path(path))
async def get_dependency(self) -> DependencyFile:
"""Get the dependency file associated with the Git repository.
:return: An instance of DependencyFile.
"""
if not self._dependency:
self._dependency = DependencyFile(workdir=self.workdir)
return self._dependency
def rename_root(self, new_dir_name):
"""Rename the root directory of the Git repository.
:param new_dir_name: The new name for the root directory.
"""
if self.workdir.name == new_dir_name:
return
new_path = self.workdir.parent / new_dir_name
if new_path.exists():
logger.info(f"Delete directory {str(new_path)}")
shutil.rmtree(new_path)
try:
shutil.move(src=str(self.workdir), dst=str(new_path))
except Exception as e:
logger.warning(f"Move {str(self.workdir)} to {str(new_path)} error: {e}")
logger.info(f"Rename directory {str(self.workdir)} to {str(new_path)}")
self._repository = Repo(new_path)
self._gitignore_rules = parse_gitignore(full_path=str(new_path / ".gitignore"))
def get_files(self, relative_path: Path | str, root_relative_path: Path | str = None, filter_ignored=True) -> List:
"""
Retrieve a list of files in the specified relative path.
The method returns a list of file paths relative to the current FileRepository.
:param relative_path: The relative path within the repository.
:type relative_path: Path or str
:param root_relative_path: The root relative path within the repository.
:type root_relative_path: Path or str
:param filter_ignored: Flag to indicate whether to filter files based on .gitignore rules.
:type filter_ignored: bool
:return: A list of file paths in the specified directory.
:rtype: List[str]
"""
try:
relative_path = Path(relative_path).relative_to(self.workdir)
except ValueError:
relative_path = Path(relative_path)
if not root_relative_path:
root_relative_path = Path(self.workdir) / relative_path
files = []
try:
directory_path = Path(self.workdir) / relative_path
if not directory_path.exists():
return []
for file_path in directory_path.iterdir():
if file_path.is_file():
rpath = file_path.relative_to(root_relative_path)
files.append(str(rpath))
else:
subfolder_files = self.get_files(
relative_path=file_path, root_relative_path=root_relative_path, filter_ignored=False
)
files.extend(subfolder_files)
except Exception as e:
logger.error(f"Error: {e}")
if not filter_ignored:
return files
filtered_files = self.filter_gitignore(filenames=files, root_relative_path=root_relative_path)
return filtered_files
def filter_gitignore(self, filenames: List[str], root_relative_path: Path | str = None) -> List[str]:
"""
Filter a list of filenames based on .gitignore rules.
:param filenames: A list of filenames to be filtered.
:type filenames: List[str]
:param root_relative_path: The root relative path within the repository.
:type root_relative_path: Path or str
:return: A list of filenames that pass the .gitignore filtering.
:rtype: List[str]
"""
if root_relative_path is None:
root_relative_path = self.workdir
files = []
for filename in filenames:
pathname = root_relative_path / filename
if self._gitignore_rules(str(pathname)):
continue
files.append(filename)
return files
if __name__ == "__main__":
path = DEFAULT_WORKSPACE_ROOT / "git"
path.mkdir(exist_ok=True, parents=True)
repo = GitRepository()
repo.open(path, auto_init=True)
repo.filter_gitignore(filenames=["snake_game/snake_game/__pycache__", "snake_game/snake_game/game.py"])
changes = repo.changed_files
print(changes)
repo.add_change(changes)
print(repo.status)
repo.commit("test")
print(repo.status)
repo.delete_repository()

View file

@ -1,22 +1,22 @@
# 添加代码语法高亮显示
from pygments import highlight as highlight_
from pygments.formatters import HtmlFormatter, TerminalFormatter
from pygments.lexers import PythonLexer, SqlLexer
from pygments.formatters import TerminalFormatter, HtmlFormatter
def highlight(code: str, language: str = 'python', formatter: str = 'terminal'):
def highlight(code: str, language: str = "python", formatter: str = "terminal"):
# 指定要高亮的语言
if language.lower() == 'python':
if language.lower() == "python":
lexer = PythonLexer()
elif language.lower() == 'sql':
elif language.lower() == "sql":
lexer = SqlLexer()
else:
raise ValueError(f"Unsupported language: {language}")
# 指定输出格式
if formatter.lower() == 'terminal':
if formatter.lower() == "terminal":
formatter = TerminalFormatter()
elif formatter.lower() == 'html':
elif formatter.lower() == "html":
formatter = HtmlFormatter()
else:
raise ValueError(f"Unsupported formatter: {formatter}")

View file

@ -10,7 +10,7 @@ import os
from pathlib import Path
from metagpt.config import CONFIG
from metagpt.const import PROJECT_ROOT
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.utils.common import check_cmd_exists
@ -69,7 +69,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
if stdout:
logger.info(stdout.decode())
if stderr:
logger.error(stderr.decode())
logger.warning(stderr.decode())
else:
if engine == "playwright":
from metagpt.utils.mmdc_playwright import mermaid_to_file
@ -141,6 +141,6 @@ MMC2 = """sequenceDiagram
if __name__ == "__main__":
loop = asyncio.new_event_loop()
result = loop.run_until_complete(mermaid_to_file(MMC1, PROJECT_ROOT / f"{CONFIG.mermaid_engine}/1"))
result = loop.run_until_complete(mermaid_to_file(MMC2, PROJECT_ROOT / f"{CONFIG.mermaid_engine}/1"))
result = loop.run_until_complete(mermaid_to_file(MMC1, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
loop.close()

View file

@ -6,9 +6,9 @@
@File : mermaid.py
"""
import base64
import os
from aiohttp import ClientSession,ClientError
from aiohttp import ClientError, ClientSession
from metagpt.logs import logger
@ -29,7 +29,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix):
async with session.get(url) as response:
if response.status == 200:
text = await response.content.read()
with open(output_file, 'wb') as f:
with open(output_file, "wb") as f:
f.write(text)
logger.info(f"Generating {output_file}..")
else:

View file

@ -8,10 +8,13 @@
import os
from urllib.parse import urljoin
from playwright.async_api import async_playwright
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
"""
Converts the given Mermaid code to various output formats and saves them to files.
@ -24,66 +27,72 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
Returns:
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
"""
suffixes=['png', 'svg', 'pdf']
suffixes = ["png", "svg", "pdf"]
__dirname = os.path.dirname(os.path.abspath(__file__))
async with async_playwright() as p:
browser = await p.chromium.launch()
device_scale_factor = 1.0
context = await browser.new_context(
viewport={'width': width, 'height': height},
device_scale_factor=device_scale_factor,
)
viewport={"width": width, "height": height},
device_scale_factor=device_scale_factor,
)
page = await context.new_page()
async def console_message(msg):
logger.info(msg.text)
page.on('console', console_message)
page.on("console", console_message)
try:
await page.set_viewport_size({'width': width, 'height': height})
await page.set_viewport_size({"width": width, "height": height})
mermaid_html_path = os.path.abspath(
os.path.join(__dirname, 'index.html'))
mermaid_html_url = urljoin('file:', mermaid_html_path)
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
mermaid_html_url = urljoin("file:", mermaid_html_path)
await page.goto(mermaid_html_url)
await page.wait_for_load_state("networkidle")
await page.wait_for_selector("div#container", state="attached")
mermaid_config = {}
# mermaid_config = {}
background_color = "#ffffff"
my_css = ""
# my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
#
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}''', [mermaid_code, mermaid_config, my_css, background_color])
if 'svg' in suffixes :
svg_xml = await page.evaluate('''() => {
if "svg" in suffixes:
svg_xml = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const xmlSerializer = new XMLSerializer();
return xmlSerializer.serializeToString(svg);
}''')
}"""
)
logger.info(f"Generating {output_file_without_suffix}.svg..")
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
f.write(svg_xml.encode('utf-8'))
with open(f"{output_file_without_suffix}.svg", "wb") as f:
f.write(svg_xml.encode("utf-8"))
if 'png' in suffixes:
clip = await page.evaluate('''() => {
if "png" in suffixes:
clip = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const rect = svg.getBoundingClientRect();
return {
@ -92,16 +101,17 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
width: Math.ceil(rect.width),
height: Math.ceil(rect.height)
};
}''')
await page.set_viewport_size({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height']})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
}"""
)
await page.set_viewport_size({"width": clip["x"] + clip["width"], "height": clip["y"] + clip["height"]})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
logger.info(f"Generating {output_file_without_suffix}.png..")
with open(f'{output_file_without_suffix}.png', 'wb') as f:
with open(f"{output_file_without_suffix}.png", "wb") as f:
f.write(screenshot)
if 'pdf' in suffixes:
if "pdf" in suffixes:
pdf_data = await page.pdf(scale=device_scale_factor)
logger.info(f"Generating {output_file_without_suffix}.pdf..")
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
f.write(pdf_data)
return 0
except Exception as e:

View file

@ -7,11 +7,14 @@
"""
import os
from urllib.parse import urljoin
from pyppeteer import launch
from metagpt.logs import logger
from metagpt.config import CONFIG
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
from pyppeteer import launch
from metagpt.config import CONFIG
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
"""
Converts the given Mermaid code to various output formats and saves them to files.
@ -24,15 +27,15 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
Returns:
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
"""
suffixes = ['png', 'svg', 'pdf']
suffixes = ["png", "svg", "pdf"]
__dirname = os.path.dirname(os.path.abspath(__file__))
if CONFIG.pyppeteer_executable_path:
browser = await launch(headless=True,
executablePath=CONFIG.pyppeteer_executable_path,
args=['--disable-extensions',"--no-sandbox"]
)
browser = await launch(
headless=True,
executablePath=CONFIG.pyppeteer_executable_path,
args=["--disable-extensions", "--no-sandbox"],
)
else:
logger.error("Please set the environment variable:PYPPETEER_EXECUTABLE_PATH.")
return -1
@ -41,50 +44,56 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
async def console_message(msg):
logger.info(msg.text)
page.on('console', console_message)
page.on("console", console_message)
try:
await page.setViewport(viewport={'width': width, 'height': height, 'deviceScaleFactor': device_scale_factor})
await page.setViewport(viewport={"width": width, "height": height, "deviceScaleFactor": device_scale_factor})
mermaid_html_path = os.path.abspath(
os.path.join(__dirname, 'index.html'))
mermaid_html_url = urljoin('file:', mermaid_html_path)
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
mermaid_html_url = urljoin("file:", mermaid_html_path)
await page.goto(mermaid_html_url)
await page.querySelector("div#container")
mermaid_config = {}
# mermaid_config = {}
background_color = "#ffffff"
my_css = ""
# my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}''', [mermaid_code, mermaid_config, my_css, background_color])
if 'svg' in suffixes :
svg_xml = await page.evaluate('''() => {
if "svg" in suffixes:
svg_xml = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const xmlSerializer = new XMLSerializer();
return xmlSerializer.serializeToString(svg);
}''')
}"""
)
logger.info(f"Generating {output_file_without_suffix}.svg..")
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
f.write(svg_xml.encode('utf-8'))
with open(f"{output_file_without_suffix}.svg", "wb") as f:
f.write(svg_xml.encode("utf-8"))
if 'png' in suffixes:
clip = await page.evaluate('''() => {
if "png" in suffixes:
clip = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const rect = svg.getBoundingClientRect();
return {
@ -93,16 +102,23 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
width: Math.ceil(rect.width),
height: Math.ceil(rect.height)
};
}''')
await page.setViewport({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height'], 'deviceScaleFactor': device_scale_factor})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
}"""
)
await page.setViewport(
{
"width": clip["x"] + clip["width"],
"height": clip["y"] + clip["height"],
"deviceScaleFactor": device_scale_factor,
}
)
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
logger.info(f"Generating {output_file_without_suffix}.png..")
with open(f'{output_file_without_suffix}.png', 'wb') as f:
with open(f"{output_file_without_suffix}.png", "wb") as f:
f.write(screenshot)
if 'pdf' in suffixes:
if "pdf" in suffixes:
pdf_data = await page.pdf(scale=device_scale_factor)
logger.info(f"Generating {output_file_without_suffix}.pdf..")
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
f.write(pdf_data)
return 0
except Exception as e:
@ -110,4 +126,3 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
return -1
finally:
await browser.close()

View file

@ -16,7 +16,7 @@ class WebPage(BaseModel):
class Config:
underscore_attrs_are_private = True
_soup : Optional[BeautifulSoup] = None
_soup: Optional[BeautifulSoup] = None
_title: Optional[str] = None
@property
@ -24,7 +24,7 @@ class WebPage(BaseModel):
if self._soup is None:
self._soup = BeautifulSoup(self.html, "html.parser")
return self._soup
@property
def title(self):
if self._title is None:

View file

@ -37,12 +37,12 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
if not isinstance(expr, cst.Expr):
return None
val = expr.value
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
return None
evaluated_value = val.evaluated_value
evaluated_value = val.evaluated_value
if isinstance(evaluated_value, bytes):
return None
@ -56,6 +56,7 @@ class DocstringCollector(cst.CSTVisitor):
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(self):
self.stack: list[str] = []
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
@ -96,6 +97,7 @@ class DocstringTransformer(cst.CSTTransformer):
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(
self,
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
@ -125,7 +127,9 @@ class DocstringTransformer(cst.CSTTransformer):
key = tuple(self.stack)
self.stack.pop()
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
if hasattr(updated_node, "decorators") and any(
(i.decorator.value == "overload") for i in updated_node.decorators
):
return updated_node
statement = self.docstrings.get(key)

View file

@ -8,6 +8,7 @@
import docx
def read_docx(file_path: str) -> list:
"""Open a docx file"""
doc = docx.Document(file_path)

View file

@ -0,0 +1,310 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : repair llm raw output with particular conditions
import copy
from enum import Enum
from typing import Callable, Union
import regex as re
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.custom_decoder import CustomDecoder
class RepairType(Enum):
CS = "case sensitivity"
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
JSON = "json format"
def repair_case_sensitivity(output: str, req_key: str) -> str:
"""
usually, req_key is the key name of expected json or markdown content, it won't appear in the value part.
fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually
"""
if req_key in output:
return output
output_lower = output.lower()
req_key_lower = req_key.lower()
if req_key_lower in output_lower:
# find the sub-part index, and replace it with raw req_key
lidx = output_lower.find(req_key_lower)
source = output[lidx : lidx + len(req_key_lower)]
output = output.replace(source, req_key)
logger.info(f"repair_case_sensitivity: {req_key}")
return output
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
fix
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
"""
sc_arr = ["/"]
if req_key in output:
return output
for sc in sc_arr:
req_key_pure = req_key.replace(sc, "")
appear_cnt = output.count(req_key_pure)
if req_key_pure in output and appear_cnt > 1:
# req_key with special_character usually in the tail side
ridx = output.rfind(req_key_pure)
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
return output
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
implement the req_key pair in the begin or end of the content
req_key format
1. `[req_key]`, and its pair `[/req_key]`
2. `[/req_key]`, and its pair `[req_key]`
"""
sc = "/" # special char
if req_key.startswith("[") and req_key.endswith("]"):
if sc in req_key:
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
right_key = req_key
else:
left_key = req_key
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
if left_key not in output:
output = left_key + "\n" + output
if right_key not in output:
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
ridx = routput.rfind(left_key)
if ridx < 0:
return None
sub_output = routput[ridx:]
idx1 = sub_output.rfind("}")
idx2 = sub_output.rindex("]")
idx = idx1 if idx1 >= idx2 else idx2
sub_output = sub_output[: idx + 1]
return sub_output
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
# # avoid [req_key]xx[req_key] case to append [/req_key]
output = output + "\n" + right_key
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
sub_content = judge_potential_json(output, left_key)
output = sub_content + "\n" + right_key
return output
def repair_json_format(output: str) -> str:
"""
fix extra `[` or `}` in the end
"""
output = output.strip()
if output.startswith("[{"):
output = output[1:]
logger.info(f"repair_json_format: {'[{'}")
elif output.endswith("}]"):
output = output[:-1]
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
return output
def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str:
repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]]
for repair_type in repair_types:
if repair_type == RepairType.CS:
output = repair_case_sensitivity(output, req_key)
elif repair_type == RepairType.RKPM:
output = repair_required_key_pair_missing(output, req_key)
elif repair_type == RepairType.SCM:
output = repair_special_character_missing(output, req_key)
elif repair_type == RepairType.JSON:
output = repair_json_format(output)
return output
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
"""
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
so here we try to repair it and use all repair methods by default.
typical case
1. case sensitivity
target: "Original Requirements"
output: "Original requirements"
2. special character missing
target: [/CONTENT]
output: [CONTENT]
3. json format
target: { xxx }
output: { xxx }]
"""
if not CONFIG.repair_llm_output:
return output
# do the repairation usually for non-openai models
for req_key in req_keys:
output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type)
return output
def repair_invalid_json(output: str, error: str) -> str:
"""
repair the situation like there are extra chars like
error examples
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
"""
pattern = r"line ([0-9]+)"
matches = re.findall(pattern, error, re.DOTALL)
if len(matches) > 0:
line_no = int(matches[0]) - 1
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
output = output.replace('"""', '"').replace("'''", '"')
arr = output.split("\n")
line = arr[line_no].strip()
# different general problems
if line.endswith("],"):
# problem, redundant char `]`
new_line = line.replace("]", "")
elif line.endswith("},") and not output.endswith("},"):
# problem, redundant char `}`
new_line = line.replace("}", "")
elif line.endswith("},") and output.endswith("},"):
new_line = line[:-1]
elif '",' not in line and "," not in line:
new_line = f'{line}",'
elif "," not in line:
# problem, miss char `,` at the end.
new_line = f"{line},"
elif "," in line and len(line) == 1:
new_line = f'"{line}'
elif '",' in line:
new_line = line[:-2] + "',"
arr[line_no] = new_line
output = "\n".join(arr)
logger.info(f"repair_invalid_json, raw error: {error}")
return output
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
def run_and_passon(retry_state: RetryCallState) -> None:
"""
RetryCallState example
{
"start_time":143.098322024,
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
"args":"(\"tag:[/CONTENT]\",)", # function input args
"kwargs":{}, # function input kwargs
"attempt_number":1, # retry number
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
"outcome_timestamp":143.098416904,
"idle_for":0,
"next_action":"None"
}
"""
if retry_state.outcome.failed:
if retry_state.args:
# # can't be used as args=retry_state.args
func_param_output = retry_state.args[0]
elif retry_state.kwargs:
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
logger.warning(
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
)
repaired_output = repair_invalid_json(func_param_output, exp_str)
retry_state.kwargs["output"] = repaired_output
return run_and_passon
@retry(
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
wait=wait_fixed(1),
after=run_after_exp_and_passon_next_retry(logger),
)
def retry_parse_json_text(output: str) -> Union[list, dict]:
"""
repair the json-text situation like there are extra chars like [']', '}']
Warning
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
it's a two-layer retry cycle
"""
# logger.debug(f"output to json decode:\n{output}")
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
parsed_data = CustomDecoder(strict=False).decode(output)
return parsed_data
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
"""extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern"""
def re_extract_content(cont: str, pattern: str) -> str:
matches = re.findall(pattern, cont, re.DOTALL)
for match in matches:
if match:
cont = match
break
return cont.strip()
# TODO construct the extract pattern with the `right_key`
raw_content = copy.deepcopy(content)
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)
if not new_content.startswith("{"):
# TODO find a more general pattern
# # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
logger.warning(f"extract_content try another pattern: {pattern}")
if right_key not in new_content:
raw_content = copy.deepcopy(new_content + "\n" + right_key)
# # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)
else:
if right_key in new_content:
idx = new_content.find(right_key)
new_content = new_content[:idx]
new_content = new_content.strip()
return new_content
def extract_state_value_from_output(content: str) -> str:
"""
For openai models, they will always return state number. But for open llm models, the instruction result maybe a
long text contain target number, so here add a extraction to improve success rate.
Args:
content (str): llm's output from `Role._think`
"""
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
matches = re.findall(pattern, content, re.DOTALL)
matches = list(set(matches))
state = matches[0] if len(matches) > 0 else "-1"
return state

View file

@ -4,13 +4,11 @@
import copy
import pickle
from typing import Dict, List
from metagpt.actions.action_output import ActionOutput
from metagpt.schema import Message
from metagpt.utils.common import import_class
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
def actionoutout_schema_to_mapping(schema: dict) -> dict:
"""
directly traverse the `properties` in the first level.
schema structure likes
@ -35,14 +33,31 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
if property["type"] == "string":
mapping[field] = (str, ...)
elif property["type"] == "array" and property["items"]["type"] == "string":
mapping[field] = (List[str], ...)
mapping[field] = (list[str], ...)
elif property["type"] == "array" and property["items"]["type"] == "array":
# here only consider the `List[List[str]]` situation
mapping[field] = (List[List[str]], ...)
# here only consider the `list[list[str]]` situation
mapping[field] = (list[list[str]], ...)
return mapping
def serialize_message(message: Message):
def actionoutput_mapping_to_str(mapping: dict) -> dict:
new_mapping = {}
for key, value in mapping.items():
new_mapping[key] = str(value)
return new_mapping
def actionoutput_str_to_mapping(mapping: dict) -> dict:
new_mapping = {}
for key, value in mapping.items():
if value == "(<class 'str'>, Ellipsis)":
new_mapping[key] = (str, ...)
else:
new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)`
return new_mapping
def serialize_message(message: "Message"):
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference
ic = message_cp.instruct_content
if ic:
@ -56,11 +71,12 @@ def serialize_message(message: Message):
return msg_ser
def deserialize_message(message_ser: str) -> Message:
def deserialize_message(message_ser: str) -> "Message":
message = pickle.loads(message_ser)
if message.instruct_content:
ic = message.instruct_content
ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
ic_new = ic_obj(**ic["value"])
message.instruct_content = ic_new

View file

@ -20,4 +20,3 @@ class Singleton(abc.ABCMeta, type):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]

View file

@ -1,4 +1,4 @@
# token to separate different code messages in a WriteCode Message content
MSG_SEP = "#*000*#"
MSG_SEP = "#*000*#"
# token to seperate file name and the actual code text in a code message
FILENAME_CODE_SEP = "#*001*#"

View file

@ -3,7 +3,12 @@ from typing import Generator, Sequence
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
def reduce_message_length(
msgs: Generator[str, None, None],
model_name: str,
system_text: str,
reserved: int = 0,
) -> str:
"""Reduce the length of concatenated message segments to fit within the maximum token size.
Args:
@ -49,9 +54,9 @@ def generate_prompt_chunk(
current_token = 0
current_lines = []
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
reserved = reserved + count_string_tokens(prompt_template + system_text, model_name)
# 100 is a magic number to ensure the maximum context length is not exceeded
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
while paragraphs:
paragraph = paragraphs.pop(0)
@ -103,7 +108,7 @@ def decode_unicode_escape(text: str) -> str:
return text.encode("utf-8").decode("unicode_escape", "ignore")
def _split_by_count(lst: Sequence , count: int):
def _split_by_count(lst: Sequence, count: int):
avg = len(lst) // count
remainder = len(lst) % count
start = 0

View file

@ -18,11 +18,13 @@ TOKEN_COSTS = {
"gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004},
"gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002},
"gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004},
"gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002},
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
"gpt-4": {"prompt": 0.03, "completion": 0.06},
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
}
@ -36,11 +38,13 @@ TOKEN_MAX = {
"gpt-3.5-turbo-16k-0613": 16384,
"gpt-35-turbo": 4096,
"gpt-35-turbo-16k": 16384,
"gpt-3.5-turbo-1106": 16384,
"gpt-4-0314": 8192,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-4-0613": 8192,
"gpt-4-1106-preview": 128000,
"text-embedding-ada-002": 8192,
"chatglm_turbo": 32768,
}
@ -58,20 +62,23 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"gpt-3.5-turbo-16k-0613",
"gpt-35-turbo",
"gpt-35-turbo-16k",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-1106",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
}:
tokens_per_message = 3
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
elif "gpt-3.5-turbo" == model:
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
elif "gpt-4" == model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_message_tokens(messages, model="gpt-4-0613")
else: