diff --git a/.gitignore b/.gitignore index 922116d12..aa5edd74a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ ### Python template # Byte-compiled / optimized / DLL files -__pycache__/ +__pycache__ *.py[cod] *$py.class diff --git a/examples/di/arxiv_reader.py b/examples/di/arxiv_reader.py index fab0e2e48..6e1939b81 100644 --- a/examples/di/arxiv_reader.py +++ b/examples/di/arxiv_reader.py @@ -1,10 +1,5 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -@Time : 2024/01/15 -@Author : mannaandpoem -@File : imitate_webpage.py -""" from metagpt.roles.di.data_interpreter import DataInterpreter diff --git a/examples/search_with_specific_engine.py b/examples/search_with_specific_engine.py index 1eee762d5..276431ed8 100644 --- a/examples/search_with_specific_engine.py +++ b/examples/search_with_specific_engine.py @@ -13,7 +13,7 @@ async def main(): question = "What are the most interesting human facts?" search = Config.default().search - kwargs = {"api_key": search.api_key, "cse_id": search.cse_id, "proxy": None} + kwargs = search.model_dump() await Searcher(search_engine=SearchEngine(engine=search.api_type, **kwargs)).run(question) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 3f822568e..27dde5a8c 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -331,7 +331,7 @@ class ActionNode: def compile_to(self, i: Dict, schema, kv_sep) -> str: if schema == "json": - return json.dumps(i, indent=4) + return json.dumps(i, indent=4, ensure_ascii=False) elif schema == "markdown": return dict_to_markdown(i, kv_sep=kv_sep) else: diff --git a/metagpt/actions/di/write_analysis_code.py b/metagpt/actions/di/write_analysis_code.py index 185926e31..711e56d39 100644 --- a/metagpt/actions/di/write_analysis_code.py +++ b/metagpt/actions/di/write_analysis_code.py @@ -18,7 +18,7 @@ from metagpt.prompts.di.write_analysis_code import ( STRUCTUAL_PROMPT, ) from metagpt.schema import Message, Plan -from metagpt.utils.common import CodeParser, process_message, remove_comments +from metagpt.utils.common import CodeParser, remove_comments class WriteAnalysisCode(Action): @@ -50,7 +50,7 @@ class WriteAnalysisCode(Action): ) working_memory = working_memory or [] - context = process_message([Message(content=structual_prompt, role="user")] + working_memory) + context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory) # LLM call if use_reflection: diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py index af928b02a..e28b14c99 100644 --- a/metagpt/configs/search_config.py +++ b/metagpt/configs/search_config.py @@ -7,6 +7,8 @@ """ from typing import Callable, Optional +from pydantic import Field + from metagpt.tools import SearchEngineType from metagpt.utils.yaml_model import YamlModel @@ -18,3 +20,11 @@ class SearchConfig(YamlModel): api_key: str = "" cse_id: str = "" # for google search_func: Optional[Callable] = None + params: dict = Field( + default_factory=lambda: { + "engine": "google", + "google_domain": "google.com", + "gl": "us", + "hl": "en", + } + ) diff --git a/metagpt/learn/skill_loader.py b/metagpt/learn/skill_loader.py index bcf28bb87..e98f73cf9 100644 --- a/metagpt/learn/skill_loader.py +++ b/metagpt/learn/skill_loader.py @@ -9,11 +9,11 @@ from pathlib import Path from typing import Dict, List, Optional -import aiofiles import yaml from pydantic import BaseModel, Field from metagpt.context import Context +from metagpt.utils.common import aread class Example(BaseModel): @@ -68,8 +68,7 @@ class SkillsDeclaration(BaseModel): async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration": if not skill_yaml_file_name: skill_yaml_file_name = Path(__file__).parent.parent.parent / "docs/.well-known/skills.yaml" - async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader: - data = await reader.read(-1) + data = await aread(filename=skill_yaml_file_name) skill_data = yaml.safe_load(data) return SkillsDeclaration(**skill_data) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index e085d0187..db2757ec3 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -74,6 +74,28 @@ class BaseLLM(ABC): def _system_msg(self, msg: str) -> dict[str, str]: return {"role": "system", "content": msg} + def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + """convert messages to list[dict].""" + from metagpt.schema import Message + + if not isinstance(messages, list): + messages = [messages] + + processed_messages = [] + for msg in messages: + if isinstance(msg, str): + processed_messages.append({"role": "user", "content": msg}) + elif isinstance(msg, dict): + assert set(msg.keys()) == set(["role", "content"]) + processed_messages.append(msg) + elif isinstance(msg, Message): + processed_messages.append(msg.to_dict()) + else: + raise ValueError( + f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!" + ) + return processed_messages + def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]: return [self._system_msg(msg) for msg in msgs] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index e041f4c87..7a35f0a9d 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart +import os from typing import Optional, Union import google.generativeai as genai @@ -16,9 +17,10 @@ from google.generativeai.types.generation_types import ( from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import USE_CONFIG_TIMEOUT -from metagpt.logs import log_llm_stream +from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider +from metagpt.schema import Message class GeminiGenerativeModel(GenerativeModel): @@ -52,6 +54,10 @@ class GeminiLLM(BaseLLM): self.llm = GeminiGenerativeModel(model_name=self.model) def __init_gemini(self, config: LLMConfig): + if config.proxy: + logger.info(f"Use proxy: {config.proxy}") + os.environ["HTTP_PROXY"] = config.proxy + os.environ["HTTP_PROXYS"] = config.proxy genai.configure(api_key=config.api_key) def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, str]: @@ -62,6 +68,35 @@ class GeminiLLM(BaseLLM): def _assistant_msg(self, msg: str) -> dict[str, str]: return {"role": "model", "parts": [msg]} + def _system_msg(self, msg: str) -> dict[str, str]: + return {"role": "user", "parts": [msg]} + + def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + """convert messages to list[dict].""" + from metagpt.schema import Message + + if not isinstance(messages, list): + messages = [messages] + + # REF: https://ai.google.dev/tutorials/python_quickstart + # As a dictionary, the message requires `role` and `parts` keys. + # The role in a conversation can either be the `user`, which provides the prompts, + # or `model`, which provides the responses. + processed_messages = [] + for msg in messages: + if isinstance(msg, str): + processed_messages.append({"role": "user", "parts": [msg]}) + elif isinstance(msg, dict): + assert set(msg.keys()) == set(["role", "parts"]) + processed_messages.append(msg) + elif isinstance(msg, Message): + processed_messages.append({"role": "user" if msg.role == "user" else "model", "parts": [msg.content]}) + else: + raise ValueError( + f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!" + ) + return processed_messages + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 10b7749d6..dbfed72df 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -30,12 +30,7 @@ from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider -from metagpt.utils.common import ( - CodeParser, - decode_image, - log_and_reraise, - process_message, -) +from metagpt.utils.common import CodeParser, decode_image, log_and_reraise from metagpt.utils.cost_manager import CostManager from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( @@ -151,7 +146,7 @@ class OpenAILLM(BaseLLM): async def _achat_completion_function( self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **chat_configs ) -> ChatCompletion: - messages = process_message(messages) + messages = self.format_msg(messages) kwargs = self._cons_kwargs(messages=messages, timeout=self.get_timeout(timeout), **chat_configs) rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs) self._update_costs(rsp.usage) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e9cef69a4..c8daba724 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -29,6 +29,7 @@ from typing import Any, Callable, List, Literal, Tuple, Union from urllib.parse import quote, unquote import aiofiles +import chardet import loguru import requests from PIL import Image @@ -663,14 +664,21 @@ def role_raise_decorator(func): @handle_exception -async def aread(filename: str | Path, encoding=None) -> str: +async def aread(filename: str | Path, encoding="utf-8") -> str: """Read file asynchronously.""" - async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader: - content = await reader.read() + try: + async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader: + content = await reader.read() + except UnicodeDecodeError: + async with aiofiles.open(str(filename), mode="rb") as reader: + raw = await reader.read() + result = chardet.detect(raw) + detected_encoding = result["encoding"] + content = raw.decode(detected_encoding) return content -async def awrite(filename: str | Path, data: str, encoding=None): +async def awrite(filename: str | Path, data: str, encoding="utf-8"): """Write file asynchronously.""" pathname = Path(filename) pathname.parent.mkdir(parents=True, exist_ok=True) @@ -802,29 +810,6 @@ def decode_image(img_url_or_b64: str) -> Image: return img -def process_message(messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: - """convert messages to list[dict].""" - from metagpt.schema import Message - - # 全部转成list - if not isinstance(messages, list): - messages = [messages] - - # 转成list[dict] - processed_messages = [] - for msg in messages: - if isinstance(msg, str): - processed_messages.append({"role": "user", "content": msg}) - elif isinstance(msg, dict): - assert set(msg.keys()) == set(["role", "content"]) - processed_messages.append(msg) - elif isinstance(msg, Message): - processed_messages.append(msg.to_dict()) - else: - raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!") - return processed_messages - - def log_and_reraise(retry_state: RetryCallState): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( diff --git a/metagpt/utils/dependency_file.py b/metagpt/utils/dependency_file.py index d3add1171..0a375051c 100644 --- a/metagpt/utils/dependency_file.py +++ b/metagpt/utils/dependency_file.py @@ -13,9 +13,7 @@ import re from pathlib import Path from typing import Set -import aiofiles - -from metagpt.utils.common import aread +from metagpt.utils.common import aread, awrite from metagpt.utils.exceptions import handle_exception @@ -45,8 +43,7 @@ class DependencyFile: async def save(self): """Save dependencies to the file asynchronously.""" data = json.dumps(self._dependencies) - async with aiofiles.open(str(self._filename), mode="w") as writer: - await writer.write(data) + await awrite(filename=self._filename, data=data) async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True): """Update dependencies for a file asynchronously. diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py index d2a06963a..d19f2b705 100644 --- a/metagpt/utils/file_repository.py +++ b/metagpt/utils/file_repository.py @@ -14,11 +14,9 @@ from datetime import datetime from pathlib import Path from typing import Dict, List, Set -import aiofiles - from metagpt.logs import logger from metagpt.schema import Document -from metagpt.utils.common import aread +from metagpt.utils.common import aread, awrite from metagpt.utils.json_to_markdown import json_to_markdown @@ -55,8 +53,7 @@ class FileRepository: pathname = self.workdir / filename pathname.parent.mkdir(parents=True, exist_ok=True) content = content if content else "" # avoid `argument must be str, not None` to make it continue - async with aiofiles.open(str(pathname), mode="w") as writer: - await writer.write(content) + await awrite(filename=str(pathname), data=content) logger.info(f"save to: {str(pathname)}") if dependencies is not None: diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py index ae3c5118f..e1d140e84 100644 --- a/metagpt/utils/mermaid.py +++ b/metagpt/utils/mermaid.py @@ -9,11 +9,9 @@ import asyncio import os from pathlib import Path -import aiofiles - from metagpt.config2 import config from metagpt.logs import logger -from metagpt.utils.common import check_cmd_exists +from metagpt.utils.common import awrite, check_cmd_exists async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: @@ -30,9 +28,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name) tmp = Path(f"{output_file_without_suffix}.mmd") - async with aiofiles.open(tmp, "w", encoding="utf-8") as f: - await f.write(mermaid_code) - # tmp.write_text(mermaid_code, encoding="utf-8") + await awrite(filename=tmp, data=mermaid_code) if engine == "nodejs": if check_cmd_exists(config.mermaid.path) != 0: diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index b8756e8c6..17e095c5f 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -340,7 +340,9 @@ def extract_state_value_from_output(content: str) -> str: content (str): llm's output from `Role._think` """ content = content.strip() # deal the output cases like " 0", "0\n" and so on. - pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern + pattern = ( + r"(? 0 else "-1" diff --git a/metagpt/utils/tree.py b/metagpt/utils/tree.py new file mode 100644 index 000000000..bd7922290 --- /dev/null +++ b/metagpt/utils/tree.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/3/11 +@Author : mashenquan +@File : tree.py +@Desc : Implement the same functionality as the `tree` command. + Example: + >>> print_tree(".") + utils + +-- serialize.py + +-- project_repo.py + +-- tree.py + +-- mmdc_playwright.py + +-- cost_manager.py + +-- __pycache__ + | +-- __init__.cpython-39.pyc + | +-- redis.cpython-39.pyc + | +-- singleton.cpython-39.pyc + | +-- embedding.cpython-39.pyc + | +-- make_sk_kernel.cpython-39.pyc + | +-- file_repository.cpython-39.pyc + +-- file.py + +-- save_code.py + +-- common.py + +-- redis.py +""" +from __future__ import annotations + +import subprocess +from pathlib import Path +from typing import Callable, Dict, List + +from gitignore_parser import parse_gitignore + + +def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str: + """ + Recursively traverses the directory structure and prints it out in a tree-like format. + + Args: + root (str or Path): The root directory from which to start traversing. + gitignore (str or Path): The filename of gitignore file. + run_command (bool): Whether to execute `tree` command. Execute the `tree` command and return the result if True, + otherwise execute python code instead. + + Returns: + str: A string representation of the directory tree. + + Example: + >>> tree(".") + utils + +-- serialize.py + +-- project_repo.py + +-- tree.py + +-- mmdc_playwright.py + +-- __pycache__ + | +-- __init__.cpython-39.pyc + | +-- redis.cpython-39.pyc + | +-- singleton.cpython-39.pyc + +-- parse_docstring.py + + >>> tree(".", gitignore="../../.gitignore") + utils + +-- serialize.py + +-- project_repo.py + +-- tree.py + +-- mmdc_playwright.py + +-- parse_docstring.py + + >>> tree(".", gitignore="../../.gitignore", run_command=True) + utils + ├── serialize.py + ├── project_repo.py + ├── tree.py + ├── mmdc_playwright.py + └── parse_docstring.py + + + """ + root = Path(root).resolve() + if run_command: + return _execute_tree(root, gitignore) + + git_ignore_rules = parse_gitignore(gitignore) if gitignore else None + dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)} + v = _print_tree(dir_) + return "\n".join(v) + + +def _list_children(root: Path, git_ignore_rules: Callable) -> Dict[str, Dict]: + dir_ = {} + for i in root.iterdir(): + if git_ignore_rules and git_ignore_rules(str(i)): + continue + try: + if i.is_file(): + dir_[i.name] = {} + else: + dir_[i.name] = _list_children(root=i, git_ignore_rules=git_ignore_rules) + except (FileNotFoundError, PermissionError, OSError): + dir_[i.name] = {} + return dir_ + + +def _print_tree(dir_: Dict[str:Dict]) -> List[str]: + ret = [] + for name, children in dir_.items(): + ret.append(name) + if not children: + continue + lines = _print_tree(children) + for j, v in enumerate(lines): + if v[0] not in ["+", " ", "|"]: + ret = _add_line(ret) + row = f"+-- {v}" + else: + row = f" {v}" + ret.append(row) + return ret + + +def _add_line(rows: List[str]) -> List[str]: + for i in range(len(rows) - 1, -1, -1): + v = rows[i] + if v[0] != " ": + return rows + rows[i] = "|" + v[1:] + return rows + + +def _execute_tree(root: Path, gitignore: str | Path) -> str: + args = ["--gitfile", str(gitignore)] if gitignore else [] + try: + result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True) + if result.returncode != 0: + raise ValueError(f"tree exits with code {result.returncode}") + return result.stdout + except subprocess.CalledProcessError as e: + raise e diff --git a/setup.py b/setup.py index 7a14c6182..df9bedc9b 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pr setup( name="metagpt", - version="0.7.4", + version="0.7.6", description="The Multi-Agent Framework", long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/metagpt/roles/test_tutorial_assistant.py b/tests/metagpt/roles/test_tutorial_assistant.py index c12c2b26e..732f346fd 100644 --- a/tests/metagpt/roles/test_tutorial_assistant.py +++ b/tests/metagpt/roles/test_tutorial_assistant.py @@ -6,11 +6,11 @@ @File : test_tutorial_assistant.py """ -import aiofiles import pytest from metagpt.const import TUTORIAL_PATH from metagpt.roles.tutorial_assistant import TutorialAssistant +from metagpt.utils.common import aread @pytest.mark.asyncio @@ -20,9 +20,8 @@ async def test_tutorial_assistant(language: str, topic: str, context): msg = await role.run(topic) assert TUTORIAL_PATH.exists() filename = msg.content - async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader: - content = await reader.read() - assert "pip" in content + content = await aread(filename=filename) + assert "pip" in content if __name__ == "__main__": diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index a1f03ef7b..964ead02f 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -11,7 +11,6 @@ from typing import Callable import pytest -from metagpt.config2 import config from metagpt.configs.search_config import SearchConfig from metagpt.logs import logger from metagpt.tools import SearchEngineType @@ -53,14 +52,11 @@ async def test_search_engine( search_engine_config = {"engine": search_engine_type, "run_func": run_func} if search_engine_type is SearchEngineType.SERPAPI_GOOGLE: - assert config.search search_engine_config["api_key"] = "mock-serpapi-key" elif search_engine_type is SearchEngineType.DIRECT_GOOGLE: - assert config.search search_engine_config["api_key"] = "mock-google-key" search_engine_config["cse_id"] = "mock-google-cse" elif search_engine_type is SearchEngineType.SERPER_GOOGLE: - assert config.search search_engine_config["api_key"] = "mock-serper-key" async def test(search_engine): diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index b365f424f..75e8ef4ad 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -13,7 +13,6 @@ import uuid from pathlib import Path from typing import Any, Set -import aiofiles import pytest from pydantic import BaseModel @@ -125,9 +124,7 @@ class TestGetProjectRoot: async def test_parse_data_exception(self, filename, want): pathname = Path(__file__).parent.parent.parent / "data/output_parser" / filename assert pathname.exists() - async with aiofiles.open(str(pathname), mode="r") as reader: - data = await reader.read() - + data = await aread(filename=pathname) result = OutputParser.parse_data(data=data) assert want in result @@ -198,12 +195,25 @@ class TestGetProjectRoot: @pytest.mark.asyncio async def test_read_write(self): - pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp" + pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.tmp" await awrite(pathname, "ABC") data = await aread(pathname) assert data == "ABC" pathname.unlink(missing_ok=True) + @pytest.mark.asyncio + async def test_read_write_error_charset(self): + pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.txt" + content = "中国abc123\u27f6" + await awrite(filename=pathname, data=content) + data = await aread(filename=pathname) + assert data == content + + content = "GB18030 是中国国家标准局发布的新一代中文字符集标准,是 GBK 的升级版,支持更广泛的字符范围。" + await awrite(filename=pathname, data=content, encoding="gb2312") + data = await aread(filename=pathname, encoding="utf-8") + assert data == content + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_git_repository.py b/tests/metagpt/utils/test_git_repository.py index ea28b8f0b..480a22e24 100644 --- a/tests/metagpt/utils/test_git_repository.py +++ b/tests/metagpt/utils/test_git_repository.py @@ -10,15 +10,14 @@ import shutil from pathlib import Path -import aiofiles import pytest +from metagpt.utils.common import awrite from metagpt.utils.git_repository import GitRepository async def mock_file(filename, content=""): - async with aiofiles.open(str(filename), mode="w") as file: - await file.write(content) + await awrite(filename=filename, data=content) async def mock_repo(local_path) -> (GitRepository, Path): diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py index 34c21612c..ef13c2325 100644 --- a/tests/metagpt/utils/test_s3.py +++ b/tests/metagpt/utils/test_s3.py @@ -9,7 +9,6 @@ import uuid from pathlib import Path import aioboto3 -import aiofiles import pytest from metagpt.config2 import Config @@ -46,7 +45,7 @@ async def test_s3(mocker): conn = S3(s3) object_name = "unittest.bak" await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name) - pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak") + pathname = (Path(__file__).parent / "../../../workspace/unittest" / uuid.uuid4().hex).with_suffix(".bak") pathname.unlink(missing_ok=True) await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname)) assert pathname.exists() @@ -54,8 +53,7 @@ async def test_s3(mocker): assert url bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name) assert bin_data - async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader: - data = await reader.read() + data = await aread(filename=__file__) res = await conn.cache(data, ".bak", "script") assert "http" in res @@ -69,8 +67,6 @@ async def test_s3(mocker): except Exception: pass - await reader.close() - if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_tree.py b/tests/metagpt/utils/test_tree.py new file mode 100644 index 000000000..03a2a5606 --- /dev/null +++ b/tests/metagpt/utils/test_tree.py @@ -0,0 +1,64 @@ +from pathlib import Path +from typing import List + +import pytest + +from metagpt.utils.tree import _print_tree, tree + + +@pytest.mark.parametrize( + ("root", "rules"), + [ + (str(Path(__file__).parent / "../.."), None), + (str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")), + ], +) +def test_tree(root: str, rules: str): + v = tree(root=root, gitignore=rules) + assert v + + +@pytest.mark.parametrize( + ("root", "rules"), + [ + (str(Path(__file__).parent / "../.."), None), + (str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")), + ], +) +def test_tree_command(root: str, rules: str): + v = tree(root=root, gitignore=rules, run_command=True) + assert v + + +@pytest.mark.parametrize( + ("tree", "want"), + [ + ({"a": {"b": {}, "c": {}}}, ["a", "+-- b", "+-- c"]), + ({"a": {"b": {}, "c": {"d": {}}}}, ["a", "+-- b", "+-- c", " +-- d"]), + ( + {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}}, + ["a", "+-- b", "| +-- e", "| +-- f", "| +-- g", "+-- c", " +-- d"], + ), + ( + {"h": {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}, "i": {}}}, + [ + "h", + "+-- a", + "| +-- b", + "| | +-- e", + "| | +-- f", + "| | +-- g", + "| +-- c", + "| +-- d", + "+-- i", + ], + ), + ], +) +def test__print_tree(tree: dict, want: List[str]): + v = _print_tree(tree) + assert v == want + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index b4cdfa0cf..c4262e080 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -8,7 +8,6 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import Message -from metagpt.utils.common import process_message OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM @@ -105,7 +104,7 @@ class MockLLM(OriginalLLM): return rsp async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: - msg_key = json.dumps(process_message(messages), ensure_ascii=False) + msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False) rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs) return rsp