mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-23 15:48:11 +02:00
feat: +unit test
fixbug: PYTHONPATH fixbug: unit test
This commit is contained in:
parent
641c71bf18
commit
0adabfe53f
33 changed files with 1561 additions and 297 deletions
|
|
@ -36,7 +36,8 @@ class PrepareDocuments(Action):
|
|||
if not path:
|
||||
name = CONFIG.project_name or FileRepository.new_filename()
|
||||
path = Path(CONFIG.workspace_path) / name
|
||||
|
||||
else:
|
||||
path = Path(CONFIG.project_path)
|
||||
if path.exists() and not CONFIG.inc:
|
||||
shutil.rmtree(path)
|
||||
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from metagpt.actions.write_prd_an import (
|
|||
WP_IS_RELATIVE_NODE,
|
||||
WP_ISSUE_TYPE_NODE,
|
||||
WRITE_PRD_NODE,
|
||||
WRITE_PRD_NODE_NO_NAME,
|
||||
)
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
|
|
@ -123,7 +124,8 @@ class WritePRD(Action):
|
|||
# logger.info(rsp)
|
||||
project_name = CONFIG.project_name if CONFIG.project_name else ""
|
||||
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm) # schema=schema
|
||||
write_prd_node = WRITE_PRD_NODE if not project_name else WRITE_PRD_NODE_NO_NAME
|
||||
node = await write_prd_node.fill(context=context, llm=self.llm) # schema=schema
|
||||
await self._rename_workspace(node)
|
||||
return node
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ ORIGINAL_REQUIREMENTS = ActionNode(
|
|||
PROJECT_NAME = ActionNode(
|
||||
key="Project Name",
|
||||
expected_type=str,
|
||||
instruction="Name the project using snake case style, like 'game_2048' or 'simple_crm'.",
|
||||
instruction="According to the content of \"Original Requirements,\" name the project using snake case style , like 'game_2048' or 'simple_crm.",
|
||||
example="game_2048",
|
||||
)
|
||||
|
||||
|
|
@ -141,7 +141,6 @@ NODES = [
|
|||
LANGUAGE,
|
||||
PROGRAMMING_LANGUAGE,
|
||||
ORIGINAL_REQUIREMENTS,
|
||||
PROJECT_NAME,
|
||||
PRODUCT_GOALS,
|
||||
USER_STORIES,
|
||||
COMPETITIVE_ANALYSIS,
|
||||
|
|
@ -152,7 +151,8 @@ NODES = [
|
|||
ANYTHING_UNCLEAR,
|
||||
]
|
||||
|
||||
WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES)
|
||||
WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES + [PROJECT_NAME])
|
||||
WRITE_PRD_NODE_NO_NAME = ActionNode.from_children("WritePRD", NODES)
|
||||
WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON])
|
||||
WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON])
|
||||
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ async def google_search(query: str, max_results: int = 6, **kwargs):
|
|||
:param max_results: The number of search results to retrieve
|
||||
:return: The web search results in markdown format.
|
||||
"""
|
||||
resluts = await SearchEngine().run(query, max_results=max_results, as_string=False)
|
||||
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(resluts, 1))
|
||||
results = await SearchEngine().run(query, max_results=max_results, as_string=False)
|
||||
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(results, 1))
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
|
||||
|
||||
ICL_SAMPLE = """Interface definition:
|
||||
|
|
@ -174,6 +176,9 @@ class UTGenerator:
|
|||
return doc
|
||||
|
||||
for name, prop in node.items():
|
||||
if not isinstance(prop, dict):
|
||||
doc += f'{" " * level}{self._para_to_str(node)}\n'
|
||||
break
|
||||
doc += f'{" " * level}{self.para_to_str(name, prop, prop_object_required)}\n'
|
||||
doc += dive_into_object(prop)
|
||||
if prop["type"] == "array":
|
||||
|
|
@ -202,12 +207,12 @@ class UTGenerator:
|
|||
|
||||
return tags
|
||||
|
||||
def generate_ut(self, include_tags) -> bool:
|
||||
async def generate_ut(self, include_tags) -> bool:
|
||||
"""Generate test case files"""
|
||||
tags = self.get_tags_mapping()
|
||||
for tag, paths in tags.items():
|
||||
if include_tags is None or tag in include_tags:
|
||||
self._generate_ut(tag, paths)
|
||||
await self._generate_ut(tag, paths)
|
||||
return True
|
||||
|
||||
def build_api_doc(self, node: dict, path: str, method: str) -> str:
|
||||
|
|
@ -250,21 +255,22 @@ class UTGenerator:
|
|||
|
||||
return doc
|
||||
|
||||
def _store(self, data, base, folder, fname):
|
||||
async def _store(self, data, base, folder, fname):
|
||||
"""Store data in a file."""
|
||||
file_path = self.get_file_path(Path(base) / folder, fname)
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(data)
|
||||
async with aiofiles.open(file_path, mode="w", encoding="utf-8") as file:
|
||||
await file.write(data)
|
||||
|
||||
def ask_gpt_and_save(self, question: str, tag: str, fname: str):
|
||||
async def ask_gpt_and_save(self, question: str, tag: str, fname: str):
|
||||
"""Generate questions and store both questions and answers"""
|
||||
messages = [self.icl_sample, question]
|
||||
result = self.gpt_msgs_to_code(messages=messages)
|
||||
result = await self.gpt_msgs_to_code(messages=messages)
|
||||
|
||||
self._store(question, self.questions_path, tag, f"{fname}.txt")
|
||||
self._store(result, self.ut_py_path, tag, f"{fname}.py")
|
||||
await self._store(question, self.questions_path, tag, f"{fname}.txt")
|
||||
data = result.get("code", "") if result else ""
|
||||
await self._store(data, self.ut_py_path, tag, f"{fname}.py")
|
||||
|
||||
def _generate_ut(self, tag, paths):
|
||||
async def _generate_ut(self, tag, paths):
|
||||
"""Process the structure under a data path
|
||||
|
||||
Args:
|
||||
|
|
@ -276,13 +282,13 @@ class UTGenerator:
|
|||
summary = node["summary"]
|
||||
question = self.template_prefix
|
||||
question += self.build_api_doc(node, path, method)
|
||||
self.ask_gpt_and_save(question, tag, summary)
|
||||
await self.ask_gpt_and_save(question, tag, summary)
|
||||
|
||||
async def gpt_msgs_to_code(self, messages: list) -> str:
|
||||
"""Choose based on different calling methods"""
|
||||
result = ""
|
||||
if self.chatgpt_method == "API":
|
||||
result = await GPTAPI().aask_code(msgs=messages)
|
||||
result = await GPTAPI().aask_code(messages=messages)
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Any, Callable, Coroutine, Literal, overload
|
||||
from typing import Any, Callable, Coroutine, overload
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
|
|
@ -46,12 +46,3 @@ class WebBrowserEngine:
|
|||
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
return await self.run_func(url, *urls)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
|
||||
return await WebBrowserEngine(engine=WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -142,12 +142,3 @@ async def _log_stream(sr, log_func):
|
|||
|
||||
_install_lock: asyncio.Lock = None
|
||||
_install_cache = set()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs):
|
||||
return await PlaywrightWrapper(browser_type=browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -118,12 +118,3 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
return WebDriver(options=deepcopy(options), service=Service(executable_path=executable_path))
|
||||
|
||||
return _get_driver
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
|
||||
return await SeleniumWrapper(browser_type=browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def check_cmd_exists(command) -> int:
|
|||
def require_python_version(req_version: Tuple) -> bool:
|
||||
if not (2 <= len(req_version) <= 3):
|
||||
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
|
||||
return True if sys.version_info > req_version else False
|
||||
return bool(sys.version_info > req_version)
|
||||
|
||||
|
||||
class OutputParser:
|
||||
|
|
@ -454,7 +454,7 @@ def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.C
|
|||
return log_it
|
||||
|
||||
|
||||
def read_json_file(json_file: str, encoding=None) -> list[Any]:
|
||||
def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
|
||||
if not Path(json_file).exists():
|
||||
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from typing import Set
|
|||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
|
@ -86,7 +85,7 @@ class DependencyFile:
|
|||
if persist:
|
||||
await self.load()
|
||||
|
||||
root = CONFIG.git_repo.workdir
|
||||
root = self._filename.parent
|
||||
try:
|
||||
key = Path(filename).relative_to(root)
|
||||
except ValueError:
|
||||
|
|
|
|||
|
|
@ -81,10 +81,11 @@ class FileRepository:
|
|||
:return: List of changed dependency filenames or paths.
|
||||
"""
|
||||
dependencies = await self.get_dependency(filename=filename)
|
||||
changed_files = self.changed_files
|
||||
changed_files = set(self.changed_files.keys())
|
||||
changed_dependent_files = set()
|
||||
for df in dependencies:
|
||||
if df in changed_files.keys():
|
||||
rdf = Path(df).relative_to(self._relative_path)
|
||||
if str(rdf) in changed_files:
|
||||
changed_dependent_files.add(df)
|
||||
return changed_dependent_files
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from git.repo import Repo
|
|||
from git.repo.fun import is_git_dir
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.dependency_file import DependencyFile
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -271,20 +270,3 @@ class GitRepository:
|
|||
continue
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = DEFAULT_WORKSPACE_ROOT / "git"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
repo = GitRepository()
|
||||
repo.open(path, auto_init=True)
|
||||
repo.filter_gitignore(filenames=["snake_game/snake_game/__pycache__", "snake_game/snake_game/game.py"])
|
||||
|
||||
changes = repo.changed_files
|
||||
print(changes)
|
||||
repo.add_change(changes)
|
||||
print(repo.status)
|
||||
repo.commit("test")
|
||||
print(repo.status)
|
||||
repo.delete_repository()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from pathlib import Path
|
|||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import check_cmd_exists
|
||||
|
||||
|
|
@ -146,9 +145,3 @@ sequenceDiagram
|
|||
S-->>SE: return summary
|
||||
SE-->>M: return summary
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC1, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/2"))
|
||||
loop.close()
|
||||
|
|
|
|||
|
|
@ -1,219 +1,67 @@
|
|||
# !/usr/bin/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Hui
|
||||
# @Desc: { redis client }
|
||||
# @Date: 2022/11/28 10:12
|
||||
import json
|
||||
"""
|
||||
@Time : 2023/12/27
|
||||
@Author : mashenquan
|
||||
@File : redis.py
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Dict, Optional, Union
|
||||
|
||||
from redis import asyncio as aioredis
|
||||
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class RedisTypeEnum(Enum):
|
||||
"""Redis 数据类型"""
|
||||
|
||||
String = "String"
|
||||
List = "List"
|
||||
Hash = "Hash"
|
||||
Set = "Set"
|
||||
ZSet = "ZSet"
|
||||
|
||||
|
||||
def make_url(
|
||||
dialect: str,
|
||||
*,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[Union[str, int]] = None,
|
||||
name: Optional[Union[str, int]] = None,
|
||||
) -> str:
|
||||
url_parts = [f"{dialect}://"]
|
||||
if user or password:
|
||||
if user:
|
||||
url_parts.append(user)
|
||||
if password:
|
||||
url_parts.append(f":{password}")
|
||||
url_parts.append("@")
|
||||
|
||||
if not host and not dialect.startswith("sqlite"):
|
||||
host = "127.0.0.1"
|
||||
|
||||
if host:
|
||||
url_parts.append(f"{host}")
|
||||
if port:
|
||||
url_parts.append(f":{port}")
|
||||
|
||||
# 比如redis可能传入0
|
||||
if name is not None:
|
||||
url_parts.append(f"/{name}")
|
||||
return "".join(url_parts)
|
||||
|
||||
|
||||
class RedisAsyncClient(aioredis.Redis):
|
||||
"""异步的客户端
|
||||
例子::
|
||||
|
||||
rdb = RedisAsyncClient()
|
||||
print(rdb.url)
|
||||
|
||||
Args:
|
||||
host: 服务器地址
|
||||
port: 服务器端口
|
||||
user: 用户名
|
||||
db: 数据库
|
||||
password: 密码
|
||||
decode_responses: 字符串输入被编码成utf8存储在Redis里了,而取出来的时候还是被编码后的bytes,需要显示的decode才能变成字符串
|
||||
health_check_interval: 定时检测连接,防止出现ConnectionErrors (104, Connection reset by peer)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: str = None,
|
||||
decode_responses=True,
|
||||
health_check_interval=10,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=decode_responses,
|
||||
health_check_interval=health_check_interval,
|
||||
socket_connect_timeout=socket_connect_timeout,
|
||||
retry_on_timeout=retry_on_timeout,
|
||||
socket_keepalive=socket_keepalive,
|
||||
**kwargs,
|
||||
)
|
||||
self.url = make_url("redis", host=host, port=port, name=db, password=password)
|
||||
|
||||
|
||||
class RedisCacheInfo(object):
|
||||
"""统一缓存信息类"""
|
||||
|
||||
def __init__(self, key, timeout: Union[int, timedelta] = timedelta(seconds=60), data_type=RedisTypeEnum.String):
|
||||
"""
|
||||
缓存信息类初始化
|
||||
Args:
|
||||
key: 缓存的key
|
||||
timeout: 缓存过期时间, 单位秒
|
||||
data_type: 缓存采用的数据结构 (不传并不影响,用于标记业务采用的是什么数据结构)
|
||||
"""
|
||||
self.key = key
|
||||
self.timeout = timeout
|
||||
self.data_type = data_type
|
||||
|
||||
def __str__(self):
|
||||
return f"cache key {self.key} timeout {self.timeout}s"
|
||||
|
||||
|
||||
class RedisManager:
|
||||
client: RedisAsyncClient = None
|
||||
|
||||
@classmethod
|
||||
def init_redis_conn(cls, host, port, password, db):
|
||||
"""初始化redis 连接"""
|
||||
if cls.client is None:
|
||||
cls.client = RedisAsyncClient(host=host, port=port, password=password, db=db)
|
||||
|
||||
@classmethod
|
||||
async def set_with_cache_info(cls, redis_cache_info: RedisCacheInfo, value):
|
||||
"""
|
||||
根据 RedisCacheInfo 设置 Redis 缓存
|
||||
:param redis_cache_info: RedisCacheInfo缓存信息对象
|
||||
:param value: 缓存的值
|
||||
:return:
|
||||
"""
|
||||
await cls.client.setex(redis_cache_info.key, redis_cache_info.timeout, value)
|
||||
|
||||
@classmethod
|
||||
async def get_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
|
||||
"""
|
||||
根据 RedisCacheInfo 获取 Redis 缓存
|
||||
:param redis_cache_info: RedisCacheInfo 缓存信息对象
|
||||
:return:
|
||||
"""
|
||||
cache_info = await cls.client.get(redis_cache_info.key)
|
||||
return cache_info
|
||||
|
||||
@classmethod
|
||||
async def del_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
|
||||
"""
|
||||
根据 RedisCacheInfo 删除 Redis 缓存
|
||||
:param redis_cache_info: RedisCacheInfo缓存信息对象
|
||||
:return:
|
||||
"""
|
||||
await cls.client.delete(redis_cache_info.key)
|
||||
|
||||
@staticmethod
|
||||
async def get_or_set_cache(cache_info: RedisCacheInfo, fetch_data_func: Callable[[], Awaitable[dict]]) -> dict:
|
||||
"""
|
||||
获取缓存数据,如果缓存不存在,则从提供的函数中获取并设置缓存
|
||||
当前版本仅支持 json 形式的 string 格式数据
|
||||
"""
|
||||
|
||||
serialized_data = await RedisManager.get_with_cache_info(cache_info)
|
||||
|
||||
if serialized_data:
|
||||
return json.loads(serialized_data)
|
||||
|
||||
data = await fetch_data_func()
|
||||
try:
|
||||
serialized_data = json.dumps(data)
|
||||
await RedisManager.set_with_cache_info(cache_info, serialized_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"数据 {data} 通过 json 进行序列化缓存失败:{e}")
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls):
|
||||
return cls.client is not None
|
||||
|
||||
|
||||
class Redis:
|
||||
def __init__(self, conf: Dict = None):
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
|
||||
async def _connect(self, force=False):
|
||||
if self._client and not force:
|
||||
return True
|
||||
if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
host = CONFIG.REDIS_HOST
|
||||
port = int(CONFIG.REDIS_PORT)
|
||||
pwd = CONFIG.REDIS_PASSWORD
|
||||
db = CONFIG.REDIS_DB
|
||||
RedisManager.init_redis_conn(host=host, port=port, password=pwd, db=db)
|
||||
self._client = await aioredis.from_url(
|
||||
f"redis://{CONFIG.REDIS_HOST}:{CONFIG.REDIS_PORT}",
|
||||
username=CONFIG.REDIS_USER,
|
||||
password=CONFIG.REDIS_PASSWORD,
|
||||
db=CONFIG.REDIS_DB,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis initialization has failed:{e}")
|
||||
return False
|
||||
|
||||
def is_valid(self):
|
||||
return RedisManager.is_valid()
|
||||
|
||||
async def get(self, key: str) -> str:
|
||||
if not self.is_valid() or not key:
|
||||
async def get(self, key: str) -> bytes:
|
||||
if not await self._connect() or not key:
|
||||
return None
|
||||
try:
|
||||
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
|
||||
v = await self._client.get(key)
|
||||
return v
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, data: str, timeout_sec: int):
|
||||
if not self.is_valid() or not key:
|
||||
async def set(self, key: str, data: str, timeout_sec: int = None):
|
||||
if not await self._connect() or not key:
|
||||
return
|
||||
try:
|
||||
await RedisManager.set_with_cache_info(
|
||||
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
|
||||
)
|
||||
ex = None if not timeout_sec else timedelta(seconds=timeout_sec)
|
||||
await self._client.set(key, data, ex=ex)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
|
||||
async def close(self):
|
||||
if not self._client:
|
||||
return
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
return bool(self._client)
|
||||
|
|
|
|||
|
|
@ -136,8 +136,7 @@ class S3:
|
|||
pathname = path / object_name
|
||||
try:
|
||||
async with aiofiles.open(str(pathname), mode="wb") as file:
|
||||
if format == BASE64_FORMAT:
|
||||
data = base64.b64decode(data)
|
||||
data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8")
|
||||
await file.write(data)
|
||||
|
||||
bucket = CONFIG.S3_BUCKET
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue