mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 19:36:24 +02:00
Merge branch 'main' into dev_updated
This commit is contained in:
commit
853086924a
429 changed files with 24237 additions and 5835 deletions
49
metagpt/utils/ahttp_client.py
Normal file
49
metagpt/utils/ahttp_client.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : pure async http_client
|
||||
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client import DEFAULT_TIMEOUT
|
||||
|
||||
|
||||
async def apost(
|
||||
url: str,
|
||||
params: Optional[Mapping[str, str]] = None,
|
||||
json: Any = None,
|
||||
data: Any = None,
|
||||
headers: Optional[dict] = None,
|
||||
as_json: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
timeout: int = DEFAULT_TIMEOUT.total,
|
||||
) -> Union[str, dict]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
|
||||
if as_json:
|
||||
data = await resp.json()
|
||||
else:
|
||||
data = await resp.read()
|
||||
data = data.decode(encoding)
|
||||
return data
|
||||
|
||||
|
||||
async def apost_stream(
|
||||
url: str,
|
||||
params: Optional[Mapping[str, str]] = None,
|
||||
json: Any = None,
|
||||
data: Any = None,
|
||||
headers: Optional[dict] = None,
|
||||
encoding: str = "utf-8",
|
||||
timeout: int = DEFAULT_TIMEOUT.total,
|
||||
) -> Any:
|
||||
"""
|
||||
usage:
|
||||
result = astream(url="xx")
|
||||
async for line in result:
|
||||
deal_with(line)
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url=url, params=params, json=json, data=data, headers=headers, timeout=timeout) as resp:
|
||||
async for line in resp.content:
|
||||
yield line.decode(encoding)
|
||||
|
|
@ -4,16 +4,35 @@
|
|||
@Time : 2023/4/29 16:07
|
||||
@Author : alexanderwu
|
||||
@File : common.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.2 of RFC 116:
|
||||
Add generic class-to-string and object-to-string conversion functionality.
|
||||
@Modified By: mashenquan, 2023/11/27. Bug fix: `parse_recipient` failed to parse the recipient in certain GPT-3.5
|
||||
responses.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from typing import List, Tuple, Union
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
from pydantic_core import to_jsonable_python
|
||||
from tenacity import RetryCallState, _utils
|
||||
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_ALL
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
def check_cmd_exists(command) -> int:
|
||||
|
|
@ -29,6 +48,12 @@ def check_cmd_exists(command) -> int:
|
|||
return result
|
||||
|
||||
|
||||
def require_python_version(req_version: Tuple) -> bool:
|
||||
if not (2 <= len(req_version) <= 3):
|
||||
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
|
||||
return bool(sys.version_info > req_version)
|
||||
|
||||
|
||||
class OutputParser:
|
||||
@classmethod
|
||||
def parse_blocks(cls, text: str):
|
||||
|
|
@ -85,10 +110,7 @@ class OutputParser:
|
|||
|
||||
@staticmethod
|
||||
def parse_python_code(text: str) -> str:
|
||||
for pattern in (
|
||||
r"(.*?```python.*?\s+)?(?P<code>.*)(```.*?)",
|
||||
r"(.*?```python.*?\s+)?(?P<code>.*)",
|
||||
):
|
||||
for pattern in (r"(.*?```python.*?\s+)?(?P<code>.*)(```.*?)", r"(.*?```python.*?\s+)?(?P<code>.*)"):
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if not match:
|
||||
continue
|
||||
|
|
@ -109,18 +131,28 @@ class OutputParser:
|
|||
try:
|
||||
content = cls.parse_code(text=content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 尝试解析list
|
||||
try:
|
||||
content = cls.parse_file_list(text=content)
|
||||
except Exception:
|
||||
pass
|
||||
# 尝试解析list
|
||||
try:
|
||||
content = cls.parse_file_list(text=content)
|
||||
except Exception:
|
||||
pass
|
||||
parsed_data[block] = content
|
||||
return parsed_data
|
||||
|
||||
@staticmethod
|
||||
def extract_content(text, tag="CONTENT"):
|
||||
# Use regular expression to extract content between [CONTENT] and [/CONTENT]
|
||||
extracted_content = re.search(rf"\[{tag}\](.*?)\[/{tag}\]", text, re.DOTALL)
|
||||
|
||||
if extracted_content:
|
||||
return extracted_content.group(1).strip()
|
||||
else:
|
||||
raise ValueError(f"Could not find content between [{tag}] and [/{tag}]")
|
||||
|
||||
@classmethod
|
||||
def parse_data_with_mapping(cls, data, mapping):
|
||||
if "[CONTENT]" in data:
|
||||
data = cls.extract_content(text=data)
|
||||
block_dict = cls.parse_blocks(data)
|
||||
parsed_data = {}
|
||||
for block, content in block_dict.items():
|
||||
|
|
@ -187,7 +219,7 @@ class OutputParser:
|
|||
result = ast.literal_eval(structure_text)
|
||||
|
||||
# Ensure the result matches the specified data type
|
||||
if isinstance(result, list) or isinstance(result, dict):
|
||||
if isinstance(result, (list, dict)):
|
||||
return result
|
||||
|
||||
raise ValueError(f"The extracted structure is not a {data_type}.")
|
||||
|
|
@ -219,10 +251,15 @@ class CodeParser:
|
|||
# 遍历所有的block
|
||||
for block in blocks:
|
||||
# 如果block不为空,则继续处理
|
||||
if block.strip() != "":
|
||||
if block.strip() == "":
|
||||
continue
|
||||
if "\n" not in block:
|
||||
block_title = block
|
||||
block_content = ""
|
||||
else:
|
||||
# 将block的标题和内容分开,并分别去掉前后的空白字符
|
||||
block_title, block_content = block.split("\n", 1)
|
||||
block_dict[block_title.strip()] = block_content.strip()
|
||||
block_dict[block_title.strip()] = block_content.strip()
|
||||
|
||||
return block_dict
|
||||
|
||||
|
|
@ -282,9 +319,6 @@ class NoMoneyException(Exception):
|
|||
def print_members(module, indent=0):
|
||||
"""
|
||||
https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python
|
||||
:param module:
|
||||
:param indent:
|
||||
:return:
|
||||
"""
|
||||
prefix = " " * indent
|
||||
for name, obj in inspect.getmembers(module):
|
||||
|
|
@ -302,9 +336,16 @@ def print_members(module, indent=0):
|
|||
|
||||
|
||||
def parse_recipient(text):
|
||||
# FIXME: use ActionNode instead.
|
||||
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
|
||||
recipient = re.search(pattern, text)
|
||||
return recipient.group(1) if recipient else ""
|
||||
if recipient:
|
||||
return recipient.group(1)
|
||||
pattern = r"Send To:\s*([A-Za-z]+)\s*?"
|
||||
recipient = re.search(pattern, text)
|
||||
if recipient:
|
||||
return recipient.group(1)
|
||||
return ""
|
||||
|
||||
|
||||
def create_func_config(func_schema: dict) -> dict:
|
||||
|
|
@ -329,3 +370,224 @@ def remove_comments(code_str):
|
|||
clean_code = re.sub(pattern, replace_func, code_str, flags=re.MULTILINE)
|
||||
clean_code = os.linesep.join([s.rstrip() for s in clean_code.splitlines() if s.strip()])
|
||||
return clean_code
|
||||
|
||||
|
||||
def get_class_name(cls) -> str:
|
||||
"""Return class name"""
|
||||
return f"{cls.__module__}.{cls.__name__}"
|
||||
|
||||
|
||||
def any_to_str(val: Any) -> str:
|
||||
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
elif not callable(val):
|
||||
return get_class_name(type(val))
|
||||
else:
|
||||
return get_class_name(val)
|
||||
|
||||
|
||||
def any_to_str_set(val) -> set:
|
||||
"""Convert any type to string set."""
|
||||
res = set()
|
||||
|
||||
# Check if the value is iterable, but not a string (since strings are technically iterable)
|
||||
if isinstance(val, (dict, list, set, tuple)):
|
||||
# Special handling for dictionaries to iterate over values
|
||||
if isinstance(val, dict):
|
||||
val = val.values()
|
||||
|
||||
for i in val:
|
||||
res.add(any_to_str(i))
|
||||
else:
|
||||
res.add(any_to_str(val))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def is_subscribed(message: "Message", tags: set):
|
||||
"""Return whether it's consumer"""
|
||||
if MESSAGE_ROUTE_TO_ALL in message.send_to:
|
||||
return True
|
||||
|
||||
for i in tags:
|
||||
if i in message.send_to:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def any_to_name(val):
|
||||
"""
|
||||
Convert a value to its name by extracting the last part of the dotted path.
|
||||
|
||||
:param val: The value to convert.
|
||||
|
||||
:return: The name of the value.
|
||||
"""
|
||||
return any_to_str(val).split(".")[-1]
|
||||
|
||||
|
||||
def concat_namespace(*args) -> str:
|
||||
return ":".join(str(value) for value in args)
|
||||
|
||||
|
||||
def split_namespace(ns_class_name: str) -> List[str]:
|
||||
return ns_class_name.split(":")
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
"""
|
||||
Generates a logging function to be used after a call is retried.
|
||||
|
||||
This generated function logs an error message with the outcome of the retried function call. It includes
|
||||
the name of the function, the time taken for the call in seconds (formatted according to `sec_format`),
|
||||
the number of attempts made, and the exception raised, if any.
|
||||
|
||||
:param i: A Logger instance from the loguru library used to log the error message.
|
||||
:param sec_format: A string format specifier for how to format the number of seconds since the start of the call.
|
||||
Defaults to three decimal places.
|
||||
:return: A callable that accepts a RetryCallState object and returns None. This callable logs the details
|
||||
of the retried call.
|
||||
"""
|
||||
|
||||
def log_it(retry_state: "RetryCallState") -> None:
|
||||
# If the function name is not known, default to "<unknown>"
|
||||
if retry_state.fn is None:
|
||||
fn_name = "<unknown>"
|
||||
else:
|
||||
# Retrieve the callable's name using a utility function
|
||||
fn_name = _utils.get_callback_name(retry_state.fn)
|
||||
|
||||
# Log an error message with the function name, time since start, attempt number, and the exception
|
||||
i.error(
|
||||
f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
|
||||
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
|
||||
f"exp: {retry_state.outcome.exception()}"
|
||||
)
|
||||
|
||||
return log_it
|
||||
|
||||
|
||||
def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
|
||||
if not Path(json_file).exists():
|
||||
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
|
||||
|
||||
with open(json_file, "r", encoding=encoding) as fin:
|
||||
try:
|
||||
data = json.load(fin)
|
||||
except Exception:
|
||||
raise ValueError(f"read json file: {json_file} failed")
|
||||
return data
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: list, encoding=None):
|
||||
folder_path = Path(json_file).parent
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(json_file, "w", encoding=encoding) as fout:
|
||||
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
|
||||
|
||||
|
||||
def import_class(class_name: str, module_name: str) -> type:
|
||||
module = importlib.import_module(module_name)
|
||||
a_class = getattr(module, class_name)
|
||||
return a_class
|
||||
|
||||
|
||||
def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object:
|
||||
a_class = import_class(class_name, module_name)
|
||||
class_inst = a_class(*args, **kwargs)
|
||||
return class_inst
|
||||
|
||||
|
||||
def format_trackback_info(limit: int = 2):
|
||||
return traceback.format_exc(limit=limit)
|
||||
|
||||
|
||||
def serialize_decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
result = await func(self, *args, **kwargs)
|
||||
return result
|
||||
except KeyboardInterrupt:
|
||||
logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}")
|
||||
except Exception:
|
||||
logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}")
|
||||
self.serialize() # Team.serialize
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def role_raise_decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except KeyboardInterrupt as kbi:
|
||||
logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project")
|
||||
if self.latest_observed_msg:
|
||||
self.rc.memory.delete(self.latest_observed_msg)
|
||||
# raise again to make it captured outside
|
||||
raise Exception(format_trackback_info(limit=None))
|
||||
except Exception:
|
||||
if self.latest_observed_msg:
|
||||
logger.warning(
|
||||
"There is a exception in role's execution, in order to resume, "
|
||||
"we delete the newest role communication message in the role's memory."
|
||||
)
|
||||
# remove role newest observed msg to make it observed again
|
||||
self.rc.memory.delete(self.latest_observed_msg)
|
||||
# raise again to make it captured outside
|
||||
raise Exception(format_trackback_info(limit=None))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@handle_exception
|
||||
async def aread(filename: str | Path, encoding=None) -> str:
|
||||
"""Read file asynchronously."""
|
||||
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
|
||||
content = await reader.read()
|
||||
return content
|
||||
|
||||
|
||||
async def awrite(filename: str | Path, data: str, encoding=None):
|
||||
"""Write file asynchronously."""
|
||||
pathname = Path(filename)
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(str(pathname), mode="w", encoding=encoding) as writer:
|
||||
await writer.write(data)
|
||||
|
||||
|
||||
async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
|
||||
if not Path(filename).exists():
|
||||
return ""
|
||||
lines = []
|
||||
async with aiofiles.open(str(filename), mode="r") as reader:
|
||||
ix = 0
|
||||
while ix < end_lineno:
|
||||
ix += 1
|
||||
line = await reader.readline()
|
||||
if ix < lineno:
|
||||
continue
|
||||
if ix > end_lineno:
|
||||
break
|
||||
lines.append(line)
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
def list_files(root: str | Path) -> List[Path]:
|
||||
files = []
|
||||
try:
|
||||
directory_path = Path(root)
|
||||
if not directory_path.exists():
|
||||
return []
|
||||
for file_path in directory_path.iterdir():
|
||||
if file_path.is_file():
|
||||
files.append(file_path)
|
||||
else:
|
||||
subfolder_files = list_files(root=file_path)
|
||||
files.extend(subfolder_files)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return files
|
||||
|
|
|
|||
82
metagpt/utils/cost_manager.py
Normal file
82
metagpt/utils/cost_manager.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/28
|
||||
@Author : mashenquan
|
||||
@File : openai.py
|
||||
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.token_counter import TOKEN_COSTS
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
total_prompt_tokens: int
|
||||
total_completion_tokens: int
|
||||
total_cost: float
|
||||
total_budget: float
|
||||
|
||||
|
||||
class CostManager(BaseModel):
|
||||
"""Calculate the overhead of using the interface."""
|
||||
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
total_budget: float = 0
|
||||
max_budget: float = 10.0
|
||||
total_cost: float = 0
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
"""
|
||||
Get the total number of prompt tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of prompt tokens.
|
||||
"""
|
||||
return self.total_prompt_tokens
|
||||
|
||||
def get_total_completion_tokens(self):
|
||||
"""
|
||||
Get the total number of completion tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of completion tokens.
|
||||
"""
|
||||
return self.total_completion_tokens
|
||||
|
||||
def get_total_cost(self):
|
||||
"""
|
||||
Get the total cost of API calls.
|
||||
|
||||
Returns:
|
||||
float: The total cost of API calls.
|
||||
"""
|
||||
return self.total_cost
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
"""Get all costs"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
|
|
@ -25,7 +25,7 @@ def py_make_scanner(context):
|
|||
except IndexError:
|
||||
raise StopIteration(idx) from None
|
||||
|
||||
if nextchar == '"' or nextchar == "'":
|
||||
if nextchar in ("'", '"'):
|
||||
if idx + 2 < len(string) and string[idx + 1] == nextchar and string[idx + 2] == nextchar:
|
||||
# Handle the case where the next two characters are the same as nextchar
|
||||
return parse_string(string, idx + 3, strict, delimiter=nextchar * 3) # triple quote
|
||||
|
|
|
|||
102
metagpt/utils/dependency_file.py
Normal file
102
metagpt/utils/dependency_file.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/22
|
||||
@Author : mashenquan
|
||||
@File : dependency_file.py
|
||||
@Desc: Implementation of the dependency file described in Section 2.2.3.2 of RFC 135.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class DependencyFile:
|
||||
"""A class representing a DependencyFile for managing dependencies.
|
||||
|
||||
:param workdir: The working directory path for the DependencyFile.
|
||||
"""
|
||||
|
||||
def __init__(self, workdir: Path | str):
|
||||
"""Initialize a DependencyFile instance.
|
||||
|
||||
:param workdir: The working directory path for the DependencyFile.
|
||||
"""
|
||||
self._dependencies = {}
|
||||
self._filename = Path(workdir) / ".dependencies.json"
|
||||
|
||||
async def load(self):
|
||||
"""Load dependencies from the file asynchronously."""
|
||||
if not self._filename.exists():
|
||||
return
|
||||
self._dependencies = json.loads(await aread(self._filename))
|
||||
|
||||
@handle_exception
|
||||
async def save(self):
|
||||
"""Save dependencies to the file asynchronously."""
|
||||
data = json.dumps(self._dependencies)
|
||||
async with aiofiles.open(str(self._filename), mode="w") as writer:
|
||||
await writer.write(data)
|
||||
|
||||
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
|
||||
"""Update dependencies for a file asynchronously.
|
||||
|
||||
:param filename: The filename or path.
|
||||
:param dependencies: The set of dependencies.
|
||||
:param persist: Whether to persist the changes immediately.
|
||||
"""
|
||||
if persist:
|
||||
await self.load()
|
||||
|
||||
root = self._filename.parent
|
||||
try:
|
||||
key = Path(filename).relative_to(root)
|
||||
except ValueError:
|
||||
key = filename
|
||||
|
||||
if dependencies:
|
||||
relative_paths = []
|
||||
for i in dependencies:
|
||||
try:
|
||||
relative_paths.append(str(Path(i).relative_to(root)))
|
||||
except ValueError:
|
||||
relative_paths.append(str(i))
|
||||
self._dependencies[str(key)] = relative_paths
|
||||
elif str(key) in self._dependencies:
|
||||
del self._dependencies[str(key)]
|
||||
|
||||
if persist:
|
||||
await self.save()
|
||||
|
||||
async def get(self, filename: Path | str, persist=True):
|
||||
"""Get dependencies for a file asynchronously.
|
||||
|
||||
:param filename: The filename or path.
|
||||
:param persist: Whether to load dependencies from the file immediately.
|
||||
:return: A set of dependencies.
|
||||
"""
|
||||
if persist:
|
||||
await self.load()
|
||||
|
||||
root = self._filename.parent
|
||||
try:
|
||||
key = Path(filename).relative_to(root)
|
||||
except ValueError:
|
||||
key = filename
|
||||
return set(self._dependencies.get(str(key), {}))
|
||||
|
||||
def delete_file(self):
|
||||
"""Delete the dependency file."""
|
||||
self._filename.unlink(missing_ok=True)
|
||||
|
||||
@property
|
||||
def exists(self):
|
||||
"""Check if the dependency file exists."""
|
||||
return self._filename.exists()
|
||||
82
metagpt/utils/di_graph_repository.py
Normal file
82
metagpt/utils/di_graph_repository.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : di_graph_repository.py
|
||||
@Desc : Graph repository based on DiGraph
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import networkx
|
||||
|
||||
from metagpt.utils.common import aread, awrite
|
||||
from metagpt.utils.graph_repository import SPO, GraphRepository
|
||||
|
||||
|
||||
class DiGraphRepository(GraphRepository):
|
||||
def __init__(self, name: str, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
self._repo = networkx.DiGraph()
|
||||
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
self._repo.add_edge(subject, object_, predicate=predicate)
|
||||
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
result = []
|
||||
for s, o, p in self._repo.edges(data="predicate"):
|
||||
if subject and subject != s:
|
||||
continue
|
||||
if predicate and predicate != p:
|
||||
continue
|
||||
if object_ and object_ != o:
|
||||
continue
|
||||
result.append(SPO(subject=s, predicate=p, object_=o))
|
||||
return result
|
||||
|
||||
def json(self) -> str:
|
||||
m = networkx.node_link_data(self._repo)
|
||||
data = json.dumps(m)
|
||||
return data
|
||||
|
||||
async def save(self, path: str | Path = None):
|
||||
data = self.json()
|
||||
path = path or self._kwargs.get("root")
|
||||
if not path.exists():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = Path(path) / self.name
|
||||
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
|
||||
|
||||
async def load(self, pathname: str | Path):
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
m = json.loads(data)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
|
||||
@staticmethod
|
||||
async def load_from(pathname: str | Path) -> GraphRepository:
|
||||
pathname = Path(pathname)
|
||||
name = pathname.with_suffix("").name
|
||||
root = pathname.parent
|
||||
graph = DiGraphRepository(name=name, root=root)
|
||||
if pathname.exists():
|
||||
await graph.load(pathname=pathname)
|
||||
return graph
|
||||
|
||||
@property
|
||||
def root(self) -> str:
|
||||
return self._kwargs.get("root")
|
||||
|
||||
@property
|
||||
def pathname(self) -> Path:
|
||||
p = Path(self.root) / self.name
|
||||
return p.with_suffix(".json")
|
||||
61
metagpt/utils/exceptions.py
Normal file
61
metagpt/utils/exceptions.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19 14:46
|
||||
@Author : alexanderwu
|
||||
@File : exceptions.py
|
||||
"""
|
||||
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import traceback
|
||||
from typing import Any, Callable, Tuple, Type, TypeVar, Union
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
ReturnType = TypeVar("ReturnType")
|
||||
|
||||
|
||||
def handle_exception(
|
||||
_func: Callable[..., ReturnType] = None,
|
||||
*,
|
||||
exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception,
|
||||
exception_msg: str = "",
|
||||
default_return: Any = None,
|
||||
) -> Callable[..., ReturnType]:
|
||||
"""handle exception, return default value"""
|
||||
|
||||
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except exception_type as e:
|
||||
logger.opt(depth=1).error(
|
||||
f"{e}: {exception_msg}, "
|
||||
f"\nCalling {func.__name__} with args: {args}, kwargs: {kwargs} "
|
||||
f"\nStack: {traceback.format_exc()}"
|
||||
)
|
||||
return default_return
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exception_type as e:
|
||||
logger.opt(depth=1).error(
|
||||
f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, "
|
||||
f"stack: {traceback.format_exc()}"
|
||||
)
|
||||
return default_return
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
if _func is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(_func)
|
||||
|
|
@ -6,10 +6,12 @@
|
|||
@File : file.py
|
||||
@Describe : General file operations.
|
||||
"""
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class File:
|
||||
|
|
@ -18,6 +20,7 @@ class File:
|
|||
CHUNK_SIZE = 64 * 1024
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
async def write(cls, root_path: Path, filename: str, content: bytes) -> Path:
|
||||
"""Write the file content to the local specified path.
|
||||
|
||||
|
|
@ -32,18 +35,15 @@ class File:
|
|||
Raises:
|
||||
Exception: If an unexpected error occurs during the file writing process.
|
||||
"""
|
||||
try:
|
||||
root_path.mkdir(parents=True, exist_ok=True)
|
||||
full_path = root_path / filename
|
||||
async with aiofiles.open(full_path, mode="wb") as writer:
|
||||
await writer.write(content)
|
||||
logger.debug(f"Successfully write file: {full_path}")
|
||||
return full_path
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing file: {e}")
|
||||
raise e
|
||||
root_path.mkdir(parents=True, exist_ok=True)
|
||||
full_path = root_path / filename
|
||||
async with aiofiles.open(full_path, mode="wb") as writer:
|
||||
await writer.write(content)
|
||||
logger.debug(f"Successfully write file: {full_path}")
|
||||
return full_path
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
async def read(cls, file_path: Path, chunk_size: int = None) -> bytes:
|
||||
"""Partitioning read the file content from the local specified path.
|
||||
|
||||
|
|
@ -57,19 +57,14 @@ class File:
|
|||
Raises:
|
||||
Exception: If an unexpected error occurs during the file reading process.
|
||||
"""
|
||||
try:
|
||||
chunk_size = chunk_size or cls.CHUNK_SIZE
|
||||
async with aiofiles.open(file_path, mode="rb") as reader:
|
||||
chunks = list()
|
||||
while True:
|
||||
chunk = await reader.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
content = b''.join(chunks)
|
||||
logger.debug(f"Successfully read file, the path of file: {file_path}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file: {e}")
|
||||
raise e
|
||||
|
||||
chunk_size = chunk_size or cls.CHUNK_SIZE
|
||||
async with aiofiles.open(file_path, mode="rb") as reader:
|
||||
chunks = list()
|
||||
while True:
|
||||
chunk = await reader.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
logger.debug(f"Successfully read file, the path of file: {file_path}")
|
||||
return content
|
||||
|
|
|
|||
290
metagpt/utils/file_repository.py
Normal file
290
metagpt/utils/file_repository.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/20
|
||||
@Author : mashenquan
|
||||
@File : git_repository.py
|
||||
@Desc: File repository management. RFC 135 2.2.3.2, 2.2.3.4 and 2.2.3.13.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.json_to_markdown import json_to_markdown
|
||||
|
||||
|
||||
class FileRepository:
|
||||
"""A class representing a FileRepository associated with a Git repository.
|
||||
|
||||
:param git_repo: The associated GitRepository instance.
|
||||
:param relative_path: The relative path within the Git repository.
|
||||
|
||||
Attributes:
|
||||
_relative_path (Path): The relative path within the Git repository.
|
||||
_git_repo (GitRepository): The associated GitRepository instance.
|
||||
"""
|
||||
|
||||
def __init__(self, git_repo, relative_path: Path = Path(".")):
|
||||
"""Initialize a FileRepository instance.
|
||||
|
||||
:param git_repo: The associated GitRepository instance.
|
||||
:param relative_path: The relative path within the Git repository.
|
||||
"""
|
||||
self._relative_path = relative_path
|
||||
self._git_repo = git_repo
|
||||
|
||||
# Initializing
|
||||
self.workdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def save(self, filename: Path | str, content, dependencies: List[str] = None):
|
||||
"""Save content to a file and update its dependencies.
|
||||
|
||||
:param filename: The filename or path within the repository.
|
||||
:param content: The content to be saved.
|
||||
:param dependencies: List of dependency filenames or paths.
|
||||
"""
|
||||
pathname = self.workdir / filename
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(str(pathname), mode="w") as writer:
|
||||
await writer.write(content)
|
||||
logger.info(f"save to: {str(pathname)}")
|
||||
|
||||
if dependencies is not None:
|
||||
dependency_file = await self._git_repo.get_dependency()
|
||||
await dependency_file.update(pathname, set(dependencies))
|
||||
logger.info(f"update dependency: {str(pathname)}:{dependencies}")
|
||||
|
||||
async def get_dependency(self, filename: Path | str) -> Set[str]:
|
||||
"""Get the dependencies of a file.
|
||||
|
||||
:param filename: The filename or path within the repository.
|
||||
:return: Set of dependency filenames or paths.
|
||||
"""
|
||||
pathname = self.workdir / filename
|
||||
dependency_file = await self._git_repo.get_dependency()
|
||||
return await dependency_file.get(pathname)
|
||||
|
||||
async def get_changed_dependency(self, filename: Path | str) -> Set[str]:
|
||||
"""Get the dependencies of a file that have changed.
|
||||
|
||||
:param filename: The filename or path within the repository.
|
||||
:return: List of changed dependency filenames or paths.
|
||||
"""
|
||||
dependencies = await self.get_dependency(filename=filename)
|
||||
changed_files = set(self.changed_files.keys())
|
||||
changed_dependent_files = set()
|
||||
for df in dependencies:
|
||||
rdf = Path(df).relative_to(self._relative_path)
|
||||
if str(rdf) in changed_files:
|
||||
changed_dependent_files.add(df)
|
||||
return changed_dependent_files
|
||||
|
||||
async def get(self, filename: Path | str) -> Document | None:
|
||||
"""Read the content of a file.
|
||||
|
||||
:param filename: The filename or path within the repository.
|
||||
:return: The content of the file.
|
||||
"""
|
||||
doc = Document(root_path=str(self.root_path), filename=str(filename))
|
||||
path_name = self.workdir / filename
|
||||
if not path_name.exists():
|
||||
return None
|
||||
doc.content = await aread(path_name)
|
||||
return doc
|
||||
|
||||
async def get_all(self) -> List[Document]:
|
||||
"""Get the content of all files in the repository.
|
||||
|
||||
:return: List of Document instances representing files.
|
||||
"""
|
||||
docs = []
|
||||
for root, dirs, files in os.walk(str(self.workdir)):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
relative_path = file_path.relative_to(self.workdir)
|
||||
doc = await self.get(relative_path)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@property
|
||||
def workdir(self):
|
||||
"""Return the absolute path to the working directory of the FileRepository.
|
||||
|
||||
:return: The absolute path to the working directory.
|
||||
"""
|
||||
return self._git_repo.workdir / self._relative_path
|
||||
|
||||
@property
|
||||
def root_path(self):
|
||||
"""Return the relative path from git repository root"""
|
||||
return self._relative_path
|
||||
|
||||
@property
|
||||
def changed_files(self) -> Dict[str, str]:
|
||||
"""Return a dictionary of changed files and their change types.
|
||||
|
||||
:return: A dictionary where keys are file paths and values are change types.
|
||||
"""
|
||||
files = self._git_repo.changed_files
|
||||
relative_files = {}
|
||||
for p, ct in files.items():
|
||||
if ct.value == "D": # deleted
|
||||
continue
|
||||
try:
|
||||
rf = Path(p).relative_to(self._relative_path)
|
||||
except ValueError:
|
||||
continue
|
||||
relative_files[str(rf)] = ct
|
||||
return relative_files
|
||||
|
||||
@property
|
||||
def all_files(self) -> List:
|
||||
"""Get a dictionary of all files in the repository.
|
||||
|
||||
The dictionary includes file paths relative to the current FileRepository.
|
||||
|
||||
:return: A dictionary where keys are file paths and values are file information.
|
||||
:rtype: List
|
||||
"""
|
||||
return self._git_repo.get_files(relative_path=self._relative_path)
|
||||
|
||||
def get_change_dir_files(self, dir: Path | str) -> List:
|
||||
"""Get the files in a directory that have changed.
|
||||
|
||||
:param dir: The directory path within the repository.
|
||||
:return: List of changed filenames or paths within the directory.
|
||||
"""
|
||||
changed_files = self.changed_files
|
||||
children = []
|
||||
for f in changed_files:
|
||||
try:
|
||||
Path(f).relative_to(Path(dir))
|
||||
except ValueError:
|
||||
continue
|
||||
children.append(str(f))
|
||||
return children
|
||||
|
||||
@staticmethod
|
||||
def new_filename():
|
||||
"""Generate a new filename based on the current timestamp and a UUID suffix.
|
||||
|
||||
:return: A new filename string.
|
||||
"""
|
||||
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
return current_time
|
||||
# guid_suffix = str(uuid.uuid4())[:8]
|
||||
# return f"{current_time}x{guid_suffix}"
|
||||
|
||||
async def save_doc(self, doc: Document, with_suffix: str = None, dependencies: List[str] = None):
|
||||
"""Save a Document instance as a PDF file.
|
||||
|
||||
This method converts the content of the Document instance to Markdown,
|
||||
saves it to a file with an optional specified suffix, and logs the saved file.
|
||||
|
||||
:param doc: The Document instance to be saved.
|
||||
:type doc: Document
|
||||
:param with_suffix: An optional suffix to append to the saved file's name.
|
||||
:type with_suffix: str, optional
|
||||
:param dependencies: A list of dependencies for the saved file.
|
||||
:type dependencies: List[str], optional
|
||||
"""
|
||||
m = json.loads(doc.content)
|
||||
filename = Path(doc.filename).with_suffix(with_suffix) if with_suffix is not None else Path(doc.filename)
|
||||
await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
|
||||
logger.debug(f"File Saved: {str(filename)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_file(filename: Path | str, relative_path: Path | str = ".") -> Document | None:
|
||||
"""Retrieve a specific file from the file repository.
|
||||
|
||||
:param filename: The name or path of the file to retrieve.
|
||||
:type filename: Path or str
|
||||
:param relative_path: The relative path within the file repository.
|
||||
:type relative_path: Path or str, optional
|
||||
:return: The document representing the file, or None if not found.
|
||||
:rtype: Document or None
|
||||
"""
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
return await file_repo.get(filename=filename)
|
||||
|
||||
@staticmethod
|
||||
async def get_all_files(relative_path: Path | str = ".") -> List[Document]:
|
||||
"""Retrieve all files from the file repository.
|
||||
|
||||
:param relative_path: The relative path within the file repository.
|
||||
:type relative_path: Path or str, optional
|
||||
:return: A list of documents representing all files in the repository.
|
||||
:rtype: List[Document]
|
||||
"""
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
return await file_repo.get_all()
|
||||
|
||||
@staticmethod
|
||||
async def save_file(filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."):
|
||||
"""Save a file to the file repository.
|
||||
|
||||
:param filename: The name or path of the file to save.
|
||||
:type filename: Path or str
|
||||
:param content: The content of the file.
|
||||
:param dependencies: A list of dependencies for the file.
|
||||
:type dependencies: List[str], optional
|
||||
:param relative_path: The relative path within the file repository.
|
||||
:type relative_path: Path or str, optional
|
||||
"""
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
return await file_repo.save(filename=filename, content=content, dependencies=dependencies)
|
||||
|
||||
@staticmethod
|
||||
async def save_as(
|
||||
doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "."
|
||||
):
|
||||
"""Save a Document instance with optional modifications.
|
||||
|
||||
This static method creates a new FileRepository, saves the Document instance
|
||||
with optional modifications (such as a suffix), and logs the saved file.
|
||||
|
||||
:param doc: The Document instance to be saved.
|
||||
:type doc: Document
|
||||
:param with_suffix: An optional suffix to append to the saved file's name.
|
||||
:type with_suffix: str, optional
|
||||
:param dependencies: A list of dependencies for the saved file.
|
||||
:type dependencies: List[str], optional
|
||||
:param relative_path: The relative path within the file repository.
|
||||
:type relative_path: Path or str, optional
|
||||
:return: A boolean indicating whether the save operation was successful.
|
||||
:rtype: bool
|
||||
"""
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
return await file_repo.save_doc(doc=doc, with_suffix=with_suffix, dependencies=dependencies)
|
||||
|
||||
async def delete(self, filename: Path | str):
|
||||
"""Delete a file from the file repository.
|
||||
|
||||
This method deletes a file from the file repository based on the provided filename.
|
||||
|
||||
:param filename: The name or path of the file to be deleted.
|
||||
:type filename: Path or str
|
||||
"""
|
||||
pathname = self.workdir / filename
|
||||
if not pathname.exists():
|
||||
return
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
dependency_file = await self._git_repo.get_dependency()
|
||||
await dependency_file.update(filename=pathname, dependencies=None)
|
||||
logger.info(f"remove dependency key: {str(pathname)}")
|
||||
|
||||
@staticmethod
|
||||
async def delete_file(filename: Path | str, relative_path: Path | str = "."):
|
||||
file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path)
|
||||
await file_repo.delete(filename=filename)
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/19 20:39
|
||||
@Author : femto Zheng
|
||||
@File : get_template.py
|
||||
"""
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
def get_template(templates, format=CONFIG.prompt_format):
|
||||
selected_templates = templates.get(format)
|
||||
if selected_templates is None:
|
||||
raise ValueError(f"Can't find {format} in passed in templates")
|
||||
|
||||
# Extract the selected templates
|
||||
prompt_template = selected_templates["PROMPT_TEMPLATE"]
|
||||
format_example = selected_templates["FORMAT_EXAMPLE"]
|
||||
|
||||
return prompt_template, format_example
|
||||
272
metagpt/utils/git_repository.py
Normal file
272
metagpt/utils/git_repository.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/11/20
|
||||
@Author : mashenquan
|
||||
@File : git_repository.py
|
||||
@Desc: Git repository management. RFC 135 2.2.3.3.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from git.repo import Repo
|
||||
from git.repo.fun import is_git_dir
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.dependency_file import DependencyFile
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
class ChangeType(Enum):
|
||||
ADDED = "A" # File was added
|
||||
COPIED = "C" # File was copied
|
||||
DELETED = "D" # File was deleted
|
||||
RENAMED = "R" # File was renamed
|
||||
MODIFIED = "M" # File was modified
|
||||
TYPE_CHANGED = "T" # Type of the file was changed
|
||||
UNTRACTED = "U" # File is untracked (not added to version control)
|
||||
|
||||
|
||||
class GitRepository:
|
||||
"""A class representing a Git repository.
|
||||
|
||||
:param local_path: The local path to the Git repository.
|
||||
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
|
||||
|
||||
Attributes:
|
||||
_repository (Repo): The GitPython `Repo` object representing the Git repository.
|
||||
"""
|
||||
|
||||
def __init__(self, local_path=None, auto_init=True):
|
||||
"""Initialize a GitRepository instance.
|
||||
|
||||
:param local_path: The local path to the Git repository.
|
||||
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
|
||||
"""
|
||||
self._repository = None
|
||||
self._dependency = None
|
||||
self._gitignore_rules = None
|
||||
if local_path:
|
||||
self.open(local_path=local_path, auto_init=auto_init)
|
||||
|
||||
def open(self, local_path: Path, auto_init=False):
|
||||
"""Open an existing Git repository or initialize a new one if auto_init is True.
|
||||
|
||||
:param local_path: The local path to the Git repository.
|
||||
:param auto_init: If True, automatically initializes a new Git repository if the provided path is not a Git repository.
|
||||
"""
|
||||
local_path = Path(local_path)
|
||||
if self.is_git_dir(local_path):
|
||||
self._repository = Repo(local_path)
|
||||
self._gitignore_rules = parse_gitignore(full_path=str(local_path / ".gitignore"))
|
||||
return
|
||||
if not auto_init:
|
||||
return
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
return self._init(local_path)
|
||||
|
||||
def _init(self, local_path: Path):
|
||||
"""Initialize a new Git repository at the specified path.
|
||||
|
||||
:param local_path: The local path where the new Git repository will be initialized.
|
||||
"""
|
||||
self._repository = Repo.init(path=Path(local_path))
|
||||
|
||||
gitignore_filename = Path(local_path) / ".gitignore"
|
||||
ignores = ["__pycache__", "*.pyc"]
|
||||
with open(str(gitignore_filename), mode="w") as writer:
|
||||
writer.write("\n".join(ignores))
|
||||
self._repository.index.add([".gitignore"])
|
||||
self._repository.index.commit("Add .gitignore")
|
||||
self._gitignore_rules = parse_gitignore(full_path=gitignore_filename)
|
||||
|
||||
def add_change(self, files: Dict):
|
||||
"""Add or remove files from the staging area based on the provided changes.
|
||||
|
||||
:param files: A dictionary where keys are file paths and values are instances of ChangeType.
|
||||
"""
|
||||
if not self.is_valid or not files:
|
||||
return
|
||||
|
||||
for k, v in files.items():
|
||||
self._repository.index.remove(k) if v is ChangeType.DELETED else self._repository.index.add([k])
|
||||
|
||||
def commit(self, comments):
|
||||
"""Commit the staged changes with the given comments.
|
||||
|
||||
:param comments: Comments for the commit.
|
||||
"""
|
||||
if self.is_valid:
|
||||
self._repository.index.commit(comments)
|
||||
|
||||
def delete_repository(self):
|
||||
"""Delete the entire repository directory."""
|
||||
if self.is_valid:
|
||||
shutil.rmtree(self._repository.working_dir)
|
||||
|
||||
@property
|
||||
def changed_files(self) -> Dict[str, str]:
|
||||
"""Return a dictionary of changed files and their change types.
|
||||
|
||||
:return: A dictionary where keys are file paths and values are change types.
|
||||
"""
|
||||
files = {i: ChangeType.UNTRACTED for i in self._repository.untracked_files}
|
||||
changed_files = {f.a_path: ChangeType(f.change_type) for f in self._repository.index.diff(None)}
|
||||
files.update(changed_files)
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def is_git_dir(local_path):
|
||||
"""Check if the specified directory is a Git repository.
|
||||
|
||||
:param local_path: The local path to check.
|
||||
:return: True if the directory is a Git repository, False otherwise.
|
||||
"""
|
||||
git_dir = Path(local_path) / ".git"
|
||||
if git_dir.exists() and is_git_dir(git_dir):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
"""Check if the Git repository is valid (exists and is initialized).
|
||||
|
||||
:return: True if the repository is valid, False otherwise.
|
||||
"""
|
||||
return bool(self._repository)
|
||||
|
||||
@property
|
||||
def status(self) -> str:
|
||||
"""Return the Git repository's status as a string."""
|
||||
if not self.is_valid:
|
||||
return ""
|
||||
return self._repository.git.status()
|
||||
|
||||
@property
|
||||
def workdir(self) -> Path | None:
|
||||
"""Return the path to the working directory of the Git repository.
|
||||
|
||||
:return: The path to the working directory or None if the repository is not valid.
|
||||
"""
|
||||
if not self.is_valid:
|
||||
return None
|
||||
return Path(self._repository.working_dir)
|
||||
|
||||
def archive(self, comments="Archive"):
|
||||
"""Archive the current state of the Git repository.
|
||||
|
||||
:param comments: Comments for the archive commit.
|
||||
"""
|
||||
logger.info(f"Archive: {list(self.changed_files.keys())}")
|
||||
self.add_change(self.changed_files)
|
||||
self.commit(comments)
|
||||
|
||||
def new_file_repository(self, relative_path: Path | str = ".") -> FileRepository:
|
||||
"""Create a new instance of FileRepository associated with this Git repository.
|
||||
|
||||
:param relative_path: The relative path to the file repository within the Git repository.
|
||||
:return: A new instance of FileRepository.
|
||||
"""
|
||||
path = Path(relative_path)
|
||||
try:
|
||||
path = path.relative_to(self.workdir)
|
||||
except ValueError:
|
||||
path = relative_path
|
||||
return FileRepository(git_repo=self, relative_path=Path(path))
|
||||
|
||||
async def get_dependency(self) -> DependencyFile:
|
||||
"""Get the dependency file associated with the Git repository.
|
||||
|
||||
:return: An instance of DependencyFile.
|
||||
"""
|
||||
if not self._dependency:
|
||||
self._dependency = DependencyFile(workdir=self.workdir)
|
||||
return self._dependency
|
||||
|
||||
def rename_root(self, new_dir_name):
|
||||
"""Rename the root directory of the Git repository.
|
||||
|
||||
:param new_dir_name: The new name for the root directory.
|
||||
"""
|
||||
if self.workdir.name == new_dir_name:
|
||||
return
|
||||
new_path = self.workdir.parent / new_dir_name
|
||||
if new_path.exists():
|
||||
logger.info(f"Delete directory {str(new_path)}")
|
||||
shutil.rmtree(new_path)
|
||||
try:
|
||||
shutil.move(src=str(self.workdir), dst=str(new_path))
|
||||
except Exception as e:
|
||||
logger.warning(f"Move {str(self.workdir)} to {str(new_path)} error: {e}")
|
||||
logger.info(f"Rename directory {str(self.workdir)} to {str(new_path)}")
|
||||
self._repository = Repo(new_path)
|
||||
self._gitignore_rules = parse_gitignore(full_path=str(new_path / ".gitignore"))
|
||||
|
||||
def get_files(self, relative_path: Path | str, root_relative_path: Path | str = None, filter_ignored=True) -> List:
|
||||
"""
|
||||
Retrieve a list of files in the specified relative path.
|
||||
|
||||
The method returns a list of file paths relative to the current FileRepository.
|
||||
|
||||
:param relative_path: The relative path within the repository.
|
||||
:type relative_path: Path or str
|
||||
:param root_relative_path: The root relative path within the repository.
|
||||
:type root_relative_path: Path or str
|
||||
:param filter_ignored: Flag to indicate whether to filter files based on .gitignore rules.
|
||||
:type filter_ignored: bool
|
||||
:return: A list of file paths in the specified directory.
|
||||
:rtype: List[str]
|
||||
"""
|
||||
try:
|
||||
relative_path = Path(relative_path).relative_to(self.workdir)
|
||||
except ValueError:
|
||||
relative_path = Path(relative_path)
|
||||
|
||||
if not root_relative_path:
|
||||
root_relative_path = Path(self.workdir) / relative_path
|
||||
files = []
|
||||
try:
|
||||
directory_path = Path(self.workdir) / relative_path
|
||||
if not directory_path.exists():
|
||||
return []
|
||||
for file_path in directory_path.iterdir():
|
||||
if file_path.is_file():
|
||||
rpath = file_path.relative_to(root_relative_path)
|
||||
files.append(str(rpath))
|
||||
else:
|
||||
subfolder_files = self.get_files(
|
||||
relative_path=file_path, root_relative_path=root_relative_path, filter_ignored=False
|
||||
)
|
||||
files.extend(subfolder_files)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
if not filter_ignored:
|
||||
return files
|
||||
filtered_files = self.filter_gitignore(filenames=files, root_relative_path=root_relative_path)
|
||||
return filtered_files
|
||||
|
||||
def filter_gitignore(self, filenames: List[str], root_relative_path: Path | str = None) -> List[str]:
|
||||
"""
|
||||
Filter a list of filenames based on .gitignore rules.
|
||||
|
||||
:param filenames: A list of filenames to be filtered.
|
||||
:type filenames: List[str]
|
||||
:param root_relative_path: The root relative path within the repository.
|
||||
:type root_relative_path: Path or str
|
||||
:return: A list of filenames that pass the .gitignore filtering.
|
||||
:rtype: List[str]
|
||||
"""
|
||||
if root_relative_path is None:
|
||||
root_relative_path = self.workdir
|
||||
files = []
|
||||
for filename in filenames:
|
||||
pathname = root_relative_path / filename
|
||||
if self._gitignore_rules(str(pathname)):
|
||||
continue
|
||||
files.append(filename)
|
||||
return files
|
||||
200
metagpt/utils/graph_repository.py
Normal file
200
metagpt/utils/graph_repository.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : graph_repository.py
|
||||
@Desc : Superclass for graph repository.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace
|
||||
|
||||
|
||||
class GraphKeyword:
|
||||
IS = "is"
|
||||
OF = "Of"
|
||||
ON = "On"
|
||||
CLASS = "class"
|
||||
FUNCTION = "function"
|
||||
HAS_FUNCTION = "has_function"
|
||||
SOURCE_CODE = "source_code"
|
||||
NULL = "<null>"
|
||||
GLOBAL_VARIABLE = "global_variable"
|
||||
CLASS_FUNCTION = "class_function"
|
||||
CLASS_PROPERTY = "class_property"
|
||||
HAS_CLASS_FUNCTION = "has_class_function"
|
||||
HAS_CLASS_PROPERTY = "has_class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
HAS_SEQUENCE_VIEW = "has_sequence_view"
|
||||
HAS_ARGS_DESC = "has_args_desc"
|
||||
HAS_TYPE_DESC = "has_type_desc"
|
||||
|
||||
|
||||
class SPO(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object_: str
|
||||
|
||||
|
||||
class GraphRepository(ABC):
|
||||
def __init__(self, name: str, **kwargs):
|
||||
self._repo_name = name
|
||||
self._kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._repo_name
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
|
||||
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
file_types = {".py": "python", ".js": "javascript"}
|
||||
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
|
||||
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
|
||||
for c in file_info.classes:
|
||||
class_name = c.get("name", "")
|
||||
# file -> class
|
||||
await graph_db.insert(
|
||||
subject=file_info.file,
|
||||
predicate=GraphKeyword.HAS_CLASS,
|
||||
object_=concat_namespace(file_info.file, class_name),
|
||||
)
|
||||
# class detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
methods = c.get("methods", [])
|
||||
for fn in methods:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
object_=concat_namespace(file_info.file, class_name, fn),
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
)
|
||||
for f in file_info.functions:
|
||||
# file -> function
|
||||
await graph_db.insert(
|
||||
subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f)
|
||||
)
|
||||
# function detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
|
||||
)
|
||||
for g in file_info.globals:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, g),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.GLOBAL_VARIABLE,
|
||||
)
|
||||
for code_block in file_info.page_info:
|
||||
if code_block.tokens:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, *code_block.tokens),
|
||||
predicate=GraphKeyword.HAS_PAGE_INFO,
|
||||
object_=code_block.model_dump_json(),
|
||||
)
|
||||
for k, v in code_block.properties.items():
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, k, v),
|
||||
predicate=GraphKeyword.HAS_PAGE_INFO,
|
||||
object_=code_block.model_dump_json(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
|
||||
for c in class_views:
|
||||
filename, _ = c.package.split(":", 1)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
file_types = {".py": "python", ".js": "javascript"}
|
||||
file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=c.package)
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
for vn, vt in c.attributes.items():
|
||||
# class -> property
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.HAS_CLASS_PROPERTY,
|
||||
object_=concat_namespace(c.package, vn),
|
||||
)
|
||||
# property detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_PROPERTY,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
|
||||
)
|
||||
for fn, desc in c.methods.items():
|
||||
if "</I>" in desc and "<I>" not in desc:
|
||||
logger.error(desc)
|
||||
# class -> function
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
object_=concat_namespace(c.package, fn),
|
||||
)
|
||||
# function detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.HAS_ARGS_DESC,
|
||||
object_=desc,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_relationship_views(
|
||||
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
|
||||
):
|
||||
for r in relationship_views:
|
||||
await graph_db.insert(
|
||||
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
|
||||
)
|
||||
if not r.label:
|
||||
continue
|
||||
await graph_db.insert(
|
||||
subject=r.src,
|
||||
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
|
||||
object_=concat_namespace(r.dest, r.label),
|
||||
)
|
||||
|
|
@ -1,22 +1,22 @@
|
|||
# 添加代码语法高亮显示
|
||||
from pygments import highlight as highlight_
|
||||
from pygments.formatters import HtmlFormatter, TerminalFormatter
|
||||
from pygments.lexers import PythonLexer, SqlLexer
|
||||
from pygments.formatters import TerminalFormatter, HtmlFormatter
|
||||
|
||||
|
||||
def highlight(code: str, language: str = 'python', formatter: str = 'terminal'):
|
||||
def highlight(code: str, language: str = "python", formatter: str = "terminal"):
|
||||
# 指定要高亮的语言
|
||||
if language.lower() == 'python':
|
||||
if language.lower() == "python":
|
||||
lexer = PythonLexer()
|
||||
elif language.lower() == 'sql':
|
||||
elif language.lower() == "sql":
|
||||
lexer = SqlLexer()
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
# 指定输出格式
|
||||
if formatter.lower() == 'terminal':
|
||||
if formatter.lower() == "terminal":
|
||||
formatter = TerminalFormatter()
|
||||
elif formatter.lower() == 'html':
|
||||
elif formatter.lower() == "html":
|
||||
formatter = HtmlFormatter()
|
||||
else:
|
||||
raise ValueError(f"Unsupported formatter: {formatter}")
|
||||
|
|
|
|||
|
|
@ -18,17 +18,15 @@ from metagpt.config import CONFIG
|
|||
|
||||
def make_sk_kernel():
|
||||
kernel = sk.Kernel()
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
if CONFIG.OPENAI_API_TYPE == "azure":
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_api_base, CONFIG.openai_api_key),
|
||||
AzureChatCompletion(CONFIG.DEPLOYMENT_NAME, CONFIG.OPENAI_BASE_URL, CONFIG.OPENAI_API_KEY),
|
||||
)
|
||||
else:
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
OpenAIChatCompletion(
|
||||
CONFIG.openai_api_model, CONFIG.openai_api_key, org_id=None, endpoint=CONFIG.openai_api_base
|
||||
),
|
||||
OpenAIChatCompletion(CONFIG.OPENAI_API_MODEL, CONFIG.OPENAI_API_KEY),
|
||||
)
|
||||
|
||||
return kernel
|
||||
|
|
|
|||
|
|
@ -4,13 +4,15 @@
|
|||
@Time : 2023/7/4 10:53
|
||||
@Author : alexanderwu alitrack
|
||||
@File : mermaid.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import check_cmd_exists
|
||||
|
||||
|
|
@ -29,7 +31,9 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
tmp = Path(f"{output_file_without_suffix}.mmd")
|
||||
tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
|
||||
await f.write(mermaid_code)
|
||||
# tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
|
||||
engine = CONFIG.mermaid_engine.lower()
|
||||
if engine == "nodejs":
|
||||
|
|
@ -69,7 +73,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
if stdout:
|
||||
logger.info(stdout.decode())
|
||||
if stderr:
|
||||
logger.error(stderr.decode())
|
||||
logger.warning(stderr.decode())
|
||||
else:
|
||||
if engine == "playwright":
|
||||
from metagpt.utils.mmdc_playwright import mermaid_to_file
|
||||
|
|
@ -88,7 +92,8 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
return 0
|
||||
|
||||
|
||||
MMC1 = """classDiagram
|
||||
MMC1 = """
|
||||
classDiagram
|
||||
class Main {
|
||||
-SearchEngine search_engine
|
||||
+main() str
|
||||
|
|
@ -118,9 +123,11 @@ MMC1 = """classDiagram
|
|||
SearchEngine --> Index
|
||||
SearchEngine --> Ranking
|
||||
SearchEngine --> Summary
|
||||
Index --> KnowledgeBase"""
|
||||
Index --> KnowledgeBase
|
||||
"""
|
||||
|
||||
MMC2 = """sequenceDiagram
|
||||
MMC2 = """
|
||||
sequenceDiagram
|
||||
participant M as Main
|
||||
participant SE as SearchEngine
|
||||
participant I as Index
|
||||
|
|
@ -136,11 +143,5 @@ MMC2 = """sequenceDiagram
|
|||
R-->>SE: return ranked_results
|
||||
SE->>S: summarize_results(ranked_results)
|
||||
S-->>SE: return summary
|
||||
SE-->>M: return summary"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC1, PROJECT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC2, PROJECT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
loop.close()
|
||||
SE-->>M: return summary
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@
|
|||
@File : mermaid.py
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
|
||||
from aiohttp import ClientSession,ClientError
|
||||
from aiohttp import ClientError, ClientSession
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix):
|
|||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
text = await response.content.read()
|
||||
with open(output_file, 'wb') as f:
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(text)
|
||||
logger.info(f"Generating {output_file}..")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@
|
|||
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
"""
|
||||
Converts the given Mermaid code to various output formats and saves them to files.
|
||||
|
||||
|
|
@ -24,66 +27,72 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
Returns:
|
||||
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
|
||||
"""
|
||||
suffixes=['png', 'svg', 'pdf']
|
||||
suffixes = ["png", "svg", "pdf"]
|
||||
__dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch()
|
||||
device_scale_factor = 1.0
|
||||
context = await browser.new_context(
|
||||
viewport={'width': width, 'height': height},
|
||||
device_scale_factor=device_scale_factor,
|
||||
)
|
||||
viewport={"width": width, "height": height},
|
||||
device_scale_factor=device_scale_factor,
|
||||
)
|
||||
page = await context.new_page()
|
||||
|
||||
async def console_message(msg):
|
||||
logger.info(msg.text)
|
||||
page.on('console', console_message)
|
||||
|
||||
page.on("console", console_message)
|
||||
|
||||
try:
|
||||
await page.set_viewport_size({'width': width, 'height': height})
|
||||
await page.set_viewport_size({"width": width, "height": height})
|
||||
|
||||
mermaid_html_path = os.path.abspath(
|
||||
os.path.join(__dirname, 'index.html'))
|
||||
mermaid_html_url = urljoin('file:', mermaid_html_path)
|
||||
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
|
||||
mermaid_html_url = urljoin("file:", mermaid_html_path)
|
||||
await page.goto(mermaid_html_url)
|
||||
await page.wait_for_load_state("networkidle")
|
||||
|
||||
await page.wait_for_selector("div#container", state="attached")
|
||||
mermaid_config = {}
|
||||
# mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
my_css = ""
|
||||
# my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
#
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
|
||||
}''', [mermaid_code, mermaid_config, my_css, background_color])
|
||||
|
||||
if 'svg' in suffixes :
|
||||
svg_xml = await page.evaluate('''() => {
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
"""() => {
|
||||
const svg = document.querySelector('svg');
|
||||
const xmlSerializer = new XMLSerializer();
|
||||
return xmlSerializer.serializeToString(svg);
|
||||
}''')
|
||||
}"""
|
||||
)
|
||||
logger.info(f"Generating {output_file_without_suffix}.svg..")
|
||||
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
|
||||
f.write(svg_xml.encode('utf-8'))
|
||||
with open(f"{output_file_without_suffix}.svg", "wb") as f:
|
||||
f.write(svg_xml.encode("utf-8"))
|
||||
|
||||
if 'png' in suffixes:
|
||||
clip = await page.evaluate('''() => {
|
||||
if "png" in suffixes:
|
||||
clip = await page.evaluate(
|
||||
"""() => {
|
||||
const svg = document.querySelector('svg');
|
||||
const rect = svg.getBoundingClientRect();
|
||||
return {
|
||||
|
|
@ -92,16 +101,17 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
width: Math.ceil(rect.width),
|
||||
height: Math.ceil(rect.height)
|
||||
};
|
||||
}''')
|
||||
await page.set_viewport_size({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height']})
|
||||
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
|
||||
}"""
|
||||
)
|
||||
await page.set_viewport_size({"width": clip["x"] + clip["width"], "height": clip["y"] + clip["height"]})
|
||||
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
|
||||
logger.info(f"Generating {output_file_without_suffix}.png..")
|
||||
with open(f'{output_file_without_suffix}.png', 'wb') as f:
|
||||
with open(f"{output_file_without_suffix}.png", "wb") as f:
|
||||
f.write(screenshot)
|
||||
if 'pdf' in suffixes:
|
||||
if "pdf" in suffixes:
|
||||
pdf_data = await page.pdf(scale=device_scale_factor)
|
||||
logger.info(f"Generating {output_file_without_suffix}.pdf..")
|
||||
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
|
||||
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
|
||||
f.write(pdf_data)
|
||||
return 0
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -7,11 +7,14 @@
|
|||
"""
|
||||
import os
|
||||
from urllib.parse import urljoin
|
||||
from pyppeteer import launch
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
|
||||
from pyppeteer import launch
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
"""
|
||||
Converts the given Mermaid code to various output formats and saves them to files.
|
||||
|
||||
|
|
@ -24,15 +27,15 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
Returns:
|
||||
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
|
||||
"""
|
||||
suffixes = ['png', 'svg', 'pdf']
|
||||
suffixes = ["png", "svg", "pdf"]
|
||||
__dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
if CONFIG.pyppeteer_executable_path:
|
||||
browser = await launch(headless=True,
|
||||
executablePath=CONFIG.pyppeteer_executable_path,
|
||||
args=['--disable-extensions',"--no-sandbox"]
|
||||
)
|
||||
browser = await launch(
|
||||
headless=True,
|
||||
executablePath=CONFIG.pyppeteer_executable_path,
|
||||
args=["--disable-extensions", "--no-sandbox"],
|
||||
)
|
||||
else:
|
||||
logger.error("Please set the environment variable:PYPPETEER_EXECUTABLE_PATH.")
|
||||
return -1
|
||||
|
|
@ -41,50 +44,56 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
|
||||
async def console_message(msg):
|
||||
logger.info(msg.text)
|
||||
page.on('console', console_message)
|
||||
|
||||
page.on("console", console_message)
|
||||
|
||||
try:
|
||||
await page.setViewport(viewport={'width': width, 'height': height, 'deviceScaleFactor': device_scale_factor})
|
||||
await page.setViewport(viewport={"width": width, "height": height, "deviceScaleFactor": device_scale_factor})
|
||||
|
||||
mermaid_html_path = os.path.abspath(
|
||||
os.path.join(__dirname, 'index.html'))
|
||||
mermaid_html_url = urljoin('file:', mermaid_html_path)
|
||||
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
|
||||
mermaid_html_url = urljoin("file:", mermaid_html_path)
|
||||
await page.goto(mermaid_html_url)
|
||||
|
||||
await page.querySelector("div#container")
|
||||
mermaid_config = {}
|
||||
# mermaid_config = {}
|
||||
background_color = "#ffffff"
|
||||
my_css = ""
|
||||
# my_css = ""
|
||||
await page.evaluate(f'document.body.style.background = "{background_color}";')
|
||||
|
||||
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
const { mermaid, zenuml } = globalThis;
|
||||
await mermaid.registerExternalDiagrams([zenuml]);
|
||||
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
document.getElementById('container').innerHTML = svg;
|
||||
const svgElement = document.querySelector('svg');
|
||||
svgElement.style.backgroundColor = backgroundColor;
|
||||
# metadata = await page.evaluate(
|
||||
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
|
||||
# const { mermaid, zenuml } = globalThis;
|
||||
# await mermaid.registerExternalDiagrams([zenuml]);
|
||||
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
|
||||
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
|
||||
# document.getElementById('container').innerHTML = svg;
|
||||
# const svgElement = document.querySelector('svg');
|
||||
# svgElement.style.backgroundColor = backgroundColor;
|
||||
#
|
||||
# if (myCSS) {
|
||||
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
# style.appendChild(document.createTextNode(myCSS));
|
||||
# svgElement.appendChild(style);
|
||||
# }
|
||||
# }""",
|
||||
# [mermaid_code, mermaid_config, my_css, background_color],
|
||||
# )
|
||||
|
||||
if (myCSS) {
|
||||
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
|
||||
style.appendChild(document.createTextNode(myCSS));
|
||||
svgElement.appendChild(style);
|
||||
}
|
||||
}''', [mermaid_code, mermaid_config, my_css, background_color])
|
||||
|
||||
if 'svg' in suffixes :
|
||||
svg_xml = await page.evaluate('''() => {
|
||||
if "svg" in suffixes:
|
||||
svg_xml = await page.evaluate(
|
||||
"""() => {
|
||||
const svg = document.querySelector('svg');
|
||||
const xmlSerializer = new XMLSerializer();
|
||||
return xmlSerializer.serializeToString(svg);
|
||||
}''')
|
||||
}"""
|
||||
)
|
||||
logger.info(f"Generating {output_file_without_suffix}.svg..")
|
||||
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
|
||||
f.write(svg_xml.encode('utf-8'))
|
||||
with open(f"{output_file_without_suffix}.svg", "wb") as f:
|
||||
f.write(svg_xml.encode("utf-8"))
|
||||
|
||||
if 'png' in suffixes:
|
||||
clip = await page.evaluate('''() => {
|
||||
if "png" in suffixes:
|
||||
clip = await page.evaluate(
|
||||
"""() => {
|
||||
const svg = document.querySelector('svg');
|
||||
const rect = svg.getBoundingClientRect();
|
||||
return {
|
||||
|
|
@ -93,16 +102,23 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
width: Math.ceil(rect.width),
|
||||
height: Math.ceil(rect.height)
|
||||
};
|
||||
}''')
|
||||
await page.setViewport({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height'], 'deviceScaleFactor': device_scale_factor})
|
||||
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
|
||||
}"""
|
||||
)
|
||||
await page.setViewport(
|
||||
{
|
||||
"width": clip["x"] + clip["width"],
|
||||
"height": clip["y"] + clip["height"],
|
||||
"deviceScaleFactor": device_scale_factor,
|
||||
}
|
||||
)
|
||||
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
|
||||
logger.info(f"Generating {output_file_without_suffix}.png..")
|
||||
with open(f'{output_file_without_suffix}.png', 'wb') as f:
|
||||
with open(f"{output_file_without_suffix}.png", "wb") as f:
|
||||
f.write(screenshot)
|
||||
if 'pdf' in suffixes:
|
||||
if "pdf" in suffixes:
|
||||
pdf_data = await page.pdf(scale=device_scale_factor)
|
||||
logger.info(f"Generating {output_file_without_suffix}.pdf..")
|
||||
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
|
||||
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
|
||||
f.write(pdf_data)
|
||||
return 0
|
||||
except Exception as e:
|
||||
|
|
@ -110,4 +126,3 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
return -1
|
||||
finally:
|
||||
await browser.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Generator, Optional
|
|||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class WebPage(BaseModel):
|
||||
|
|
@ -13,18 +13,15 @@ class WebPage(BaseModel):
|
|||
html: str
|
||||
url: str
|
||||
|
||||
class Config:
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
_soup : Optional[BeautifulSoup] = None
|
||||
_title: Optional[str] = None
|
||||
_soup: Optional[BeautifulSoup] = PrivateAttr(default=None)
|
||||
_title: Optional[str] = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
def soup(self) -> BeautifulSoup:
|
||||
if self._soup is None:
|
||||
self._soup = BeautifulSoup(self.html, "html.parser")
|
||||
return self._soup
|
||||
|
||||
|
||||
@property
|
||||
def title(self):
|
||||
if self._title is None:
|
||||
|
|
|
|||
|
|
@ -37,18 +37,26 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
|
|||
|
||||
if not isinstance(expr, cst.Expr):
|
||||
return None
|
||||
|
||||
|
||||
val = expr.value
|
||||
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
|
||||
return None
|
||||
|
||||
evaluated_value = val.evaluated_value
|
||||
|
||||
evaluated_value = val.evaluated_value
|
||||
if isinstance(evaluated_value, bytes):
|
||||
return None
|
||||
|
||||
return statement
|
||||
|
||||
|
||||
def has_decorator(node: DocstringNode, name: str) -> bool:
|
||||
return hasattr(node, "decorators") and any(
|
||||
(hasattr(i.decorator, "value") and i.decorator.value == name)
|
||||
or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name)
|
||||
for i in node.decorators
|
||||
)
|
||||
|
||||
|
||||
class DocstringCollector(cst.CSTVisitor):
|
||||
"""A visitor class for collecting docstrings from a CST.
|
||||
|
||||
|
|
@ -56,6 +64,7 @@ class DocstringCollector(cst.CSTVisitor):
|
|||
stack: A list to keep track of the current path in the CST.
|
||||
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.stack: list[str] = []
|
||||
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
|
||||
|
|
@ -81,7 +90,7 @@ class DocstringCollector(cst.CSTVisitor):
|
|||
def _leave(self, node: DocstringNode) -> None:
|
||||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
|
||||
if has_decorator(node, "overload"):
|
||||
return
|
||||
|
||||
statement = get_docstring_statement(node)
|
||||
|
|
@ -96,6 +105,7 @@ class DocstringTransformer(cst.CSTTransformer):
|
|||
stack: A list to keep track of the current path in the CST.
|
||||
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
|
||||
|
|
@ -125,7 +135,7 @@ class DocstringTransformer(cst.CSTTransformer):
|
|||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
|
||||
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
|
||||
if has_decorator(updated_node, "overload"):
|
||||
return updated_node
|
||||
|
||||
statement = self.docstrings.get(key)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
import docx
|
||||
|
||||
|
||||
def read_docx(file_path: str) -> list:
|
||||
"""Open a docx file"""
|
||||
doc = docx.Document(file_path)
|
||||
|
|
|
|||
79
metagpt/utils/redis.py
Normal file
79
metagpt/utils/redis.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
# !/usr/bin/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/27
|
||||
@Author : mashenquan
|
||||
@File : redis.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
|
||||
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class Redis:
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
|
||||
async def _connect(self, force=False):
|
||||
if self._client and not force:
|
||||
return True
|
||||
if not self.is_configured:
|
||||
return False
|
||||
|
||||
try:
|
||||
self._client = await aioredis.from_url(
|
||||
f"redis://{CONFIG.REDIS_HOST}:{CONFIG.REDIS_PORT}",
|
||||
username=CONFIG.REDIS_USER,
|
||||
password=CONFIG.REDIS_PASSWORD,
|
||||
db=CONFIG.REDIS_DB,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis initialization has failed:{e}")
|
||||
return False
|
||||
|
||||
async def get(self, key: str) -> bytes | None:
|
||||
if not await self._connect() or not key:
|
||||
return None
|
||||
try:
|
||||
v = await self._client.get(key)
|
||||
return v
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, data: str, timeout_sec: int = None):
|
||||
if not await self._connect() or not key:
|
||||
return
|
||||
try:
|
||||
ex = None if not timeout_sec else timedelta(seconds=timeout_sec)
|
||||
await self._client.set(key, data, ex=ex)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
|
||||
async def close(self):
|
||||
if not self._client:
|
||||
return
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(
|
||||
CONFIG.REDIS_HOST
|
||||
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
|
||||
and CONFIG.REDIS_PORT
|
||||
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
|
||||
and CONFIG.REDIS_DB is not None
|
||||
and CONFIG.REDIS_PASSWORD is not None
|
||||
)
|
||||
314
metagpt/utils/repair_llm_raw_output.py
Normal file
314
metagpt/utils/repair_llm_raw_output.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : repair llm raw output with particular conditions
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Callable, Union
|
||||
|
||||
import regex as re
|
||||
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.custom_decoder import CustomDecoder
|
||||
|
||||
|
||||
class RepairType(Enum):
|
||||
CS = "case sensitivity"
|
||||
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
|
||||
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
|
||||
JSON = "json format"
|
||||
|
||||
|
||||
def repair_case_sensitivity(output: str, req_key: str) -> str:
|
||||
"""
|
||||
usually, req_key is the key name of expected json or markdown content, it won't appear in the value part.
|
||||
fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually
|
||||
"""
|
||||
if req_key in output:
|
||||
return output
|
||||
|
||||
output_lower = output.lower()
|
||||
req_key_lower = req_key.lower()
|
||||
if req_key_lower in output_lower:
|
||||
# find the sub-part index, and replace it with raw req_key
|
||||
lidx = output_lower.find(req_key_lower)
|
||||
source = output[lidx : lidx + len(req_key_lower)]
|
||||
output = output.replace(source, req_key)
|
||||
logger.info(f"repair_case_sensitivity: {req_key}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
fix
|
||||
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
|
||||
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
|
||||
"""
|
||||
sc_arr = ["/"]
|
||||
|
||||
if req_key in output:
|
||||
return output
|
||||
|
||||
for sc in sc_arr:
|
||||
req_key_pure = req_key.replace(sc, "")
|
||||
appear_cnt = output.count(req_key_pure)
|
||||
if req_key_pure in output and appear_cnt > 1:
|
||||
# req_key with special_character usually in the tail side
|
||||
ridx = output.rfind(req_key_pure)
|
||||
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
|
||||
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
|
||||
"""
|
||||
implement the req_key pair in the begin or end of the content
|
||||
req_key format
|
||||
1. `[req_key]`, and its pair `[/req_key]`
|
||||
2. `[/req_key]`, and its pair `[req_key]`
|
||||
"""
|
||||
sc = "/" # special char
|
||||
if req_key.startswith("[") and req_key.endswith("]"):
|
||||
if sc in req_key:
|
||||
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
|
||||
right_key = req_key
|
||||
else:
|
||||
left_key = req_key
|
||||
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
|
||||
|
||||
if left_key not in output:
|
||||
output = left_key + "\n" + output
|
||||
if right_key not in output:
|
||||
|
||||
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
|
||||
ridx = routput.rfind(left_key)
|
||||
if ridx < 0:
|
||||
return None
|
||||
sub_output = routput[ridx:]
|
||||
idx1 = sub_output.rfind("}")
|
||||
idx2 = sub_output.rindex("]")
|
||||
idx = idx1 if idx1 >= idx2 else idx2
|
||||
sub_output = sub_output[: idx + 1]
|
||||
return sub_output
|
||||
|
||||
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
|
||||
# # avoid [req_key]xx[req_key] case to append [/req_key]
|
||||
output = output + "\n" + right_key
|
||||
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
|
||||
sub_content = judge_potential_json(output, left_key)
|
||||
output = sub_content + "\n" + right_key
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repair_json_format(output: str) -> str:
|
||||
"""
|
||||
fix extra `[` or `}` in the end
|
||||
"""
|
||||
output = output.strip()
|
||||
|
||||
if output.startswith("[{"):
|
||||
output = output[1:]
|
||||
logger.info(f"repair_json_format: {'[{'}")
|
||||
elif output.endswith("}]"):
|
||||
output = output[:-1]
|
||||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str:
|
||||
repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]]
|
||||
for repair_type in repair_types:
|
||||
if repair_type == RepairType.CS:
|
||||
output = repair_case_sensitivity(output, req_key)
|
||||
elif repair_type == RepairType.RKPM:
|
||||
output = repair_required_key_pair_missing(output, req_key)
|
||||
elif repair_type == RepairType.SCM:
|
||||
output = repair_special_character_missing(output, req_key)
|
||||
elif repair_type == RepairType.JSON:
|
||||
output = repair_json_format(output)
|
||||
return output
|
||||
|
||||
|
||||
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
|
||||
"""
|
||||
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
|
||||
so here we try to repair it and use all repair methods by default.
|
||||
typical case
|
||||
1. case sensitivity
|
||||
target: "Original Requirements"
|
||||
output: "Original requirements"
|
||||
2. special character missing
|
||||
target: [/CONTENT]
|
||||
output: [CONTENT]
|
||||
3. json format
|
||||
target: { xxx }
|
||||
output: { xxx }]
|
||||
"""
|
||||
if not CONFIG.repair_llm_output:
|
||||
return output
|
||||
|
||||
# do the repairation usually for non-openai models
|
||||
for req_key in req_keys:
|
||||
output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type)
|
||||
return output
|
||||
|
||||
|
||||
def repair_invalid_json(output: str, error: str) -> str:
|
||||
"""
|
||||
repair the situation like there are extra chars like
|
||||
error examples
|
||||
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
|
||||
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
|
||||
"""
|
||||
pattern = r"line ([0-9]+)"
|
||||
|
||||
matches = re.findall(pattern, error, re.DOTALL)
|
||||
if len(matches) > 0:
|
||||
line_no = int(matches[0]) - 1
|
||||
|
||||
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
|
||||
output = output.replace('"""', '"').replace("'''", '"')
|
||||
arr = output.split("\n")
|
||||
line = arr[line_no].strip()
|
||||
# different general problems
|
||||
if line.endswith("],"):
|
||||
# problem, redundant char `]`
|
||||
new_line = line.replace("]", "")
|
||||
elif line.endswith("},") and not output.endswith("},"):
|
||||
# problem, redundant char `}`
|
||||
new_line = line.replace("}", "")
|
||||
elif line.endswith("},") and output.endswith("},"):
|
||||
new_line = line[:-1]
|
||||
elif '",' not in line and "," not in line:
|
||||
new_line = f'{line}",'
|
||||
elif "," not in line:
|
||||
# problem, miss char `,` at the end.
|
||||
new_line = f"{line},"
|
||||
elif "," in line and len(line) == 1:
|
||||
new_line = f'"{line}'
|
||||
elif '",' in line:
|
||||
new_line = line[:-2] + "',"
|
||||
else:
|
||||
new_line = line
|
||||
|
||||
arr[line_no] = new_line
|
||||
output = "\n".join(arr)
|
||||
logger.info(f"repair_invalid_json, raw error: {error}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
|
||||
def run_and_passon(retry_state: RetryCallState) -> None:
|
||||
"""
|
||||
RetryCallState example
|
||||
{
|
||||
"start_time":143.098322024,
|
||||
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
|
||||
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
|
||||
"args":"(\"tag:[/CONTENT]\",)", # function input args
|
||||
"kwargs":{}, # function input kwargs
|
||||
"attempt_number":1, # retry number
|
||||
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
|
||||
"outcome_timestamp":143.098416904,
|
||||
"idle_for":0,
|
||||
"next_action":"None"
|
||||
}
|
||||
"""
|
||||
if retry_state.outcome.failed:
|
||||
if retry_state.args:
|
||||
# # can't be used as args=retry_state.args
|
||||
func_param_output = retry_state.args[0]
|
||||
elif retry_state.kwargs:
|
||||
func_param_output = retry_state.kwargs.get("output", "")
|
||||
exp_str = str(retry_state.outcome.exception())
|
||||
|
||||
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
|
||||
logger.warning(
|
||||
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
|
||||
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
|
||||
)
|
||||
|
||||
repaired_output = repair_invalid_json(func_param_output, exp_str)
|
||||
retry_state.kwargs["output"] = repaired_output
|
||||
|
||||
return run_and_passon
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
|
||||
wait=wait_fixed(1),
|
||||
after=run_after_exp_and_passon_next_retry(logger),
|
||||
)
|
||||
def retry_parse_json_text(output: str) -> Union[list, dict]:
|
||||
"""
|
||||
repair the json-text situation like there are extra chars like [']', '}']
|
||||
|
||||
Warning
|
||||
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
|
||||
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
|
||||
it's a two-layer retry cycle
|
||||
"""
|
||||
# logger.debug(f"output to json decode:\n{output}")
|
||||
|
||||
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
|
||||
parsed_data = CustomDecoder(strict=False).decode(output)
|
||||
|
||||
return parsed_data
|
||||
|
||||
|
||||
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
|
||||
"""extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern"""
|
||||
|
||||
def re_extract_content(cont: str, pattern: str) -> str:
|
||||
matches = re.findall(pattern, cont, re.DOTALL)
|
||||
for match in matches:
|
||||
if match:
|
||||
cont = match
|
||||
break
|
||||
return cont.strip()
|
||||
|
||||
# TODO construct the extract pattern with the `right_key`
|
||||
raw_content = copy.deepcopy(content)
|
||||
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
|
||||
if not new_content.startswith("{"):
|
||||
# TODO find a more general pattern
|
||||
# # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
|
||||
logger.warning(f"extract_content try another pattern: {pattern}")
|
||||
if right_key not in new_content:
|
||||
raw_content = copy.deepcopy(new_content + "\n" + right_key)
|
||||
# # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
|
||||
new_content = re_extract_content(raw_content, pattern)
|
||||
else:
|
||||
if right_key in new_content:
|
||||
idx = new_content.find(right_key)
|
||||
new_content = new_content[:idx]
|
||||
new_content = new_content.strip()
|
||||
|
||||
return new_content
|
||||
|
||||
|
||||
def extract_state_value_from_output(content: str) -> str:
|
||||
"""
|
||||
For openai models, they will always return state number. But for open llm models, the instruction result maybe a
|
||||
long text contain target number, so here add a extraction to improve success rate.
|
||||
|
||||
Args:
|
||||
content (str): llm's output from `Role._think`
|
||||
"""
|
||||
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
|
||||
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
|
||||
matches = re.findall(pattern, content, re.DOTALL)
|
||||
matches = list(set(matches))
|
||||
state = matches[0] if len(matches) > 0 else "-1"
|
||||
return state
|
||||
170
metagpt/utils/s3.py
Normal file
170
metagpt/utils/s3.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
import base64
|
||||
import os.path
|
||||
import traceback
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aioboto3
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class S3:
|
||||
"""A class for interacting with Amazon S3 storage."""
|
||||
|
||||
def __init__(self):
|
||||
self.session = aioboto3.Session()
|
||||
self.auth_config = {
|
||||
"service_name": "s3",
|
||||
"aws_access_key_id": CONFIG.S3_ACCESS_KEY,
|
||||
"aws_secret_access_key": CONFIG.S3_SECRET_KEY,
|
||||
"endpoint_url": CONFIG.S3_ENDPOINT_URL,
|
||||
}
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
bucket: str,
|
||||
local_path: str,
|
||||
object_name: str,
|
||||
) -> None:
|
||||
"""Upload a file from the local path to the specified path of the storage bucket specified in s3.
|
||||
|
||||
Args:
|
||||
bucket: The name of the S3 storage bucket.
|
||||
local_path: The local file path, including the file name.
|
||||
object_name: The complete path of the uploaded file to be stored in S3, including the file name.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the upload process, an exception is raised.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
async with aiofiles.open(local_path, mode="rb") as reader:
|
||||
body = await reader.read()
|
||||
await client.put_object(Body=body, Bucket=bucket, Key=object_name)
|
||||
logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}")
|
||||
raise e
|
||||
|
||||
async def get_object_url(
|
||||
self,
|
||||
bucket: str,
|
||||
object_name: str,
|
||||
) -> str:
|
||||
"""Get the URL for a downloadable or preview file stored in the specified S3 bucket.
|
||||
|
||||
Args:
|
||||
bucket: The name of the S3 storage bucket.
|
||||
object_name: The complete path of the file stored in S3, including the file name.
|
||||
|
||||
Returns:
|
||||
The URL for the downloadable or preview file.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs while retrieving the URL, an exception is raised.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
file = await client.get_object(Bucket=bucket, Key=object_name)
|
||||
return str(file["Body"].url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get the url for a downloadable or preview file: {e}")
|
||||
raise e
|
||||
|
||||
async def get_object(
|
||||
self,
|
||||
bucket: str,
|
||||
object_name: str,
|
||||
) -> bytes:
|
||||
"""Get the binary data of a file stored in the specified S3 bucket.
|
||||
|
||||
Args:
|
||||
bucket: The name of the S3 storage bucket.
|
||||
object_name: The complete path of the file stored in S3, including the file name.
|
||||
|
||||
Returns:
|
||||
The binary data of the requested file.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs while retrieving the file data, an exception is raised.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
|
||||
return await s3_object["Body"].read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get the binary data of the file: {e}")
|
||||
raise e
|
||||
|
||||
async def download_file(
|
||||
self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024
|
||||
) -> None:
|
||||
"""Download an S3 object to a local file.
|
||||
|
||||
Args:
|
||||
bucket: The name of the S3 storage bucket.
|
||||
object_name: The complete path of the file stored in S3, including the file name.
|
||||
local_path: The local file path where the S3 object will be downloaded.
|
||||
chunk_size: The size of data chunks to read and write at a time. Default is 128 KB.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the download process, an exception is raised.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
|
||||
stream = s3_object["Body"]
|
||||
async with aiofiles.open(local_path, mode="wb") as writer:
|
||||
while True:
|
||||
file_data = await stream.read(chunk_size)
|
||||
if not file_data:
|
||||
break
|
||||
await writer.write(file_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download the file from S3: {e}")
|
||||
raise e
|
||||
|
||||
async def cache(self, data: str, file_ext: str, format: str = "") -> str:
|
||||
"""Save data to remote S3 and return url"""
|
||||
object_name = uuid.uuid4().hex + file_ext
|
||||
path = Path(__file__).parent
|
||||
pathname = path / object_name
|
||||
try:
|
||||
async with aiofiles.open(str(pathname), mode="wb") as file:
|
||||
data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8")
|
||||
await file.write(data)
|
||||
|
||||
bucket = CONFIG.S3_BUCKET
|
||||
object_pathname = CONFIG.S3_BUCKET or "system"
|
||||
object_pathname += f"/{object_name}"
|
||||
object_pathname = os.path.normpath(object_pathname)
|
||||
await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname)
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
return await self.get_object_url(bucket=bucket, object_name=object_pathname)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
pathname.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
return self.is_configured
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(
|
||||
CONFIG.S3_ACCESS_KEY
|
||||
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
|
||||
and CONFIG.S3_SECRET_KEY
|
||||
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
|
||||
and CONFIG.S3_ENDPOINT_URL
|
||||
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
|
||||
and CONFIG.S3_BUCKET
|
||||
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
|
||||
)
|
||||
|
|
@ -4,13 +4,11 @@
|
|||
|
||||
import copy
|
||||
import pickle
|
||||
from typing import Dict, List
|
||||
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
||||
def actionoutout_schema_to_mapping(schema: dict) -> dict:
|
||||
"""
|
||||
directly traverse the `properties` in the first level.
|
||||
schema structure likes
|
||||
|
|
@ -35,32 +33,50 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
|||
if property["type"] == "string":
|
||||
mapping[field] = (str, ...)
|
||||
elif property["type"] == "array" and property["items"]["type"] == "string":
|
||||
mapping[field] = (List[str], ...)
|
||||
mapping[field] = (list[str], ...)
|
||||
elif property["type"] == "array" and property["items"]["type"] == "array":
|
||||
# here only consider the `List[List[str]]` situation
|
||||
mapping[field] = (List[List[str]], ...)
|
||||
# here only consider the `list[list[str]]` situation
|
||||
mapping[field] = (list[list[str]], ...)
|
||||
return mapping
|
||||
|
||||
|
||||
def serialize_message(message: Message):
|
||||
def actionoutput_mapping_to_str(mapping: dict) -> dict:
|
||||
new_mapping = {}
|
||||
for key, value in mapping.items():
|
||||
new_mapping[key] = str(value)
|
||||
return new_mapping
|
||||
|
||||
|
||||
def actionoutput_str_to_mapping(mapping: dict) -> dict:
|
||||
new_mapping = {}
|
||||
for key, value in mapping.items():
|
||||
if value == "(<class 'str'>, Ellipsis)":
|
||||
new_mapping[key] = (str, ...)
|
||||
else:
|
||||
new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)`
|
||||
return new_mapping
|
||||
|
||||
|
||||
def serialize_message(message: "Message"):
|
||||
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference
|
||||
ic = message_cp.instruct_content
|
||||
if ic:
|
||||
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
|
||||
schema = ic.schema()
|
||||
schema = ic.model_json_schema()
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
|
||||
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
|
||||
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
||||
msg_ser = pickle.dumps(message_cp)
|
||||
|
||||
return msg_ser
|
||||
|
||||
|
||||
def deserialize_message(message_ser: str) -> Message:
|
||||
def deserialize_message(message_ser: str) -> "Message":
|
||||
message = pickle.loads(message_ser)
|
||||
if message.instruct_content:
|
||||
ic = message.instruct_content
|
||||
ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
ic_new = ic_obj(**ic["value"])
|
||||
message.instruct_content = ic_new
|
||||
|
||||
|
|
|
|||
|
|
@ -20,4 +20,3 @@ class Singleton(abc.ABCMeta, type):
|
|||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# token to separate different code messages in a WriteCode Message content
|
||||
MSG_SEP = "#*000*#"
|
||||
MSG_SEP = "#*000*#"
|
||||
# token to seperate file name and the actual code text in a code message
|
||||
FILENAME_CODE_SEP = "#*001*#"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@ from typing import Generator, Sequence
|
|||
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
|
||||
|
||||
|
||||
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
|
||||
def reduce_message_length(
|
||||
msgs: Generator[str, None, None],
|
||||
model_name: str,
|
||||
system_text: str,
|
||||
reserved: int = 0,
|
||||
) -> str:
|
||||
"""Reduce the length of concatenated message segments to fit within the maximum token size.
|
||||
|
||||
Args:
|
||||
|
|
@ -49,9 +54,9 @@ def generate_prompt_chunk(
|
|||
current_token = 0
|
||||
current_lines = []
|
||||
|
||||
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
|
||||
reserved = reserved + count_string_tokens(prompt_template + system_text, model_name)
|
||||
# 100 is a magic number to ensure the maximum context length is not exceeded
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
|
||||
|
||||
while paragraphs:
|
||||
paragraph = paragraphs.pop(0)
|
||||
|
|
@ -103,7 +108,7 @@ def decode_unicode_escape(text: str) -> str:
|
|||
return text.encode("utf-8").decode("unicode_escape", "ignore")
|
||||
|
||||
|
||||
def _split_by_count(lst: Sequence , count: int):
|
||||
def _split_by_count(lst: Sequence, count: int):
|
||||
avg = len(lst) // count
|
||||
remainder = len(lst) % count
|
||||
start = 0
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py
|
||||
ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py
|
||||
ref4: https://ai.google.dev/models/gemini
|
||||
"""
|
||||
import tiktoken
|
||||
|
||||
|
|
@ -16,13 +17,18 @@ TOKEN_COSTS = {
|
|||
"gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004},
|
||||
"gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004},
|
||||
"gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004},
|
||||
"gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002},
|
||||
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
|
||||
"gpt-4": {"prompt": 0.03, "completion": 0.06},
|
||||
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -32,13 +38,18 @@ TOKEN_MAX = {
|
|||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-16k-0613": 16384,
|
||||
"gpt-35-turbo": 4096,
|
||||
"gpt-35-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-1106": 16384,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"text-embedding-ada-002": 8192,
|
||||
"chatglm_turbo": 32768
|
||||
"chatglm_turbo": 32768,
|
||||
"gemini-pro": 32768,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -52,22 +63,34 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-35-turbo",
|
||||
"gpt-35-turbo-16k",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
elif "gpt-3.5-turbo" == model:
|
||||
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
elif "gpt-4" == model:
|
||||
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
elif "open-llm-model" == model:
|
||||
"""
|
||||
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
|
||||
inaccurate. It's a reference result.
|
||||
"""
|
||||
tokens_per_message = 0 # ignore conversation message template prefix
|
||||
tokens_per_name = 0
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"num_tokens_from_messages() is not implemented for model {model}. "
|
||||
|
|
@ -96,7 +119,11 @@ def count_string_tokens(string: str, model_name: str) -> int:
|
|||
Returns:
|
||||
int: The number of tokens in the text string.
|
||||
"""
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(string))
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue