feat: +unit test

fixbug: PYTHONPATH

fixbug: unit test
This commit is contained in:
莘权 马 2023-12-27 11:24:22 +08:00
parent 641c71bf18
commit 0adabfe53f
33 changed files with 1561 additions and 297 deletions

View file

@ -51,7 +51,7 @@ def check_cmd_exists(command) -> int:
def require_python_version(req_version: Tuple) -> bool:
if not (2 <= len(req_version) <= 3):
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
return True if sys.version_info > req_version else False
return bool(sys.version_info > req_version)
class OutputParser:
@ -454,7 +454,7 @@ def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.C
return log_it
def read_json_file(json_file: str, encoding=None) -> list[Any]:
def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
if not Path(json_file).exists():
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")

View file

@ -14,7 +14,6 @@ from typing import Set
import aiofiles
from metagpt.config import CONFIG
from metagpt.utils.common import aread
from metagpt.utils.exceptions import handle_exception
@ -86,7 +85,7 @@ class DependencyFile:
if persist:
await self.load()
root = CONFIG.git_repo.workdir
root = self._filename.parent
try:
key = Path(filename).relative_to(root)
except ValueError:

View file

@ -81,10 +81,11 @@ class FileRepository:
:return: List of changed dependency filenames or paths.
"""
dependencies = await self.get_dependency(filename=filename)
changed_files = self.changed_files
changed_files = set(self.changed_files.keys())
changed_dependent_files = set()
for df in dependencies:
if df in changed_files.keys():
rdf = Path(df).relative_to(self._relative_path)
if str(rdf) in changed_files:
changed_dependent_files.add(df)
return changed_dependent_files

View file

@ -17,7 +17,6 @@ from git.repo import Repo
from git.repo.fun import is_git_dir
from gitignore_parser import parse_gitignore
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.utils.dependency_file import DependencyFile
from metagpt.utils.file_repository import FileRepository
@ -271,20 +270,3 @@ class GitRepository:
continue
files.append(filename)
return files
if __name__ == "__main__":
path = DEFAULT_WORKSPACE_ROOT / "git"
path.mkdir(exist_ok=True, parents=True)
repo = GitRepository()
repo.open(path, auto_init=True)
repo.filter_gitignore(filenames=["snake_game/snake_game/__pycache__", "snake_game/snake_game/game.py"])
changes = repo.changed_files
print(changes)
repo.add_change(changes)
print(repo.status)
repo.commit("test")
print(repo.status)
repo.delete_repository()

View file

@ -13,7 +13,6 @@ from pathlib import Path
import aiofiles
from metagpt.config import CONFIG
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.utils.common import check_cmd_exists
@ -146,9 +145,3 @@ sequenceDiagram
S-->>SE: return summary
SE-->>M: return summary
"""
if __name__ == "__main__":
loop = asyncio.new_event_loop()
result = loop.run_until_complete(mermaid_to_file(MMC1, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/2"))
loop.close()

View file

@ -1,219 +1,67 @@
# !/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @Desc: { redis client }
# @Date: 2022/11/28 10:12
import json
"""
@Time : 2023/12/27
@Author : mashenquan
@File : redis.py
"""
import traceback
from datetime import timedelta
from enum import Enum
from typing import Awaitable, Callable, Dict, Optional, Union
from redis import asyncio as aioredis
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
from metagpt.config import CONFIG
from metagpt.logs import logger
class RedisTypeEnum(Enum):
"""Redis 数据类型"""
String = "String"
List = "List"
Hash = "Hash"
Set = "Set"
ZSet = "ZSet"
def make_url(
dialect: str,
*,
user: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[Union[str, int]] = None,
name: Optional[Union[str, int]] = None,
) -> str:
url_parts = [f"{dialect}://"]
if user or password:
if user:
url_parts.append(user)
if password:
url_parts.append(f":{password}")
url_parts.append("@")
if not host and not dialect.startswith("sqlite"):
host = "127.0.0.1"
if host:
url_parts.append(f"{host}")
if port:
url_parts.append(f":{port}")
# 比如redis可能传入0
if name is not None:
url_parts.append(f"/{name}")
return "".join(url_parts)
class RedisAsyncClient(aioredis.Redis):
"""异步的客户端
例子::
rdb = RedisAsyncClient()
print(rdb.url)
Args:
host: 服务器地址
port: 服务器端口
user: 用户名
db: 数据库
password: 密码
decode_responses: 字符串输入被编码成utf8存储在Redis里了而取出来的时候还是被编码后的bytes需要显示的decode才能变成字符串
health_check_interval: 定时检测连接防止出现ConnectionErrors (104, Connection reset by peer)
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str = None,
decode_responses=True,
health_check_interval=10,
socket_connect_timeout=5,
retry_on_timeout=True,
socket_keepalive=True,
**kwargs,
):
super().__init__(
host=host,
port=port,
db=db,
password=password,
decode_responses=decode_responses,
health_check_interval=health_check_interval,
socket_connect_timeout=socket_connect_timeout,
retry_on_timeout=retry_on_timeout,
socket_keepalive=socket_keepalive,
**kwargs,
)
self.url = make_url("redis", host=host, port=port, name=db, password=password)
class RedisCacheInfo(object):
"""统一缓存信息类"""
def __init__(self, key, timeout: Union[int, timedelta] = timedelta(seconds=60), data_type=RedisTypeEnum.String):
"""
缓存信息类初始化
Args:
key: 缓存的key
timeout: 缓存过期时间, 单位秒
data_type: 缓存采用的数据结构 (不传并不影响用于标记业务采用的是什么数据结构)
"""
self.key = key
self.timeout = timeout
self.data_type = data_type
def __str__(self):
return f"cache key {self.key} timeout {self.timeout}s"
class RedisManager:
client: RedisAsyncClient = None
@classmethod
def init_redis_conn(cls, host, port, password, db):
"""初始化redis 连接"""
if cls.client is None:
cls.client = RedisAsyncClient(host=host, port=port, password=password, db=db)
@classmethod
async def set_with_cache_info(cls, redis_cache_info: RedisCacheInfo, value):
"""
根据 RedisCacheInfo 设置 Redis 缓存
:param redis_cache_info: RedisCacheInfo缓存信息对象
:param value: 缓存的值
:return:
"""
await cls.client.setex(redis_cache_info.key, redis_cache_info.timeout, value)
@classmethod
async def get_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
"""
根据 RedisCacheInfo 获取 Redis 缓存
:param redis_cache_info: RedisCacheInfo 缓存信息对象
:return:
"""
cache_info = await cls.client.get(redis_cache_info.key)
return cache_info
@classmethod
async def del_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
"""
根据 RedisCacheInfo 删除 Redis 缓存
:param redis_cache_info: RedisCacheInfo缓存信息对象
:return:
"""
await cls.client.delete(redis_cache_info.key)
@staticmethod
async def get_or_set_cache(cache_info: RedisCacheInfo, fetch_data_func: Callable[[], Awaitable[dict]]) -> dict:
"""
获取缓存数据如果缓存不存在则从提供的函数中获取并设置缓存
当前版本仅支持 json 形式的 string 格式数据
"""
serialized_data = await RedisManager.get_with_cache_info(cache_info)
if serialized_data:
return json.loads(serialized_data)
data = await fetch_data_func()
try:
serialized_data = json.dumps(data)
await RedisManager.set_with_cache_info(cache_info, serialized_data)
except Exception as e:
logger.warning(f"数据 {data} 通过 json 进行序列化缓存失败:{e}")
return data
@classmethod
def is_valid(cls):
return cls.client is not None
class Redis:
def __init__(self, conf: Dict = None):
def __init__(self):
self._client = None
async def _connect(self, force=False):
if self._client and not force:
return True
if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None:
return False
try:
host = CONFIG.REDIS_HOST
port = int(CONFIG.REDIS_PORT)
pwd = CONFIG.REDIS_PASSWORD
db = CONFIG.REDIS_DB
RedisManager.init_redis_conn(host=host, port=port, password=pwd, db=db)
self._client = await aioredis.from_url(
f"redis://{CONFIG.REDIS_HOST}:{CONFIG.REDIS_PORT}",
username=CONFIG.REDIS_USER,
password=CONFIG.REDIS_PASSWORD,
db=CONFIG.REDIS_DB,
)
return True
except Exception as e:
logger.warning(f"Redis initialization has failed:{e}")
return False
def is_valid(self):
return RedisManager.is_valid()
async def get(self, key: str) -> str:
if not self.is_valid() or not key:
async def get(self, key: str) -> bytes:
if not await self._connect() or not key:
return None
try:
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
v = await self._client.get(key)
return v
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
return None
async def set(self, key: str, data: str, timeout_sec: int):
if not self.is_valid() or not key:
async def set(self, key: str, data: str, timeout_sec: int = None):
if not await self._connect() or not key:
return
try:
await RedisManager.set_with_cache_info(
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
)
ex = None if not timeout_sec else timedelta(seconds=timeout_sec)
await self._client.set(key, data, ex=ex)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
async def close(self):
if not self._client:
return
await self._client.close()
self._client = None
@property
def is_valid(self):
return bool(self._client)

View file

@ -136,8 +136,7 @@ class S3:
pathname = path / object_name
try:
async with aiofiles.open(str(pathname), mode="wb") as file:
if format == BASE64_FORMAT:
data = base64.b64decode(data)
data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8")
await file.write(data)
bucket = CONFIG.S3_BUCKET