Merge branch 'main' into dev_updated

This commit is contained in:
yzlin 2024-01-10 14:10:15 +08:00
commit 853086924a
429 changed files with 24237 additions and 5835 deletions

View file

@ -0,0 +1,49 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : pure async http_client
from typing import Any, Mapping, Optional, Union
import aiohttp
from aiohttp.client import DEFAULT_TIMEOUT
async def apost(
url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
as_json: bool = False,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total,
) -> Union[str, dict]:
async with aiohttp.ClientSession() as session:
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
if as_json:
data = await resp.json()
else:
data = await resp.read()
data = data.decode(encoding)
return data
async def apost_stream(
url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total,
) -> Any:
"""
usage:
result = astream(url="xx")
async for line in result:
deal_with(line)
"""
async with aiohttp.ClientSession() as session:
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
async for line in resp.content:
yield line.decode(encoding)

View file

@ -4,16 +4,35 @@
@Time : 2023/4/29 16:07
@Author : alexanderwu
@File : common.py
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.2 of RFC 116:
Add generic class-to-string and object-to-string conversion functionality.
@Modified By: mashenquan, 2023/11/27. Bug fix: `parse_recipient` failed to parse the recipient in certain GPT-3.5
responses.
"""
from __future__ import annotations
import ast
import contextlib
import importlib
import inspect
import json
import os
import platform
import re
from typing import List, Tuple, Union
import sys
import traceback
import typing
from pathlib import Path
from typing import Any, List, Tuple, Union
import aiofiles
import loguru
from pydantic_core import to_jsonable_python
from tenacity import RetryCallState, _utils
from metagpt.const import MESSAGE_ROUTE_TO_ALL
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
def check_cmd_exists(command) -> int:
@ -29,6 +48,12 @@ def check_cmd_exists(command) -> int:
return result
def require_python_version(req_version: Tuple) -> bool:
if not (2 <= len(req_version) <= 3):
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
return bool(sys.version_info > req_version)
class OutputParser:
@classmethod
def parse_blocks(cls, text: str):
@ -85,10 +110,7 @@ class OutputParser:
@staticmethod
def parse_python_code(text: str) -> str:
for pattern in (
r"(.*?```python.*?\s+)?(?P<code>.*)(```.*?)",
r"(.*?```python.*?\s+)?(?P<code>.*)",
):
for pattern in (r"(.*?```python.*?\s+)?(?P<code>.*)(```.*?)", r"(.*?```python.*?\s+)?(?P<code>.*)"):
match = re.search(pattern, text, re.DOTALL)
if not match:
continue
@ -109,18 +131,28 @@ class OutputParser:
try:
content = cls.parse_code(text=content)
except Exception:
pass
# 尝试解析list
try:
content = cls.parse_file_list(text=content)
except Exception:
pass
# 尝试解析list
try:
content = cls.parse_file_list(text=content)
except Exception:
pass
parsed_data[block] = content
return parsed_data
@staticmethod
def extract_content(text, tag="CONTENT"):
# Use regular expression to extract content between [CONTENT] and [/CONTENT]
extracted_content = re.search(rf"\[{tag}\](.*?)\[/{tag}\]", text, re.DOTALL)
if extracted_content:
return extracted_content.group(1).strip()
else:
raise ValueError(f"Could not find content between [{tag}] and [/{tag}]")
@classmethod
def parse_data_with_mapping(cls, data, mapping):
if "[CONTENT]" in data:
data = cls.extract_content(text=data)
block_dict = cls.parse_blocks(data)
parsed_data = {}
for block, content in block_dict.items():
@ -187,7 +219,7 @@ class OutputParser:
result = ast.literal_eval(structure_text)
# Ensure the result matches the specified data type
if isinstance(result, list) or isinstance(result, dict):
if isinstance(result, (list, dict)):
return result
raise ValueError(f"The extracted structure is not a {data_type}.")
@ -219,10 +251,15 @@ class CodeParser:
# 遍历所有的block
for block in blocks:
# 如果block不为空则继续处理
if block.strip() != "":
if block.strip() == "":
continue
if "\n" not in block:
block_title = block
block_content = ""
else:
# 将block的标题和内容分开并分别去掉前后的空白字符
block_title, block_content = block.split("\n", 1)
block_dict[block_title.strip()] = block_content.strip()
block_dict[block_title.strip()] = block_content.strip()
return block_dict
@ -282,9 +319,6 @@ class NoMoneyException(Exception):
def print_members(module, indent=0):
"""
https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python
:param module:
:param indent:
:return:
"""
prefix = " " * indent
for name, obj in inspect.getmembers(module):
@ -302,9 +336,16 @@ def print_members(module, indent=0):
def parse_recipient(text):
# FIXME: use ActionNode instead.
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
recipient = re.search(pattern, text)
return recipient.group(1) if recipient else ""
if recipient:
return recipient.group(1)
pattern = r"Send To:\s*([A-Za-z]+)\s*?"
recipient = re.search(pattern, text)
if recipient:
return recipient.group(1)
return ""
def create_func_config(func_schema: dict) -> dict:
@ -329,3 +370,224 @@ def remove_comments(code_str):
clean_code = re.sub(pattern, replace_func, code_str, flags=re.MULTILINE)
clean_code = os.linesep.join([s.rstrip() for s in clean_code.splitlines() if s.strip()])
return clean_code
def get_class_name(cls) -> str:
"""Return class name"""
return f"{cls.__module__}.{cls.__name__}"
def any_to_str(val: Any) -> str:
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
if isinstance(val, str):
return val
elif not callable(val):
return get_class_name(type(val))
else:
return get_class_name(val)
def any_to_str_set(val) -> set:
"""Convert any type to string set."""
res = set()
# Check if the value is iterable, but not a string (since strings are technically iterable)
if isinstance(val, (dict, list, set, tuple)):
# Special handling for dictionaries to iterate over values
if isinstance(val, dict):
val = val.values()
for i in val:
res.add(any_to_str(i))
else:
res.add(any_to_str(val))
return res
def is_subscribed(message: "Message", tags: set):
"""Return whether it's consumer"""
if MESSAGE_ROUTE_TO_ALL in message.send_to:
return True
for i in tags:
if i in message.send_to:
return True
return False
def any_to_name(val):
"""
Convert a value to its name by extracting the last part of the dotted path.
:param val: The value to convert.
:return: The name of the value.
"""
return any_to_str(val).split(".")[-1]
def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def split_namespace(ns_class_name: str) -> List[str]:
return ns_class_name.split(":")
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
This generated function logs an error message with the outcome of the retried function call. It includes
the name of the function, the time taken for the call in seconds (formatted according to `sec_format`),
the number of attempts made, and the exception raised, if any.
:param i: A Logger instance from the loguru library used to log the error message.
:param sec_format: A string format specifier for how to format the number of seconds since the start of the call.
Defaults to three decimal places.
:return: A callable that accepts a RetryCallState object and returns None. This callable logs the details
of the retried call.
"""
def log_it(retry_state: "RetryCallState") -> None:
# If the function name is not known, default to "<unknown>"
if retry_state.fn is None:
fn_name = "<unknown>"
else:
# Retrieve the callable's name using a utility function
fn_name = _utils.get_callback_name(retry_state.fn)
# Log an error message with the function name, time since start, attempt number, and the exception
i.error(
f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
f"exp: {retry_state.outcome.exception()}"
)
return log_it
def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
if not Path(json_file).exists():
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
with open(json_file, "r", encoding=encoding) as fin:
try:
data = json.load(fin)
except Exception:
raise ValueError(f"read json file: {json_file} failed")
return data
def write_json_file(json_file: str, data: list, encoding=None):
folder_path = Path(json_file).parent
if not folder_path.exists():
folder_path.mkdir(parents=True, exist_ok=True)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
def import_class(class_name: str, module_name: str) -> type:
module = importlib.import_module(module_name)
a_class = getattr(module, class_name)
return a_class
def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object:
a_class = import_class(class_name, module_name)
class_inst = a_class(*args, **kwargs)
return class_inst
def format_trackback_info(limit: int = 2):
return traceback.format_exc(limit=limit)
def serialize_decorator(func):
async def wrapper(self, *args, **kwargs):
try:
result = await func(self, *args, **kwargs)
return result
except KeyboardInterrupt:
logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}")
except Exception:
logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}")
self.serialize() # Team.serialize
return wrapper
def role_raise_decorator(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except KeyboardInterrupt as kbi:
logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project")
if self.latest_observed_msg:
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
except Exception:
if self.latest_observed_msg:
logger.warning(
"There is a exception in role's execution, in order to resume, "
"we delete the newest role communication message in the role's memory."
)
# remove role newest observed msg to make it observed again
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
return wrapper
@handle_exception
async def aread(filename: str | Path, encoding=None) -> str:
"""Read file asynchronously."""
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
content = await reader.read()
return content
async def awrite(filename: str | Path, data: str, encoding=None):
"""Write file asynchronously."""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="w", encoding=encoding) as writer:
await writer.write(data)
async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
if not Path(filename).exists():
return ""
lines = []
async with aiofiles.open(str(filename), mode="r") as reader:
ix = 0
while ix < end_lineno:
ix += 1
line = await reader.readline()
if ix < lineno:
continue
if ix > end_lineno:
break
lines.append(line)
return "".join(lines)
def list_files(root: str | Path) -> List[Path]:
files = []
try:
directory_path = Path(root)
if not directory_path.exists():
return []
for file_path in directory_path.iterdir():
if file_path.is_file():
files.append(file_path)
else:
subfolder_files = list_files(root=file_path)
files.extend(subfolder_files)
except Exception as e:
logger.error(f"Error: {e}")
return files

View file

@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/28
@Author : mashenquan
@File : openai.py
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
"""
from typing import NamedTuple
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.utils.token_counter import TOKEN_COSTS
class Costs(NamedTuple):
total_prompt_tokens: int
total_completion_tokens: int
total_cost: float
total_budget: float
class CostManager(BaseModel):
"""Calculate the overhead of using the interface."""
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
def get_total_prompt_tokens(self):
"""
Get the total number of prompt tokens.
Returns:
int: The total number of prompt tokens.
"""
return self.total_prompt_tokens
def get_total_completion_tokens(self):
"""
Get the total number of completion tokens.
Returns:
int: The total number of completion tokens.
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
Returns:
float: The total cost of API calls.
"""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)

