diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index fba396896..3a56248c1 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -23,6 +23,7 @@ from metagpt.schema import ( SerializationMixin, TestingContext, ) +from metagpt.utils.file_repository import FileRepository class Action(SerializationMixin, is_polymorphic_base=True): @@ -40,6 +41,10 @@ class Action(SerializationMixin, is_polymorphic_base=True): def git_repo(self): return self.g_context.git_repo + @property + def file_repo(self): + return FileRepository(self.g_context.git_repo) + @property def src_workspace(self): return self.g_context.src_workspace diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index afae03cb5..ae5aaf2b5 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -47,7 +47,7 @@ class PrepareDocuments(Action): # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. doc = Document(root_path=DOCS_FILE_REPO, filename=REQUIREMENT_FILENAME, content=with_messages[0].content) - await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=doc.content, relative_path=DOCS_FILE_REPO) + await self.file_repo.save_file(filename=REQUIREMENT_FILENAME, content=doc.content, relative_path=DOCS_FILE_REPO) # Send a Message notification to the WritePRD action, instructing it to process requirements using # `docs/requirement.txt` and `docs/prds/`. diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 0ba5477c6..7ade1420c 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -95,14 +95,14 @@ class WriteCode(Action): return code async def run(self, *args, **kwargs) -> CodingContext: - bug_feedback = await FileRepository.get_file(filename=BUGFIX_FILENAME, relative_path=DOCS_FILE_REPO) + bug_feedback = await self.file_repo.get_file(filename=BUGFIX_FILENAME, relative_path=DOCS_FILE_REPO) coding_context = CodingContext.loads(self.context.content) - test_doc = await FileRepository.get_file( + test_doc = await self.file_repo.get_file( filename="test_" + coding_context.filename + ".json", relative_path=TEST_OUTPUTS_FILE_REPO ) summary_doc = None if coding_context.design_doc and coding_context.design_doc.filename: - summary_doc = await FileRepository.get_file( + summary_doc = await self.file_repo.get_file( filename=coding_context.design_doc.filename, relative_path=CODE_SUMMARIES_FILE_REPO ) logs = "" diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 427c8acb5..7f1a49231 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -49,8 +49,3 @@ class ProductManager(Role): async def _observe(self, ignore_memory=False) -> int: return await super()._observe(ignore_memory=True) - - @property - def todo(self) -> str: - """AgentStore uses this attribute to display to the user what actions the current role should take.""" - return self.todo_action diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py index ff750fbbb..48e38b27a 100644 --- a/metagpt/utils/file_repository.py +++ b/metagpt/utils/file_repository.py @@ -16,7 +16,6 @@ from typing import Dict, List, Set import aiofiles -from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import Document from metagpt.utils.common import aread @@ -201,8 +200,7 @@ class FileRepository: await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies) logger.debug(f"File Saved: {str(filename)}") - @staticmethod - async def get_file(filename: Path | str, relative_path: Path | str = ".") -> Document | None: + async def get_file(self, filename: Path | str, relative_path: Path | str = ".") -> Document | None: """Retrieve a specific file from the file repository. :param filename: The name or path of the file to retrieve. @@ -212,11 +210,10 @@ class FileRepository: :return: The document representing the file, or None if not found. :rtype: Document or None """ - file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path) + file_repo = self._git_repo.new_file_repository(relative_path=relative_path) return await file_repo.get(filename=filename) - @staticmethod - async def get_all_files(relative_path: Path | str = ".") -> List[Document]: + async def get_all_files(self, relative_path: Path | str = ".") -> List[Document]: """Retrieve all files from the file repository. :param relative_path: The relative path within the file repository. @@ -224,11 +221,12 @@ class FileRepository: :return: A list of documents representing all files in the repository. :rtype: List[Document] """ - file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path) + file_repo = self._git_repo.new_file_repository(relative_path=relative_path) return await file_repo.get_all() - @staticmethod - async def save_file(filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."): + async def save_file( + self, filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "." + ): """Save a file to the file repository. :param filename: The name or path of the file to save. @@ -239,12 +237,11 @@ class FileRepository: :param relative_path: The relative path within the file repository. :type relative_path: Path or str, optional """ - file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path) + file_repo = self._git_repo.new_file_repository(relative_path=relative_path) return await file_repo.save(filename=filename, content=content, dependencies=dependencies) - @staticmethod async def save_as( - doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "." + self, doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "." ): """Save a Document instance with optional modifications. @@ -262,7 +259,7 @@ class FileRepository: :return: A boolean indicating whether the save operation was successful. :rtype: bool """ - file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path) + file_repo = self._git_repo.new_file_repository(relative_path=relative_path) return await file_repo.save_doc(doc=doc, with_suffix=with_suffix, dependencies=dependencies) async def delete(self, filename: Path | str): @@ -282,7 +279,6 @@ class FileRepository: await dependency_file.update(filename=pathname, dependencies=None) logger.info(f"remove dependency key: {str(pathname)}") - @staticmethod - async def delete_file(filename: Path | str, relative_path: Path | str = "."): - file_repo = CONFIG.git_repo.new_file_repository(relative_path=relative_path) + async def delete_file(self, filename: Path | str, relative_path: Path | str = "."): + file_repo = self._git_repo.new_file_repository(relative_path=relative_path) await file_repo.delete(filename=filename) diff --git a/tests/conftest.py b/tests/conftest.py index 7ed66a61d..71afdff9f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,8 +15,9 @@ from typing import Optional import pytest -from metagpt.config import CONFIG, Config +from metagpt.config2 import config from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH +from metagpt.context import context from metagpt.llm import LLM from metagpt.logs import logger from metagpt.provider.openai_api import OpenAILLM @@ -53,6 +54,7 @@ class MockLLM(OpenAILLM): timeout=3, stream=True, ) -> str: + logger.debug(f"MockLLM.aask: {msg}") if msg not in self.rsp_cache: # Call the original unmocked method rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream) @@ -81,7 +83,8 @@ def rsp_cache(): @pytest.fixture(scope="function") def llm_mock(rsp_cache, mocker): - llm = MockLLM() + llm = MockLLM(config.get_llm_config()) + llm.cost_manager = context.cost_manager llm.rsp_cache = rsp_cache mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask) yield mocker @@ -90,7 +93,7 @@ def llm_mock(rsp_cache, mocker): class Context: def __init__(self): self._llm_ui = None - self._llm_api = LLM(provider=CONFIG.get_default_llm_provider_enum()) + self._llm_api = LLM() @property def llm_api(self): @@ -153,12 +156,12 @@ def loguru_caplog(caplog): # init & dispose git repo @pytest.fixture(scope="session", autouse=True) def setup_and_teardown_git_repo(request): - CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest") - CONFIG.git_reinit = True + context.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest") + context.config.git_reinit = True # Destroy git repo at the end of the test session. def fin(): - CONFIG.git_repo.delete_repository() + context.git_repo.delete_repository() # Register the function for destroying the environment. request.addfinalizer(fin) @@ -166,4 +169,4 @@ def setup_and_teardown_git_repo(request): @pytest.fixture(scope="session", autouse=True) def init_config(): - Config() + pass diff --git a/tests/metagpt/actions/test_prepare_documents.py b/tests/metagpt/actions/test_prepare_documents.py index c7fb6af20..30aa3b482 100644 --- a/tests/metagpt/actions/test_prepare_documents.py +++ b/tests/metagpt/actions/test_prepare_documents.py @@ -25,6 +25,6 @@ async def test_prepare_documents(): await PrepareDocuments(g_context=context).run(with_messages=[msg]) assert context.git_repo - doc = await FileRepository.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO) + doc = await FileRepository(context.git_repo).get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO) assert doc assert doc.content == msg.content diff --git a/tests/metagpt/roles/test_product_manager.py b/tests/metagpt/roles/test_product_manager.py index 0538cbe6d..34cf9ce6e 100644 --- a/tests/metagpt/roles/test_product_manager.py +++ b/tests/metagpt/roles/test_product_manager.py @@ -7,6 +7,7 @@ """ import pytest +from metagpt.context import context from metagpt.logs import logger from metagpt.roles import ProductManager from tests.metagpt.roles.mock import MockMessages @@ -15,7 +16,7 @@ from tests.metagpt.roles.mock import MockMessages @pytest.mark.asyncio @pytest.mark.usefixtures("llm_mock") async def test_product_manager(): - product_manager = ProductManager() + product_manager = ProductManager(context=context) rsp = await product_manager.run(MockMessages.req) logger.info(rsp) assert len(rsp.content) > 0