Merge main branch

This commit is contained in:
mannaandpoem 2024-01-03 19:48:46 +08:00
commit 24e617b362
325 changed files with 11290 additions and 3760 deletions

View file

@ -23,11 +23,11 @@ import sys
import traceback
import typing
from pathlib import Path
from typing import Any, List, Tuple, Union, get_args, get_origin
from typing import Any, List, Tuple, Union
import aiofiles
import loguru
from pydantic.json import pydantic_encoder
from pydantic_core import to_jsonable_python
from tenacity import RetryCallState, _utils
from metagpt.const import MESSAGE_ROUTE_TO_ALL
@ -48,10 +48,10 @@ def check_cmd_exists(command) -> int:
return result
def require_python_version(req_version: tuple[int]) -> bool:
def require_python_version(req_version: Tuple) -> bool:
if not (2 <= len(req_version) <= 3):
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
return True if sys.version_info > req_version else False
return bool(sys.version_info > req_version)
class OutputParser:
@ -131,13 +131,11 @@ class OutputParser:
try:
content = cls.parse_code(text=content)
except Exception:
pass
# 尝试解析list
try:
content = cls.parse_file_list(text=content)
except Exception:
pass
# 尝试解析list
try:
content = cls.parse_file_list(text=content)
except Exception:
pass
parsed_data[block] = content
return parsed_data
@ -149,19 +147,7 @@ class OutputParser:
if extracted_content:
return extracted_content.group(1).strip()
else:
return "No content found between [CONTENT] and [/CONTENT] tags."
@staticmethod
def is_supported_list_type(i):
origin = get_origin(i)
if origin is not List:
return False
args = get_args(i)
if args == (str,) or args == (Tuple[str, str],) or args == (List[str],):
return True
return False
raise ValueError(f"Could not find content between [{tag}] and [/{tag}]")
@classmethod
def parse_data_with_mapping(cls, data, mapping):
@ -367,14 +353,14 @@ def get_class_name(cls) -> str:
return f"{cls.__module__}.{cls.__name__}"
def any_to_str(val: str | typing.Callable) -> str:
def any_to_str(val: Any) -> str:
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
if isinstance(val, str):
return val
if not callable(val):
elif not callable(val):
return get_class_name(type(val))
return get_class_name(val)
else:
return get_class_name(val)
def any_to_str_set(val) -> set:
@ -406,6 +392,21 @@ def is_subscribed(message: "Message", tags: set):
return False
def any_to_name(val):
"""
Convert a value to its name by extracting the last part of the dotted path.
:param val: The value to convert.
:return: The name of the value.
"""
return any_to_str(val).split(".")[-1]
def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
@ -439,7 +440,7 @@ def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.C
return log_it
def read_json_file(json_file: str, encoding=None) -> list[Any]:
def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
if not Path(json_file).exists():
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
@ -457,7 +458,7 @@ def write_json_file(json_file: str, data: list, encoding=None):
folder_path.mkdir(parents=True, exist_ok=True)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder)
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
def import_class(class_name: str, module_name: str) -> type:
@ -497,7 +498,7 @@ def role_raise_decorator(func):
except KeyboardInterrupt as kbi:
logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project")
if self.latest_observed_msg:
self._rc.memory.delete(self.latest_observed_msg)
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
except Exception:
@ -507,7 +508,7 @@ def role_raise_decorator(func):
"we delete the newest role communication message in the role's memory."
)
# remove role newest observed msg to make it observed again
self._rc.memory.delete(self.latest_observed_msg)
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
@ -515,8 +516,33 @@ def role_raise_decorator(func):
@handle_exception
async def aread(file_path: str) -> str:
async def aread(filename: str | Path, encoding=None) -> str:
"""Read file asynchronously."""
async with aiofiles.open(str(file_path), mode="r") as reader:
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
content = await reader.read()
return content
async def awrite(filename: str | Path, data: str, encoding=None):
"""Write file asynchronously."""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="w", encoding=encoding) as writer:
await writer.write(data)
async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
if not Path(filename).exists():
return ""
lines = []
async with aiofiles.open(str(filename), mode="r") as reader:
ix = 0
while ix < end_lineno:
ix += 1
line = await reader.readline()
if ix < lineno:
continue
if ix > end_lineno:
break
lines.append(line)
return "".join(lines)