View file

@ -25,7 +25,7 @@ def py_make_scanner(context):
except IndexError:
raise StopIteration(idx) from None
if nextchar == '"' or nextchar == "'":
if nextchar in ("'", '"'):
if idx + 2 < len(string) and string[idx + 1] == nextchar and string[idx + 2] == nextchar:
# Handle the case where the next two characters are the same as nextchar
return parse_string(string, idx + 3, strict, delimiter=nextchar * 3) # triple quote

View file

@ -0,0 +1,102 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/11/22
@Author : mashenquan
@File : dependency_file.py
@Desc: Implementation of the dependency file described in Section 2.2.3.2 of RFC 135.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Set
import aiofiles
from metagpt.utils.common import aread
from metagpt.utils.exceptions import handle_exception
class DependencyFile:
"""A class representing a DependencyFile for managing dependencies.
:param workdir: The working directory path for the DependencyFile.
"""
def __init__(self, workdir: Path | str):
"""Initialize a DependencyFile instance.
:param workdir: The working directory path for the DependencyFile.
"""
self._dependencies = {}
self._filename = Path(workdir) / ".dependencies.json"
async def load(self):
"""Load dependencies from the file asynchronously."""
if not self._filename.exists():
return
self._dependencies = json.loads(await aread(self._filename))
@handle_exception
async def save(self):
"""Save dependencies to the file asynchronously."""
data = json.dumps(self._dependencies)
async with aiofiles.open(str(self._filename), mode="w") as writer:
await writer.write(data)
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
"""Update dependencies for a file asynchronously.
:param filename: The filename or path.
:param dependencies: The set of dependencies.
:param persist: Whether to persist the changes immediately.
"""
if persist:
await self.load()
root = self._filename.parent
try:
key = Path(filename).relative_to(root)
except ValueError:
key = filename
if dependencies:
relative_paths = []
for i in dependencies:
try:
relative_paths.append(str(Path(i).relative_to(root)))
except ValueError:
relative_paths.append(str(i))
self._dependencies[str(key)] = relative_paths
elif str(key) in self._dependencies:
del self._dependencies[str(key)]
if persist:
await self.save()
async def get(self, filename: Path | str, persist=True):
"""Get dependencies for a file asynchronously.
:param filename: The filename or path.
:param persist: Whether to load dependencies from the file immediately.
:return: A set of dependencies.
"""
if persist:
await self.load()
root = self._filename.parent
try:
key = Path(filename).relative_to(root)
except ValueError:
key = filename
return set(self._dependencies.get(str(key), {}))
def delete_file(self):
"""Delete the dependency file."""
self._filename.unlink(missing_ok=True)
@property
def exists(self):
"""Check if the dependency file exists."""
return self._filename.exists()

View file

@ -0,0 +1,82 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : di_graph_repository.py
@Desc : Graph repository based on DiGraph
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import List
import networkx
from metagpt.utils.common import aread, awrite
from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
self._repo.add_edge(subject, object_, predicate=predicate)
async def upsert(self, subject: str, predicate: str, object_: str):
pass
async def update(self, subject: str, predicate: str, object_: str):
pass
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
result = []
for s, o, p in self._repo.edges(data="predicate"):
if subject and subject != s:
continue
if predicate and predicate != p:
continue
if object_ and object_ != o:
continue
result.append(SPO(subject=s, predicate=p, object_=o))
return result
def json(self) -> str:
m = networkx.node_link_data(self._repo)
data = json.dumps(m)
return data
async def save(self, path: str | Path = None):
data = self.json()
path = path or self._kwargs.get("root")
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
pathname = Path(path) / self.name
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
if pathname.exists():
await graph.load(pathname=pathname)
return graph
@property
def root(self) -> str:
return self._kwargs.get("root")
@property
def pathname(self) -> Path:
p = Path(self.root) / self.name
return p.with_suffix(".json")

View file

@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19 14:46
@Author : alexanderwu
@File : exceptions.py
"""
import asyncio
import functools
import traceback
from typing import Any, Callable, Tuple, Type, TypeVar, Union
from metagpt.logs import logger
ReturnType = TypeVar("ReturnType")
def handle_exception(
_func: Callable[..., ReturnType] = None,
*,
exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception,
exception_msg: str = "",
default_return: Any = None,
) -> Callable[..., ReturnType]:
"""handle exception, return default value"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
try:
return await func(*args, **kwargs)
except exception_type as e:
logger.opt(depth=1).error(
f"{e}: {exception_msg}, "
f"\nCalling {func.__name__} with args: {args}, kwargs: {kwargs} "
f"\nStack: {traceback.format_exc()}"
)
return default_return
@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
try:
return func(*args, **kwargs)
except exception_type as e:
logger.opt(depth=1).error(
f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, "
f"stack: {traceback.format_exc()}"
)
return default_return
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
if _func is None:
return decorator
else:
return decorator(_func)

View file

@ -6,10 +6,12 @@
@File : file.py
@Describe : General file operations.
"""
import aiofiles
from pathlib import Path
import aiofiles
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
class File:
@ -18,6 +20,7 @@ class File:
CHUNK_SIZE = 64 * 1024
@classmethod
@handle_exception
async def write(cls, root_path: Path, filename: str, content: bytes) -> Path:
"""Write the file content to the local specified path.
@ -32,18 +35,15 @@ class File:
Raises:
Exception: If an unexpected error occurs during the file writing process.
"""
try:
root_path.mkdir(parents=True, exist_ok=True)
full_path = root_path / filename
async with aiofiles.open(full_path, mode="wb") as writer:
await writer.write(content)
logger.debug(f"Successfully write file: {full_path}")
return full_path
except Exception as e:
logger.error(f"Error writing file: {e}")
raise e
root_path.mkdir(parents=True, exist_ok=True)
full_path = root_path / filename
async with aiofiles.open(full_path, mode="wb") as writer:
await writer.write(content)
logger.debug(f"Successfully write file: {full_path}")
return full_path
@classmethod
@handle_exception
async def read(cls, file_path: Path, chunk_size: int = None) -> bytes:
"""Partitioning read the file content from the local specified path.
@ -57,19 +57,14 @@ class File:
Raises:
Exception: If an unexpected error occurs during the file reading process.
"""
try:
chunk_size = chunk_size or cls.CHUNK_SIZE
async with aiofiles.open(file_path, mode="rb") as reader:
chunks = list()
while True:
chunk = await reader.read(chunk_size)
if not chunk:
break
chunks.append(chunk)
content = b''.join(chunks)
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content
except Exception as e:
logger.error(f"Error reading file: {e}")
raise e
chunk_size = chunk_size or cls.CHUNK_SIZE
async with aiofiles.open(file_path, mode="rb") as reader:
chunks = list()
while True:
chunk = await reader.read(chunk_size)
if not chunk:
break
chunks.append(chunk)
content = b"".join(chunks)
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content

View file

@ -0,0 +1,290 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/11/20
@Author : mashenquan
@File : git_repository.py
@Desc: File repository management. RFC 135 2.2.3.2, 2.2.3.4 and 2.2.3.13.
"""
from __future__ import annotations
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Set
import aiofiles
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.schema import Document
from metagpt.utils.common import aread
from metagpt.utils.json_to_markdown import json_to_markdown
class FileRepository:
"""A class representing a FileRepository associated with a Git repository.
:param git_repo: The associated GitRepository instance.
:param relative_path: The relative path within the Git repository.
Attributes:
_relative_path (Path): The relative path within the Git repository.
_git_repo (GitRepository): The associated GitRepository instance.
"""
def __init__(self, git_repo, relative_path: Path = Path(".")):
"""Initialize a FileRepository instance.
:param git_repo: The associated GitRepository instance.
:param relative_path: The relative path within the Git repository.
"""
self._relative_path = relative_path
self._git_repo = git_repo
# Initializing
self.workdir.mkdir(parents=True, exist_ok=True)
async def save(self, filename: Path | str, content, dependencies: List[str] = None):
"""Save content to a file and update its dependencies.
:param filename: The filename or path within the repository.
:param content: The content to be saved.
:param dependencies: List of dependency filenames or paths.
"""
pathname = self.workdir / filename
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="w") as writer:
await writer.write(content)
logger.info(f"save to: {str(pathname)}")
if dependencies is not None:
dependency_file = await self._git_repo.get_dependency()
await dependency_file.update(pathname, set(dependencies))
logger.info(f"update dependency: {str(pathname)}:{dependencies}")
async def get_dependency(self, filename: Path | str) -> Set[str]:
"""Get the dependencies of a file.
:param filename: The filename or path within the repository.
:return: Set of dependency filenames or paths.
"""
pathname = self.workdir / filename
dependency_file = await self._git_repo.get_dependency()
return await dependency_file.get(pathname)
async def get_changed_dependency(self, filename: Path | str) -> Set[str]:
"""Get the dependencies of a file that have changed.
:param filename: The filename or path within the repository.
:return: List of changed dependency filenames or paths.
"""
dependencies = await self.get_dependency(filename=filename)
changed_files = set(self.changed_files.keys())
changed_dependent_files = set()
for df in dependencies:
rdf = Path(df).relative_to(self._relative_path)
if str(rdf) in changed_files:
changed_dependent_files.add(df)
return changed_dependent_files
async def get(self, filename: Path | str) -> Document | None:
"""Read the content of a file.
:param filename: The filename or path within the repository.
:return: The content of the file.
"""
doc = Document(root_path=str(self.root_path), filename=str(filename))
path_name = self.workdir / filename
if not path_name.exists():
return None
doc.content = await aread(path_name)
return doc
async def get_all(self) -> List[Document]:
"""Get the content of all files in the repository.
:return: List of Document instances representing files.
"""
docs = []
for root, dirs, files in os.walk(str(self.workdir)):
for file in files:
file_path = Path(root) / file
relative_path = file_path.relative_to(self.workdir)
doc = await self.get(relative_path)
docs.append(doc)
return docs
@property
def workdir(self):
"""Return the absolute path to the working directory of the FileRepository.
:return: The absolute path to the working directory.
"""
return self._git_repo.workdir / self._relative_path
@property
def root_path(self):
"""Return the relative path from git repository root"""
return self._relative_path
@property
def changed_files(self) -> Dict[str, str]:
"""Return a dictionary of changed files and their change types.
:return: A dictionary where keys are file paths and values are change types.
"""
files = self._git_repo.changed_files
relative_files = {}
for p, ct in files.items():
if ct.value == "D": # deleted
continue
try:
rf = Path(p).relative_to(self._relative_path)
except ValueError:
continue
relative_files[str(rf)] = ct
return relative_files
@property
def all_files(self) -> List:
"""Get a dictionary of all files in the repository.
The dictionary includes file paths relative to the current FileRepository.
:return: A dictionary where keys are file paths and values are file information.
:rtype: List
"""
return self._git_repo.get_files(relative_path=self._relative_path)
def get_change_dir_files(self, dir: Path | str) -> List:
"""Get the files in a directory that have changed.
:param dir: The directory path within the repository.
:return: List of changed filenames or paths within the directory.
"""
changed_files = self.changed_files
children = []
for f in changed_files:
try:
Path(f).relative_to(Path(dir))
except ValueError:
continue
children.append(str(f))
return children
@staticmethod
def new_filename():
"""Generate a new filename based on the current timestamp and a UUID suffix.
:return: A new filename string.
"""
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
return current_time
# guid_suffix = str(uuid.uuid4())[:8]
# return f"{current_time}x{guid_suffix}"
async def save_doc(self, doc: Document, with_suffix: str = None, dependencies: List[str] = None):
"""Save a Document instance as a PDF file.
This method converts the content of the Document instance to Markdown,
saves it to a file with an optional specified suffix, and logs the saved file.
:param doc: The Document instance to be saved.
:type doc: Document
:param with_suffix: An optional suffix to append to the saved file's name.
:type with_suffix: str, optional
:param dependencies: A list of dependencies for the saved file.
:type dependencies: List[str], optional
"""
m = json.loads(doc.content)
filename = Path(doc.filename).with_suffix(with_suffix) if with_suffix is not None else Path(doc.filename)
await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
logger.debug(f"File Saved: {str(filename)}")
@staticmethod
async def get_file(filename: Path | str, relative_path: Path | str = ".") -> Document | None:
"""Retrieve a specific file from the file repository.
:param filename: The name or path of the file to retrieve.
:type filename: Path or str
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: The document representing the file, or None if not found.
:rtype: Document or None
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get(filename=filename)
@staticmethod
async def get_all_files(relative_path: Path | str = ".") -> List[Document]:
"""Retrieve all files from the file repository.
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: A list of documents representing all files in the repository.
:rtype: List[Document]
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get_all()
@staticmethod
async def save_file(filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."):
"""Save a file to the file repository.
:param filename: The name or path of the file to save.
:type filename: Path or str
:param content: The content of the file.
:param dependencies: A list of dependencies for the file.
:type dependencies: List[str], optional
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save(filename=filename, content=content, dependencies=dependencies)
@staticmethod
async def save_as(
doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "."
):
"""Save a Document instance with optional modifications.
This static method creates a new FileRepository, saves the Document instance
with optional modifications (such as a suffix), and logs the saved file.
:param doc: The Document instance to be saved.
:type doc: Document
:param with_suffix: An optional suffix to append to the saved file's name.
:type with_suffix: str, optional
:param dependencies: A list of dependencies for the saved file.
:type dependencies: List[str], optional
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: A boolean indicating whether the save operation was successful.
:rtype: bool
"""
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save_doc(doc=doc, with_suffix=with_suffix, dependencies=dependencies)
async def delete(self, filename: Path | str):
"""Delete a file from the file repository.
This method deletes a file from the file repository based on the provided filename.
:param filename: The name or path of the file to be deleted.
:type filename: Path or str
"""
pathname = self.workdir / filename
if not pathname.exists():
return
pathname.unlink(missing_ok=True)
dependency_file = await self._git_repo.get_dependency()
await dependency_file.update(filename=pathname, dependencies=None)
logger.info(f"remove dependency key: {str(pathname)}")
@staticmethod
async def delete_file(filename: Path | str, relative_path: Path | str = "."):
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
await file_repo.delete(filename=filename)

View file

@ -1,20 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/19 20:39
@Author : femto Zheng
@File : get_template.py
"""
from metagpt.config import CONFIG
def get_template(templates, format=CONFIG.prompt_format):
selected_templates = templates.get(format)
if selected_templates is None:
raise ValueError(f"Can't find {format} in passed in templates")
# Extract the selected templates
prompt_template = selected_templates["PROMPT_TEMPLATE"]
format_example = selected_templates["FORMAT_EXAMPLE"]
return prompt_template, format_example

View file

@ -0,0 +1,272 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/11/20
@Author : mashenquan
@File : git_repository.py
@Desc: Git repository management. RFC 135 2.2.3.3.
"""
from __future__ import annotations
import shutil
from enum import Enum
from pathlib import Path
from typing import Dict, List
from git.repo import Repo
from git.repo.fun import is_git_dir
from gitignore_parser import parse_gitignore
from metagpt.logs import logger
from metagpt.utils.dependency_file import DependencyFile
from metagpt.utils.file_repository import FileRepository
class ChangeType(Enum):
ADDED = "A" # File was added
COPIED = "C" # File was copied
DELETED = "D" # File was deleted
RENAMED = "R" # File was renamed
MODIFIED = "M" # File was modified
TYPE_CHANGED = "T" # Type of the file was changed
UNTRACTED = "U" # File is untracked (not added to version control)
class GitRepository:
"""A class representing a Git repository.
:param local_path: The local path to the Git repository.
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
Attributes:
_repository (Repo): The GitPython `Repo` object representing the Git repository.
"""
def __init__(self, local_path=None, auto_init=True):
"""Initialize a GitRepository instance.
:param local_path: The local path to the Git repository.
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
"""
self._repository = None
self._dependency = None
self._gitignore_rules = None
if local_path:
self.open(local_path=local_path, auto_init=auto_init)
def open(self, local_path: Path, auto_init=False):
"""Open an existing Git repository or initialize a new one if auto_init is True.
:param local_path: The local path to the Git repository.
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
"""
local_path = Path(local_path)
if self.is_git_dir(local_path):
self._repository = Repo(local_path)
self._gitignore_rules = parse_gitignore(full_path=str(local_path / ".gitignore"))
return
if not auto_init:
return
local_path.mkdir(parents=True, exist_ok=True)
return self._init(local_path)
def _init(self, local_path: Path):
"""Initialize a new Git repository at the specified path.
:param local_path: The local path where the new Git repository will be initialized.
"""
self._repository = Repo.init(path=Path(local_path))
gitignore_filename = Path(local_path) / ".gitignore"
ignores = ["__pycache__", "*.pyc"]
with open(str(gitignore_filename), mode="w") as writer:
writer.write("\n".join(ignores))
self._repository.index.add([".gitignore"])
self._repository.index.commit("Add .gitignore")
self._gitignore_rules = parse_gitignore(full_path=gitignore_filename)
def add_change(self, files: Dict):
"""Add or remove files from the staging area based on the provided changes.
:param files: A dictionary where keys are file paths and values are instances of ChangeType.
"""
if not self.is_valid or not files:
return
for k, v in files.items():
self._repository.index.remove(k) if v is ChangeType.DELETED else self._repository.index.add([k])
def commit(self, comments):
"""Commit the staged changes with the given comments.
:param comments: Comments for the commit.
"""
if self.is_valid:
self._repository.index.commit(comments)
def delete_repository(self):
"""Delete the entire repository directory."""
if self.is_valid:
shutil.rmtree(self._repository.working_dir)
@property
def changed_files(self) -> Dict[str, str]:
"""Return a dictionary of changed files and their change types.
:return: A dictionary where keys are file paths and values are change types.
"""
files = {i: ChangeType.UNTRACTED for i in self._repository.untracked_files}
changed_files = {f.a_path: ChangeType(f.change_type) for f in self._repository.index.diff(None)}
files.update(changed_files)
return files
@staticmethod
def is_git_dir(local_path):
"""Check if the specified directory is a Git repository.
:param local_path: The local path to check.
:return: True if the directory is a Git repository, False otherwise.
"""
git_dir = Path(local_path) / ".git"
if git_dir.exists() and is_git_dir(git_dir):
return True
return False
@property
def is_valid(self):
"""Check if the Git repository is valid (exists and is initialized).
:return: True if the repository is valid, False otherwise.
"""
return bool(self._repository)
@property
def status(self) -> str:
"""Return the Git repository's status as a string."""
if not self.is_valid:
return ""
return self._repository.git.status()
@property
def workdir(self) -> Path | None:
"""Return the path to the working directory of the Git repository.
:return: The path to the working directory or None if the repository is not valid.
"""
if not self.is_valid:
return None
return Path(self._repository.working_dir)
def archive(self, comments="Archive"):
"""Archive the current state of the Git repository.
:param comments: Comments for the archive commit.
"""
logger.info(f"Archive: {list(self.changed_files.keys())}")
self.add_change(self.changed_files)
self.commit(comments)
def new_file_repository(self, relative_path: Path | str = ".") -> FileRepository:
"""Create a new instance of FileRepository associated with this Git repository.
:param relative_path: The relative path to the file repository within the Git repository.
:return: A new instance of FileRepository.
"""
path = Path(relative_path)
try:
path = path.relative_to(self.workdir)
except ValueError:
path = relative_path
return FileRepository(git_repo=self, relative_path=Path(path))
async def get_dependency(self) -> DependencyFile:
"""Get the dependency file associated with the Git repository.
:return: An instance of DependencyFile.
"""
if not self._dependency:
self._dependency = DependencyFile(workdir=self.workdir)
return self._dependency
def rename_root(self, new_dir_name):
"""Rename the root directory of the Git repository.
:param new_dir_name: The new name for the root directory.
"""
if self.workdir.name == new_dir_name:
return
new_path = self.workdir.parent / new_dir_name
if new_path.exists():
logger.info(f"Delete directory {str(new_path)}")
shutil.rmtree(new_path)
try:
shutil.move(src=str(self.workdir), dst=str(new_path))
except Exception as e:
logger.warning(f"Move {str(self.workdir)} to {str(new_path)} error: {e}")
logger.info(f"Rename directory {str(self.workdir)} to {str(new_path)}")
self._repository = Repo(new_path)
self._gitignore_rules = parse_gitignore(full_path=str(new_path / ".gitignore"))
def get_files(self, relative_path: Path | str, root_relative_path: Path | str = None, filter_ignored=True) -> List:
"""
Retrieve a list of files in the specified relative path.
The method returns a list of file paths relative to the current FileRepository.
:param relative_path: The relative path within the repository.
:type relative_path: Path or str
:param root_relative_path: The root relative path within the repository.
:type root_relative_path: Path or str
:param filter_ignored: Flag to indicate whether to filter files based on .gitignore rules.
:type filter_ignored: bool
:return: A list of file paths in the specified directory.
:rtype: List[str]
"""
try:
relative_path = Path(relative_path).relative_to(self.workdir)
except ValueError:
relative_path = Path(relative_path)
if not root_relative_path:
root_relative_path = Path(self.workdir) / relative_path
files = []
try:
directory_path = Path(self.workdir) / relative_path
if not directory_path.exists():
return []
for file_path in directory_path.iterdir():
if file_path.is_file():
rpath = file_path.relative_to(root_relative_path)
files.append(str(rpath))
else:
subfolder_files = self.get_files(
relative_path=file_path, root_relative_path=root_relative_path, filter_ignored=False
)
files.extend(subfolder_files)
except Exception as e:
logger.error(f"Error: {e}")
if not filter_ignored:
return files
filtered_files = self.filter_gitignore(filenames=files, root_relative_path=root_relative_path)
return filtered_files
def filter_gitignore(self, filenames: List[str], root_relative_path: Path | str = None) -> List[str]:
"""
Filter a list of filenames based on .gitignore rules.
:param filenames: A list of filenames to be filtered.
:type filenames: List[str]
:param root_relative_path: The root relative path within the repository.
:type root_relative_path: Path or str
:return: A list of filenames that pass the .gitignore filtering.
:rtype: List[str]
"""
if root_relative_path is None:
root_relative_path = self.workdir
files = []
for filename in filenames:
pathname = root_relative_path / filename
if self._gitignore_rules(str(pathname)):
continue
files.append(filename)
return files

View file

@ -0,0 +1,200 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : graph_repository.py
@Desc : Superclass for graph repository.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword:
IS = "is"
OF = "Of"
ON = "On"
CLASS = "class"
FUNCTION = "function"
HAS_FUNCTION = "has_function"
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_ARGS_DESC = "has_args_desc"
HAS_TYPE_DESC = "has_type_desc"
class SPO(BaseModel):
subject: str
predicate: str
object_: str
class GraphRepository(ABC):
def __init__(self, name: str, **kwargs):
self._repo_name = name
self._kwargs = kwargs
@abstractmethod
async def insert(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def upsert(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def update(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
pass
@property
def name(self) -> str:
return self._repo_name
@staticmethod
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
# file -> class
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
# class detail
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
methods = c.get("methods", [])
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
object_=concat_namespace(file_info.file, class_name, fn),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
# file -> function
await graph_db.insert(
subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f)
)
# function detail
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
for g in file_info.globals:
await graph_db.insert(
subject=concat_namespace(file_info.file, g),
predicate=GraphKeyword.IS,
object_=GraphKeyword.GLOBAL_VARIABLE,
)
for code_block in file_info.page_info:
if code_block.tokens:
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.model_dump_json(),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.model_dump_json(),
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
for c in class_views:
filename, _ = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=c.package)
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
# class -> property
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_PROPERTY,
object_=concat_namespace(c.package, vn),
)
# property detail
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_PROPERTY,
)
await graph_db.insert(
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
if "</I>" in desc and "<I>" not in desc:
logger.error(desc)
# class -> function
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
object_=concat_namespace(c.package, fn),
)
# function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)
@staticmethod
async def update_graph_db_with_class_relationship_views(
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
):
for r in relationship_views:
await graph_db.insert(
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
)
if not r.label:
continue
await graph_db.insert(
subject=r.src,
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
object_=concat_namespace(r.dest, r.label),
)

View file

@ -1,22 +1,22 @@
# 添加代码语法高亮显示
from pygments import highlight as highlight_
from pygments.formatters import HtmlFormatter, TerminalFormatter
from pygments.lexers import PythonLexer, SqlLexer
from pygments.formatters import TerminalFormatter, HtmlFormatter
def highlight(code: str, language: str = 'python', formatter: str = 'terminal'):
def highlight(code: str, language: str = "python", formatter: str = "terminal"):
# 指定要高亮的语言
if language.lower() == 'python':
if language.lower() == "python":
lexer = PythonLexer()
elif language.lower() == 'sql':
elif language.lower() == "sql":
lexer = SqlLexer()
else:
raise ValueError(f"Unsupported language: {language}")
# 指定输出格式
if formatter.lower() == 'terminal':
if formatter.lower() == "terminal":
formatter = TerminalFormatter()
elif formatter.lower() == 'html':
elif formatter.lower() == "html":
formatter = HtmlFormatter()
else:
raise ValueError(f"Unsupported formatter: {formatter}")

View file

@ -18,17 +18,15 @@ from metagpt.config import CONFIG
def make_sk_kernel():
kernel = sk.Kernel()
if CONFIG.openai_api_type == "azure":
if CONFIG.OPENAI_API_TYPE == "azure":
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_api_base, CONFIG.openai_api_key),
AzureChatCompletion(CONFIG.DEPLOYMENT_NAME, CONFIG.OPENAI_BASE_URL, CONFIG.OPENAI_API_KEY),
)
else:
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(
CONFIG.openai_api_model, CONFIG.openai_api_key, org_id=None, endpoint=CONFIG.openai_api_base
),
OpenAIChatCompletion(CONFIG.OPENAI_API_MODEL, CONFIG.OPENAI_API_KEY),
)
return kernel

View file

@ -4,13 +4,15 @@
@Time : 2023/7/4 10:53
@Author : alexanderwu alitrack
@File : mermaid.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
import asyncio
import os
from pathlib import Path
import aiofiles
from metagpt.config import CONFIG
from metagpt.const import PROJECT_ROOT
from metagpt.logs import logger
from metagpt.utils.common import check_cmd_exists
@ -29,7 +31,9 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
tmp = Path(f"{output_file_without_suffix}.mmd")
tmp.write_text(mermaid_code, encoding="utf-8")
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
await f.write(mermaid_code)
# tmp.write_text(mermaid_code, encoding="utf-8")
engine = CONFIG.mermaid_engine.lower()
if engine == "nodejs":
@ -69,7 +73,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
if stdout:
logger.info(stdout.decode())
if stderr:
logger.error(stderr.decode())
logger.warning(stderr.decode())
else:
if engine == "playwright":
from metagpt.utils.mmdc_playwright import mermaid_to_file
@ -88,7 +92,8 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
return 0
MMC1 = """classDiagram
MMC1 = """
classDiagram
class Main {
-SearchEngine search_engine
+main() str
@ -118,9 +123,11 @@ MMC1 = """classDiagram
SearchEngine --> Index
SearchEngine --> Ranking
SearchEngine --> Summary
Index --> KnowledgeBase"""
Index --> KnowledgeBase
"""
MMC2 = """sequenceDiagram
MMC2 = """
sequenceDiagram
participant M as Main
participant SE as SearchEngine
participant I as Index
@ -136,11 +143,5 @@ MMC2 = """sequenceDiagram
R-->>SE: return ranked_results
SE->>S: summarize_results(ranked_results)
S-->>SE: return summary
SE-->>M: return summary"""
if __name__ == "__main__":
loop = asyncio.new_event_loop()
result = loop.run_until_complete(mermaid_to_file(MMC1, PROJECT_ROOT / f"{CONFIG.mermaid_engine}/1"))
result = loop.run_until_complete(mermaid_to_file(MMC2, PROJECT_ROOT / f"{CONFIG.mermaid_engine}/1"))
loop.close()
SE-->>M: return summary
"""

View file

@ -6,9 +6,9 @@
@File : mermaid.py
"""
import base64
import os
from aiohttp import ClientSession,ClientError
from aiohttp import ClientError, ClientSession
from metagpt.logs import logger
@ -29,7 +29,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix):
async with session.get(url) as response:
if response.status == 200:
text = await response.content.read()
with open(output_file, 'wb') as f:
with open(output_file, "wb") as f:
f.write(text)
logger.info(f"Generating {output_file}..")
else:

View file

@ -8,10 +8,13 @@
import os
from urllib.parse import urljoin
from playwright.async_api import async_playwright
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
"""
Converts the given Mermaid code to various output formats and saves them to files.
@ -24,66 +27,72 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
Returns:
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
"""
suffixes=['png', 'svg', 'pdf']
suffixes = ["png", "svg", "pdf"]
__dirname = os.path.dirname(os.path.abspath(__file__))
async with async_playwright() as p:
browser = await p.chromium.launch()
device_scale_factor = 1.0
context = await browser.new_context(
viewport={'width': width, 'height': height},
device_scale_factor=device_scale_factor,
)
viewport={"width": width, "height": height},
device_scale_factor=device_scale_factor,
)
page = await context.new_page()
async def console_message(msg):
logger.info(msg.text)
page.on('console', console_message)
page.on("console", console_message)
try:
await page.set_viewport_size({'width': width, 'height': height})
await page.set_viewport_size({"width": width, "height": height})
mermaid_html_path = os.path.abspath(
os.path.join(__dirname, 'index.html'))
mermaid_html_url = urljoin('file:', mermaid_html_path)
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
mermaid_html_url = urljoin("file:", mermaid_html_path)
await page.goto(mermaid_html_url)
await page.wait_for_load_state("networkidle")
await page.wait_for_selector("div#container", state="attached")
mermaid_config = {}
# mermaid_config = {}
background_color = "#ffffff"
my_css = ""
# my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
#
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}''', [mermaid_code, mermaid_config, my_css, background_color])
if 'svg' in suffixes :
svg_xml = await page.evaluate('''() => {
if "svg" in suffixes:
svg_xml = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const xmlSerializer = new XMLSerializer();
return xmlSerializer.serializeToString(svg);
}''')
}"""
)
logger.info(f"Generating {output_file_without_suffix}.svg..")
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
f.write(svg_xml.encode('utf-8'))
with open(f"{output_file_without_suffix}.svg", "wb") as f:
f.write(svg_xml.encode("utf-8"))
if 'png' in suffixes:
clip = await page.evaluate('''() => {
if "png" in suffixes:
clip = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const rect = svg.getBoundingClientRect();
return {
@ -92,16 +101,17 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
width: Math.ceil(rect.width),
height: Math.ceil(rect.height)
};
}''')
await page.set_viewport_size({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height']})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
}"""
)
await page.set_viewport_size({"width": clip["x"] + clip["width"], "height": clip["y"] + clip["height"]})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
logger.info(f"Generating {output_file_without_suffix}.png..")
with open(f'{output_file_without_suffix}.png', 'wb') as f:
with open(f"{output_file_without_suffix}.png", "wb") as f:
f.write(screenshot)
if 'pdf' in suffixes:
if "pdf" in suffixes:
pdf_data = await page.pdf(scale=device_scale_factor)
logger.info(f"Generating {output_file_without_suffix}.pdf..")
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
f.write(pdf_data)
return 0
except Exception as e:

View file

@ -7,11 +7,14 @@
"""
import os
from urllib.parse import urljoin
from pyppeteer import launch
from metagpt.logs import logger
from metagpt.config import CONFIG
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
from pyppeteer import launch
from metagpt.config import CONFIG
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
"""
Converts the given Mermaid code to various output formats and saves them to files.
@ -24,15 +27,15 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
Returns:
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
"""
suffixes = ['png', 'svg', 'pdf']
suffixes = ["png", "svg", "pdf"]
__dirname = os.path.dirname(os.path.abspath(__file__))
if CONFIG.pyppeteer_executable_path:
browser = await launch(headless=True,
executablePath=CONFIG.pyppeteer_executable_path,
args=['--disable-extensions',"--no-sandbox"]
)
browser = await launch(
headless=True,
executablePath=CONFIG.pyppeteer_executable_path,
args=["--disable-extensions", "--no-sandbox"],
)
else:
logger.error("Please set the environment variable:PYPPETEER_EXECUTABLE_PATH.")
return -1
@ -41,50 +44,56 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
async def console_message(msg):
logger.info(msg.text)
page.on('console', console_message)
page.on("console", console_message)
try:
await page.setViewport(viewport={'width': width, 'height': height, 'deviceScaleFactor': device_scale_factor})
await page.setViewport(viewport={"width": width, "height": height, "deviceScaleFactor": device_scale_factor})
mermaid_html_path = os.path.abspath(
os.path.join(__dirname, 'index.html'))
mermaid_html_url = urljoin('file:', mermaid_html_path)
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
mermaid_html_url = urljoin("file:", mermaid_html_path)
await page.goto(mermaid_html_url)
await page.querySelector("div#container")
mermaid_config = {}
# mermaid_config = {}
background_color = "#ffffff"
my_css = ""
# my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}''', [mermaid_code, mermaid_config, my_css, background_color])
if 'svg' in suffixes :
svg_xml = await page.evaluate('''() => {
if "svg" in suffixes:
svg_xml = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const xmlSerializer = new XMLSerializer();
return xmlSerializer.serializeToString(svg);
}''')
}"""
)
logger.info(f"Generating {output_file_without_suffix}.svg..")
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
f.write(svg_xml.encode('utf-8'))
with open(f"{output_file_without_suffix}.svg", "wb") as f:
f.write(svg_xml.encode("utf-8"))
if 'png' in suffixes:
clip = await page.evaluate('''() => {
if "png" in suffixes:
clip = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const rect = svg.getBoundingClientRect();
return {
@ -93,16 +102,23 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
width: Math.ceil(rect.width),
height: Math.ceil(rect.height)
};
}''')
await page.setViewport({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height'], 'deviceScaleFactor': device_scale_factor})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
}"""
)
await page.setViewport(
{
"width": clip["x"] + clip["width"],
"height": clip["y"] + clip["height"],
"deviceScaleFactor": device_scale_factor,
}
)
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
logger.info(f"Generating {output_file_without_suffix}.png..")
with open(f'{output_file_without_suffix}.png', 'wb') as f:
with open(f"{output_file_without_suffix}.png", "wb") as f:
f.write(screenshot)
if 'pdf' in suffixes:
if "pdf" in suffixes:
pdf_data = await page.pdf(scale=device_scale_factor)
logger.info(f"Generating {output_file_without_suffix}.pdf..")
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
f.write(pdf_data)
return 0
except Exception as e:
@ -110,4 +126,3 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
return -1
finally:
await browser.close()

View file

@ -5,7 +5,7 @@ from typing import Generator, Optional
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr
class WebPage(BaseModel):
@ -13,18 +13,15 @@ class WebPage(BaseModel):
html: str
url: str
class Config:
underscore_attrs_are_private = True
_soup : Optional[BeautifulSoup] = None
_title: Optional[str] = None
_soup: Optional[BeautifulSoup] = PrivateAttr(default=None)
_title: Optional[str] = PrivateAttr(default=None)
@property
def soup(self) -> BeautifulSoup:
if self._soup is None:
self._soup = BeautifulSoup(self.html, "html.parser")
return self._soup
@property
def title(self):
if self._title is None:

View file

@ -37,18 +37,26 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
if not isinstance(expr, cst.Expr):
return None
val = expr.value
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
return None
evaluated_value = val.evaluated_value
evaluated_value = val.evaluated_value
if isinstance(evaluated_value, bytes):
return None
return statement
def has_decorator(node: DocstringNode, name: str) -> bool:
return hasattr(node, "decorators") and any(
(hasattr(i.decorator, "value") and i.decorator.value == name)
or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name)
for i in node.decorators
)
class DocstringCollector(cst.CSTVisitor):
"""A visitor class for collecting docstrings from a CST.
@ -56,6 +64,7 @@ class DocstringCollector(cst.CSTVisitor):
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(self):
self.stack: list[str] = []
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
@ -81,7 +90,7 @@ class DocstringCollector(cst.CSTVisitor):
def _leave(self, node: DocstringNode) -> None:
key = tuple(self.stack)
self.stack.pop()
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
if has_decorator(node, "overload"):
return
statement = get_docstring_statement(node)
@ -96,6 +105,7 @@ class DocstringTransformer(cst.CSTTransformer):
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(
self,
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
@ -125,7 +135,7 @@ class DocstringTransformer(cst.CSTTransformer):
key = tuple(self.stack)
self.stack.pop()
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
if has_decorator(updated_node, "overload"):
return updated_node
statement = self.docstrings.get(key)

View file

@ -8,6 +8,7 @@
import docx
def read_docx(file_path: str) -> list:
"""Open a docx file"""
doc = docx.Document(file_path)

79
metagpt/utils/redis.py Normal file
View file

@ -0,0 +1,79 @@
# !/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/27
@Author : mashenquan
@File : redis.py
"""
from __future__ import annotations
import traceback
from datetime import timedelta
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
from metagpt.config import CONFIG
from metagpt.logs import logger
class Redis:
def __init__(self):
self._client = None
async def _connect(self, force=False):
if self._client and not force:
return True
if not self.is_configured:
return False
try:
self._client = await aioredis.from_url(
f"redis://{CONFIG.REDIS_HOST}:{CONFIG.REDIS_PORT}",
username=CONFIG.REDIS_USER,
password=CONFIG.REDIS_PASSWORD,
db=CONFIG.REDIS_DB,
)
return True
except Exception as e:
logger.warning(f"Redis initialization has failed:{e}")
return False
async def get(self, key: str) -> bytes | None:
if not await self._connect() or not key:
return None
try:
v = await self._client.get(key)
return v
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
return None
async def set(self, key: str, data: str, timeout_sec: int = None):
if not await self._connect() or not key:
return
try:
ex = None if not timeout_sec else timedelta(seconds=timeout_sec)
await self._client.set(key, data, ex=ex)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
async def close(self):
if not self._client:
return
await self._client.close()
self._client = None
@property
def is_valid(self) -> bool:
return self._client is not None
@property
def is_configured(self) -> bool:
return bool(
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)