View file

@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/28
@Author : mashenquan
@File : openai.py
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
"""
from typing import NamedTuple
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.utils.token_counter import TOKEN_COSTS
class Costs(NamedTuple):
total_prompt_tokens: int
total_completion_tokens: int
total_cost: float
total_budget: float
class CostManager(BaseModel):
"""Calculate the overhead of using the interface."""
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
def get_total_prompt_tokens(self):
"""
Get the total number of prompt tokens.
Returns:
int: The total number of prompt tokens.
"""
return self.total_prompt_tokens
def get_total_completion_tokens(self):
"""
Get the total number of completion tokens.
Returns:
int: The total number of completion tokens.
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
Returns:
float: The total cost of API calls.
"""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)

View file

@ -14,7 +14,6 @@ from typing import Set
import aiofiles
from metagpt.config import CONFIG
from metagpt.utils.common import aread
from metagpt.utils.exceptions import handle_exception
@ -86,7 +85,7 @@ class DependencyFile:
if persist:
await self.load()
root = CONFIG.git_repo.workdir
root = self._filename.parent
try:
key = Path(filename).relative_to(root)
except ValueError:

View file

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : di_graph_repository.py
@Desc : Graph repository based on DiGraph
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import List
import aiofiles
import networkx
from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
self._repo.add_edge(subject, object_, predicate=predicate)
async def upsert(self, subject: str, predicate: str, object_: str):
pass
async def update(self, subject: str, predicate: str, object_: str):
pass
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
result = []
for s, o, p in self._repo.edges(data="predicate"):
if subject and subject != s:
continue
if predicate and predicate != p:
continue
if object_ and object_ != o:
continue
result.append(SPO(subject=s, predicate=p, object_=o))
return result
def json(self) -> str:
m = networkx.node_link_data(self._repo)
data = json.dumps(m)
return data
async def save(self, path: str | Path = None):
data = self.json()
path = path or self._kwargs.get("root")
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
pathname = Path(path) / self.name
async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
await writer.write(data)
async def load(self, pathname: str | Path):
async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
data = await reader.read(-1)
m = json.loads(data)
self._repo = networkx.node_link_graph(m)
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
if pathname.exists():
await graph.load(pathname=pathname)
return graph
@property
def root(self) -> str:
return self._kwargs.get("root")
@property
def pathname(self) -> Path:
p = Path(self.root) / self.name
return p.with_suffix(".json")

View file

@ -81,10 +81,11 @@ class FileRepository:
:return: List of changed dependency filenames or paths.
"""
dependencies = await self.get_dependency(filename=filename)
changed_files = self.changed_files
changed_files = set(self.changed_files.keys())
changed_dependent_files = set()
for df in dependencies:
if df in changed_files.keys():
rdf = Path(df).relative_to(self._relative_path)
if str(rdf) in changed_files:
changed_dependent_files.add(df)
return changed_dependent_files

View file

@ -1,20 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/19 20:39
@Author : femto Zheng
@File : get_template.py
"""
from metagpt.config import CONFIG
def get_template(templates, schema=CONFIG.prompt_schema):
selected_templates = templates.get(schema)
if selected_templates is None:
raise ValueError(f"Can't find {schema} in passed in templates")
# Extract the selected templates
prompt_template = selected_templates["PROMPT_TEMPLATE"]
format_example = selected_templates["FORMAT_EXAMPLE"]
return prompt_template, format_example

View file

@ -17,7 +17,6 @@ from git.repo import Repo
from git.repo.fun import is_git_dir
from gitignore_parser import parse_gitignore
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.utils.dependency_file import DependencyFile
from metagpt.utils.file_repository import FileRepository
@ -276,20 +275,3 @@ class GitRepository:
continue
files.append(filename)
return files
if __name__ == "__main__":
path = DEFAULT_WORKSPACE_ROOT / "git"
path.mkdir(exist_ok=True, parents=True)
repo = GitRepository()
repo.open(path, auto_init=True)
repo.filter_gitignore(filenames=["snake_game/snake_game/__pycache__", "snake_game/snake_game/game.py"])
changes = repo.changed_files
print(changes)
repo.add_change(changes)
print(repo.status)
repo.commit("test")
print(repo.status)
repo.delete_repository()

View file

@ -0,0 +1,150 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : graph_repository.py
@Desc : Superclass for graph repository.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
from pydantic import BaseModel
from metagpt.repo_parser import ClassInfo, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword:
IS = "is"
CLASS = "class"
FUNCTION = "function"
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_ARGS_DESC = "has_args_desc"
HAS_TYPE_DESC = "has_type_desc"
class SPO(BaseModel):
subject: str
predicate: str
object_: str
class GraphRepository(ABC):
def __init__(self, name: str, **kwargs):
self._repo_name = name
self._kwargs = kwargs
@abstractmethod
async def insert(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def upsert(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def update(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
pass
@property
def name(self) -> str:
return self._repo_name
@staticmethod
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
methods = c.get("methods", [])
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
for g in file_info.globals:
await graph_db.insert(
subject=concat_namespace(file_info.file, g),
predicate=GraphKeyword.IS,
object_=GraphKeyword.GLOBAL_VARIABLE,
)
for code_block in file_info.page_info:
if code_block.tokens:
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
for c in class_views:
filename, class_name = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name)
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_PROPERTY,
)
await graph_db.insert(
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)

View file

@ -18,17 +18,15 @@ from metagpt.config import CONFIG
def make_sk_kernel():
kernel = sk.Kernel()
if CONFIG.openai_api_type == "azure":
if CONFIG.OPENAI_API_TYPE == "azure":
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(
deployment_name=CONFIG.deployment_name, endpoint=CONFIG.openai_base_url, api_key=CONFIG.openai_api_key
),
AzureChatCompletion(CONFIG.DEPLOYMENT_NAME, CONFIG.OPENAI_BASE_URL, CONFIG.OPENAI_API_KEY),
)
else:
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(model_id=CONFIG.openai_api_model, api_key=CONFIG.openai_api_key),
OpenAIChatCompletion(CONFIG.OPENAI_API_MODEL, CONFIG.OPENAI_API_KEY),
)
return kernel

View file

@ -10,7 +10,6 @@ import os
from pathlib import Path
from metagpt.config import CONFIG
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.utils.common import check_cmd_exists
@ -88,7 +87,8 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
return 0
MMC1 = """classDiagram
MMC1 = """
classDiagram
class Main {
-SearchEngine search_engine
+main() str
@ -118,9 +118,11 @@ MMC1 = """classDiagram
SearchEngine --> Index
SearchEngine --> Ranking
SearchEngine --> Summary
Index --> KnowledgeBase"""
Index --> KnowledgeBase
"""
MMC1_REFINE = """classDiagram
MMC1_REFINE = """
classDiagram
class Main {
-SearchEngine search_engine
+main() str
@ -156,9 +158,11 @@ MMC1_REFINE = """classDiagram
SearchEngine --> Index
SearchEngine --> Ranking
SearchEngine --> Summary
Index --> KnowledgeBase"""
Index --> KnowledgeBase
"""
MMC2 = """sequenceDiagram
MMC2 = """
sequenceDiagram
participant M as Main
participant SE as SearchEngine
participant I as Index
@ -174,9 +178,11 @@ MMC2 = """sequenceDiagram
R-->>SE: return ranked_results
SE->>S: summarize_results(ranked_results)
S-->>SE: return summary
SE-->>M: return summary"""
SE-->>M: return summary
"""
MMC2_REFINE = """sequenceDiagram
MMC2_REFINE = """
sequenceDiagram
participant M as Main
participant SE as SearchEngine
participant I as Index
@ -201,10 +207,5 @@ MMC2_REFINE = """sequenceDiagram
R-->>SE: newMethod() # Incremental change
SE->>S: newMethod() # Incremental change
S-->>SE: newMethod() # Incremental change
SE-->>M: newMethod() # Incremental change"""
if __name__ == "__main__":
loop = asyncio.new_event_loop()
result = loop.run_until_complete(mermaid_to_file(MMC1, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
loop.close()
SE-->>M: newMethod() # Incremental change
"""

View file

@ -5,7 +5,7 @@ from typing import Generator, Optional
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr
class WebPage(BaseModel):
@ -13,11 +13,8 @@ class WebPage(BaseModel):
html: str
url: str
class Config:
underscore_attrs_are_private = True
_soup: Optional[BeautifulSoup] = None
_title: Optional[str] = None
_soup: Optional[BeautifulSoup] = PrivateAttr(default=None)
_title: Optional[str] = PrivateAttr(default=None)
@property
def soup(self) -> BeautifulSoup:

View file

@ -49,6 +49,14 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
return statement
def has_decorator(node: DocstringNode, name: str) -> bool:
return hasattr(node, "decorators") and any(
(hasattr(i.decorator, "value") and i.decorator.value == name)
or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name)
for i in node.decorators
)
class DocstringCollector(cst.CSTVisitor):
"""A visitor class for collecting docstrings from a CST.
@ -82,7 +90,7 @@ class DocstringCollector(cst.CSTVisitor):
def _leave(self, node: DocstringNode) -> None:
key = tuple(self.stack)
self.stack.pop()
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
if has_decorator(node, "overload"):
return
statement = get_docstring_statement(node)
@ -127,9 +135,7 @@ class DocstringTransformer(cst.CSTTransformer):
key = tuple(self.stack)
self.stack.pop()
if hasattr(updated_node, "decorators") and any(
(i.decorator.value == "overload") for i in updated_node.decorators
):
if has_decorator(updated_node, "overload"):
return updated_node
statement = self.docstrings.get(key)

79
metagpt/utils/redis.py Normal file
View file

@ -0,0 +1,79 @@
# !/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/27
@Author : mashenquan
@File : redis.py
"""
from __future__ import annotations
import traceback
from datetime import timedelta
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
from metagpt.config import CONFIG
from metagpt.logs import logger
class Redis:
def __init__(self):
self._client = None
async def _connect(self, force=False):
if self._client and not force:
return True
if not self.is_configured:
return False
try:
self._client = await aioredis.from_url(
f"redis://{CONFIG.REDIS_HOST}:{CONFIG.REDIS_PORT}",
username=CONFIG.REDIS_USER,
password=CONFIG.REDIS_PASSWORD,
db=CONFIG.REDIS_DB,
)
return True
except Exception as e:
logger.warning(f"Redis initialization has failed:{e}")
return False
async def get(self, key: str) -> bytes | None:
if not await self._connect() or not key:
return None
try:
v = await self._client.get(key)
return v
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
return None
async def set(self, key: str, data: str, timeout_sec: int = None):
if not await self._connect() or not key:
return
try:
ex = None if not timeout_sec else timedelta(seconds=timeout_sec)
await self._client.set(key, data, ex=ex)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
async def close(self):
if not self._client:
return
await self._client.close()
self._client = None
@property
def is_valid(self) -> bool:
return self._client is not None
@property
def is_configured(self) -> bool:
return bool(
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)

View file

@ -230,9 +230,11 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
elif retry_state.kwargs:
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
logger.warning(
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
)
repaired_output = repair_invalid_json(func_param_output, exp_str)

170
metagpt/utils/s3.py Normal file
View file

@ -0,0 +1,170 @@
import base64
import os.path
import traceback
import uuid
from pathlib import Path
from typing import Optional
import aioboto3
import aiofiles
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT
from metagpt.logs import logger
class S3:
"""A class for interacting with Amazon S3 storage."""
def __init__(self):
self.session = aioboto3.Session()
self.auth_config = {
"service_name": "s3",
"aws_access_key_id": CONFIG.S3_ACCESS_KEY,
"aws_secret_access_key": CONFIG.S3_SECRET_KEY,
"endpoint_url": CONFIG.S3_ENDPOINT_URL,
}
async def upload_file(
self,
bucket: str,
local_path: str,
object_name: str,
) -> None:
"""Upload a file from the local path to the specified path of the storage bucket specified in s3.
Args:
bucket: The name of the S3 storage bucket.
local_path: The local file path, including the file name.
object_name: The complete path of the uploaded file to be stored in S3, including the file name.
Raises:
Exception: If an error occurs during the upload process, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
async with aiofiles.open(local_path, mode="rb") as reader:
body = await reader.read()
await client.put_object(Body=body, Bucket=bucket, Key=object_name)
logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.")
except Exception as e:
logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}")
raise e
async def get_object_url(
self,
bucket: str,
object_name: str,
) -> str:
"""Get the URL for a downloadable or preview file stored in the specified S3 bucket.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
Returns:
The URL for the downloadable or preview file.
Raises:
Exception: If an error occurs while retrieving the URL, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
file = await client.get_object(Bucket=bucket, Key=object_name)
return str(file["Body"].url)
except Exception as e:
logger.error(f"Failed to get the url for a downloadable or preview file: {e}")
raise e
async def get_object(
self,
bucket: str,
object_name: str,
) -> bytes:
"""Get the binary data of a file stored in the specified S3 bucket.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
Returns:
The binary data of the requested file.
Raises:
Exception: If an error occurs while retrieving the file data, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
return await s3_object["Body"].read()
except Exception as e:
logger.error(f"Failed to get the binary data of the file: {e}")
raise e
async def download_file(
self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024
) -> None:
"""Download an S3 object to a local file.
Args:
bucket: The name of the S3 storage bucket.
object_name: The complete path of the file stored in S3, including the file name.
local_path: The local file path where the S3 object will be downloaded.
chunk_size: The size of data chunks to read and write at a time. Default is 128 KB.
Raises:
Exception: If an error occurs during the download process, an exception is raised.
"""
try:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
stream = s3_object["Body"]
async with aiofiles.open(local_path, mode="wb") as writer:
while True:
file_data = await stream.read(chunk_size)
if not file_data:
break
await writer.write(file_data)
except Exception as e:
logger.error(f"Failed to download the file from S3: {e}")
raise e
async def cache(self, data: str, file_ext: str, format: str = "") -> str:
"""Save data to remote S3 and return url"""
object_name = uuid.uuid4().hex + file_ext
path = Path(__file__).parent
pathname = path / object_name
try:
async with aiofiles.open(str(pathname), mode="wb") as file:
data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8")
await file.write(data)
bucket = CONFIG.S3_BUCKET
object_pathname = CONFIG.S3_BUCKET or "system"
object_pathname += f"/{object_name}"
object_pathname = os.path.normpath(object_pathname)
await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname)
pathname.unlink(missing_ok=True)
return await self.get_object_url(bucket=bucket, object_name=object_pathname)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
pathname.unlink(missing_ok=True)
return None
@property
def is_valid(self):
return self.is_configured
@property
def is_configured(self) -> bool:
return bool(
CONFIG.S3_ACCESS_KEY
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
and CONFIG.S3_SECRET_KEY
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
and CONFIG.S3_ENDPOINT_URL
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
and CONFIG.S3_BUCKET
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
)

View file

@ -62,10 +62,10 @@ def serialize_message(message: "Message"):
ic = message_cp.instruct_content
if ic:
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
schema = ic.schema()
schema = ic.model_json_schema()
mapping = actionoutout_schema_to_mapping(schema)
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
msg_ser = pickle.dumps(message_cp)
return msg_ser

View file

@ -84,6 +84,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
elif "gpt-4" == model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_message_tokens(messages, model="gpt-4-0613")
elif "open-llm-model" == model:
"""
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
inaccurate. It's a reference result.
"""
tokens_per_message = 0 # ignore conversation message template prefix
tokens_per_name = 0
else:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}. "
@ -112,7 +119,11 @@ def count_string_tokens(string: str, model_name: str) -> int:
Returns:
int: The number of tokens in the text string.
"""
encoding = tiktoken.encoding_for_model(model_name)
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(string))