View file

@ -0,0 +1,314 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : repair llm raw output with particular conditions
import copy
from enum import Enum
from typing import Callable, Union
import regex as re
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.custom_decoder import CustomDecoder
class RepairType(Enum):
CS = "case sensitivity"
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
JSON = "json format"
def repair_case_sensitivity(output: str, req_key: str) -> str:
"""
usually, req_key is the key name of expected json or markdown content, it won't appear in the value part.
fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually
"""
if req_key in output:
return output
output_lower = output.lower()
req_key_lower = req_key.lower()
if req_key_lower in output_lower:
# find the sub-part index, and replace it with raw req_key
lidx = output_lower.find(req_key_lower)
source = output[lidx : lidx + len(req_key_lower)]
output = output.replace(source, req_key)
logger.info(f"repair_case_sensitivity: {req_key}")
return output
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
fix
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
"""
sc_arr = ["/"]
if req_key in output:
return output
for sc in sc_arr:
req_key_pure = req_key.replace(sc, "")
appear_cnt = output.count(req_key_pure)
if req_key_pure in output and appear_cnt > 1:
# req_key with special_character usually in the tail side
ridx = output.rfind(req_key_pure)
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
return output
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
implement the req_key pair in the begin or end of the content
req_key format
1. `[req_key]`, and its pair `[/req_key]`
2. `[/req_key]`, and its pair `[req_key]`
"""
sc = "/" # special char
if req_key.startswith("[") and req_key.endswith("]"):
if sc in req_key:
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
right_key = req_key
else:
left_key = req_key
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
if left_key not in output:
output = left_key + "\n" + output
if right_key not in output:
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
ridx = routput.rfind(left_key)
if ridx < 0:
return None
sub_output = routput[ridx:]
idx1 = sub_output.rfind("}")
idx2 = sub_output.rindex("]")
idx = idx1 if idx1 >= idx2 else idx2
sub_output = sub_output[: idx + 1]
return sub_output
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
# # avoid [req_key]xx[req_key] case to append [/req_key]
output = output + "\n" + right_key
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
sub_content = judge_potential_json(output, left_key)
output = sub_content + "\n" + right_key
return output
def repair_json_format(output: str) -> str:
"""
fix extra `[` or `}` in the end
"""
output = output.strip()
if output.startswith("[{"):
output = output[1:]
logger.info(f"repair_json_format: {'[{'}")
elif output.endswith("}]"):
output = output[:-1]
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
return output
def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str:
repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]]
for repair_type in repair_types:
if repair_type == RepairType.CS:
output = repair_case_sensitivity(output, req_key)
elif repair_type == RepairType.RKPM:
output = repair_required_key_pair_missing(output, req_key)
elif repair_type == RepairType.SCM:
output = repair_special_character_missing(output, req_key)
elif repair_type == RepairType.JSON:
output = repair_json_format(output)
return output
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
"""
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
so here we try to repair it and use all repair methods by default.
typical case
1. case sensitivity
target: "Original Requirements"
output: "Original requirements"
2. special character missing
target: [/CONTENT]
output: [CONTENT]
3. json format
target: { xxx }
output: { xxx }]
"""
if not CONFIG.repair_llm_output:
return output
# do the repairation usually for non-openai models
for req_key in req_keys:
output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type)
return output
def repair_invalid_json(output: str, error: str) -> str:
"""
repair the situation like there are extra chars like
error examples
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
"""
pattern = r"line ([0-9]+)"
matches = re.findall(pattern, error, re.DOTALL)
if len(matches) > 0:
line_no = int(matches[0]) - 1
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
output = output.replace('"""', '"').replace("'''", '"')
arr = output.split("\n")
line = arr[line_no].strip()
# different general problems
if line.endswith("],"):
# problem, redundant char `]`
new_line = line.replace("]", "")
elif line.endswith("},") and not output.endswith("},"):
# problem, redundant char `}`
new_line = line.replace("}", "")
elif line.endswith("},") and output.endswith("},"):
new_line = line[:-1]
elif '",' not in line and "," not in line:
new_line = f'{line}",'
elif "," not in line:
# problem, miss char `,` at the end.
new_line = f"{line},"
elif "," in line and len(line) == 1:
new_line = f'"{line}'
elif '",' in line:
new_line = line[:-2] + "',"
else:
new_line = line
arr[line_no] = new_line
output = "\n".join(arr)
logger.info(f"repair_invalid_json, raw error: {error}")
return output
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
def run_and_passon(retry_state: RetryCallState) -> None:
"""
RetryCallState example
{
"start_time":143.098322024,
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
"args":"(\"tag:[/CONTENT]\",)", # function input args
"kwargs":{}, # function input kwargs
"attempt_number":1, # retry number
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
"outcome_timestamp":143.098416904,
"idle_for":0,
"next_action":"None"
}
"""
if retry_state.outcome.failed:
if retry_state.args:
# # can't be used as args=retry_state.args
func_param_output = retry_state.args[0]
elif retry_state.kwargs:
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
logger.warning(
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
)
repaired_output = repair_invalid_json(func_param_output, exp_str)
retry_state.kwargs["output"] = repaired_output
return run_and_passon
@retry(
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
wait=wait_fixed(1),
after=run_after_exp_and_passon_next_retry(logger),
)
def retry_parse_json_text(output: str) -> Union[list, dict]:
"""
repair the json-text situation like there are extra chars like [']', '}']
Warning
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
it's a two-layer retry cycle
"""
# logger.debug(f"output to json decode:\n{output}")
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
parsed_data = CustomDecoder(strict=False).decode(output)
return parsed_data
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
"""extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern"""
def re_extract_content(cont: str, pattern: str) -> str:
matches = re.findall(pattern, cont, re.DOTALL)
for match in matches:
if match:
cont = match
break
return cont.strip()
# TODO construct the extract pattern with the `right_key`
raw_content = copy.deepcopy(content)
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)
if not new_content.startswith("{"):
# TODO find a more general pattern
# # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
logger.warning(f"extract_content try another pattern: {pattern}")
if right_key not in new_content:
raw_content = copy.deepcopy(new_content + "\n" + right_key)
# # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)
else:
if right_key in new_content:
idx = new_content.find(right_key)
new_content = new_content[:idx]
new_content = new_content.strip()
return new_content
def extract_state_value_from_output(content: str) -> str:
"""
For openai models, they will always return state number. But for open llm models, the instruction result maybe a
long text contain target number, so here add a extraction to improve success rate.
Args:
content (str): llm's output from `Role._think`
"""
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
matches = re.findall(pattern, content, re.DOTALL)
matches = list(set(matches))
state = matches[0] if len(matches) > 0 else "-1"
return state

170
metagpt/utils/s3.py Normal file
View file

@ -0,0 +1,170 @@
import base64
import os.path
import traceback
import uuid
from pathlib import Path
from typing import Optional
import aioboto3
import aiofiles
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT
from metagpt.logs import logger
class S3:
"""A class for interacting with Amazon S3 storage."""
def __init__(self):
self.session = aioboto3.Session()
self.auth_config = {
"service_name": "s3",
"aws_access_key_id": CONFIG.S3_ACCESS_KEY,
"aws_secret_access_key": CONFIG.S3_SECRET_KEY,
"endpoint_url": CONFIG.S3_ENDPOINT_URL,
}
async def upload_file(
self,
bucket: str,
local_path: str,
object_name: str,
) -> None:
"""Upload a file from the local path to the specified path of the storage bucket specified in s3.
Args:
bucket: The name of the S3 storage bucket.
local_path: The local file path, including the file name.
object_name: The complete path of the uploaded file to be stored in S3, including the file name.
Raises:
Exception: If an error occurs during the upload process, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
async with aiofiles.open(local_path, mode="rb") as reader:
body = await reader.read()
await client.put_object(Body=body, Bucket=bucket, Key=object_name)
logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.")
except Exception as e:
logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}")
raise e
async def get_object_url(
self,
bucket: str,
object_name: str,
) -> str:
"""Get the URL for a downloadable or preview file stored in the specified S3 bucket.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
Returns:
The URL for the downloadable or preview file.
Raises:
Exception: If an error occurs while retrieving the URL, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
file = await client.get_object(Bucket=bucket, Key=object_name)
return str(file["Body"].url)
except Exception as e:
logger.error(f"Failed to get the url for a downloadable or preview file: {e}")
raise e
async def get_object(
self,
bucket: str,
object_name: str,
) -> bytes:
"""Get the binary data of a file stored in the specified S3 bucket.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
Returns:
The binary data of the requested file.
Raises:
Exception: If an error occurs while retrieving the file data, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
return await s3_object["Body"].read()
except Exception as e:
logger.error(f"Failed to get the binary data of the file: {e}")
raise e
async def download_file(
self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024
) -> None:
"""Download an S3 object to a local file.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
local_path: The local file path where the S3 object will be downloaded.
chunk_size: The size of data chunks to read and write at a time. Default is 128 KB.
Raises:
Exception: If an error occurs during the download process, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
stream = s3_object["Body"]
async with aiofiles.open(local_path, mode="wb") as writer:
while True:
file_data = await stream.read(chunk_size)
if not file_data:
break
await writer.write(file_data)
except Exception as e:
logger.error(f"Failed to download the file from S3: {e}")
raise e
async def cache(self, data: str, file_ext: str, format: str = "") -> str:
"""Save data to remote S3 and return url"""
object_name = uuid.uuid4().hex + file_ext
path = Path(__file__).parent
pathname = path / object_name
try:
async with aiofiles.open(str(pathname), mode="wb") as file:
data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8")
await file.write(data)
bucket = CONFIG.S3_BUCKET
object_pathname = CONFIG.S3_BUCKET or "system"
object_pathname += f"/{object_name}"
object_pathname = os.path.normpath(object_pathname)
await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname)
pathname.unlink(missing_ok=True)
return await self.get_object_url(bucket=bucket, object_name=object_pathname)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
pathname.unlink(missing_ok=True)
return None
@property
def is_valid(self):
return self.is_configured
@property
def is_configured(self) -> bool:
return bool(
CONFIG.S3_ACCESS_KEY
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
and CONFIG.S3_SECRET_KEY
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
and CONFIG.S3_ENDPOINT_URL
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
and CONFIG.S3_BUCKET
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
)

View file

@ -4,13 +4,11 @@
import copy
import pickle
from typing import Dict, List
from metagpt.actions.action_output import ActionOutput
from metagpt.schema import Message
from metagpt.utils.common import import_class
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
def actionoutout_schema_to_mapping(schema: dict) -> dict:
"""
directly traverse the `properties` in the first level.
schema structure likes
@ -35,32 +33,50 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
if property["type"] == "string":
mapping[field] = (str, ...)
elif property["type"] == "array" and property["items"]["type"] == "string":
mapping[field] = (List[str], ...)
mapping[field] = (list[str], ...)
elif property["type"] == "array" and property["items"]["type"] == "array":
# here only consider the `List[List[str]]` situation
mapping[field] = (List[List[str]], ...)
# here only consider the `list[list[str]]` situation
mapping[field] = (list[list[str]], ...)
return mapping
def serialize_message(message: Message):
def actionoutput_mapping_to_str(mapping: dict) -> dict:
new_mapping = {}
for key, value in mapping.items():
new_mapping[key] = str(value)
return new_mapping
def actionoutput_str_to_mapping(mapping: dict) -> dict:
new_mapping = {}
for key, value in mapping.items():
if value == "(<class 'str'>, Ellipsis)":
new_mapping[key] = (str, ...)
else:
new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)`
return new_mapping
def serialize_message(message: "Message"):
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference
ic = message_cp.instruct_content
if ic:
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
schema = ic.schema()
schema = ic.model_json_schema()
mapping = actionoutout_schema_to_mapping(schema)
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
msg_ser = pickle.dumps(message_cp)
return msg_ser
def deserialize_message(message_ser: str) -> Message:
def deserialize_message(message_ser: str) -> "Message":
message = pickle.loads(message_ser)
if message.instruct_content:
ic = message.instruct_content
ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
ic_new = ic_obj(**ic["value"])
message.instruct_content = ic_new

View file

@ -20,4 +20,3 @@ class Singleton(abc.ABCMeta, type):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]

View file

@ -1,4 +1,4 @@
# token to separate different code messages in a WriteCode Message content
MSG_SEP = "#*000*#"
MSG_SEP = "#*000*#"
# token to seperate file name and the actual code text in a code message
FILENAME_CODE_SEP = "#*001*#"

View file

@ -3,7 +3,12 @@ from typing import Generator, Sequence
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
def reduce_message_length(
msgs: Generator[str, None, None],
model_name: str,
system_text: str,
reserved: int = 0,
) -> str:
"""Reduce the length of concatenated message segments to fit within the maximum token size.
Args:
@ -49,9 +54,9 @@ def generate_prompt_chunk(
current_token = 0
current_lines = []
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
reserved = reserved + count_string_tokens(prompt_template + system_text, model_name)
# 100 is a magic number to ensure the maximum context length is not exceeded
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
while paragraphs:
paragraph = paragraphs.pop(0)
@ -103,7 +108,7 @@ def decode_unicode_escape(text: str) -> str:
return text.encode("utf-8").decode("unicode_escape", "ignore")
def _split_by_count(lst: Sequence , count: int):
def _split_by_count(lst: Sequence, count: int):
avg = len(lst) // count
remainder = len(lst) % count
start = 0

View file

@ -7,6 +7,7 @@
ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py
ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py
ref4: https://ai.google.dev/models/gemini
"""
import tiktoken
@ -16,13 +17,18 @@ TOKEN_COSTS = {
"gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002},
"gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004},
"gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004},
"gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002},
"gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004},
"gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002},
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
"gpt-4": {"prompt": 0.03, "completion": 0.06},
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
}
@ -32,13 +38,18 @@ TOKEN_MAX = {
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo-16k-0613": 16384,
"gpt-35-turbo": 4096,
"gpt-35-turbo-16k": 16384,
"gpt-3.5-turbo-1106": 16384,
"gpt-4-0314": 8192,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-4-0613": 8192,
"gpt-4-1106-preview": 128000,
"text-embedding-ada-002": 8192,
"chatglm_turbo": 32768
"chatglm_turbo": 32768,
"gemini-pro": 32768,
}
@ -52,22 +63,34 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-35-turbo",
"gpt-35-turbo-16k",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-1106",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
}:
tokens_per_message = 3
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
elif "gpt-3.5-turbo" == model:
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
elif "gpt-4" == model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_message_tokens(messages, model="gpt-4-0613")
elif "open-llm-model" == model:
"""
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
inaccurate. It's a reference result.
"""
tokens_per_message = 0 # ignore conversation message template prefix
tokens_per_name = 0
else:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}. "
@ -96,7 +119,11 @@ def count_string_tokens(string: str, model_name: str) -> int:
Returns:
int: The number of tokens in the text string.
"""
encoding = tiktoken.encoding_for_model(model_name)
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(string))