diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 8bf11356a..981e1405a 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -8,6 +8,7 @@ 1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. 2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ import json import uuid @@ -28,6 +29,7 @@ from metagpt.actions.design_api_an import ( from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import aread, awrite, to_markdown_code_block from metagpt.utils.mermaid import mermaid_to_file from metagpt.utils.project_repo import ProjectRepo @@ -42,6 +44,7 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write system design"]) class WriteDesign(Action): name: str = "" i_context: Optional[str] = None @@ -163,7 +166,7 @@ class WriteDesign(Action): output_path=output_path, ) - self.input_args = with_messages[0].instruct_content + self.input_args = with_messages[-1].instruct_content self.repo = ProjectRepo(self.input_args.project_path) changed_prds = self.input_args.changed_prd_filenames changed_system_designs = [ @@ -283,7 +286,12 @@ class WriteDesign(Action): ) if not output_path: - return AIMessage(content=design.instruct_content.model_dump_json()) + return AIMessage(content=design.content) output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json" await awrite(filename=output_filename, data=design.content) - return AIMessage(content=f'System Design filename: "{str(output_filename)}"') + kvs = {"changed_system_design_filenames": [output_filename]} + + return AIMessage( + content=f'System Design filename: "{str(output_filename)}"', + instruct_content=AIMessage.create_instruct_value(kvs=kvs), + ) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 89ebd59a3..393c483cc 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -46,8 +46,8 @@ class PrepareDocuments(Action): path = Path(self.config.project_path) if path.exists() and not self.config.inc: shutil.rmtree(path) - self.config.project_path = path - self.context.set_repo_dir(path) + self.context.kwargs.project_path = path + self.context.kwargs.inc = self.config.inc return ProjectRepo(path) async def run(self, with_messages, **kwargs): diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 9880a10f3..b44bfb9f3 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -8,6 +8,7 @@ 1. Divide the context into three components: legacy code, unit test code, and console log. 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ import json @@ -21,6 +22,7 @@ from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import aread, to_markdown_code_block from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter @@ -34,6 +36,7 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write a project schedule given a project system design file"]) class WriteTasks(Action): name: str = "CreateTasks" i_context: Optional[str] = None @@ -73,7 +76,7 @@ class WriteTasks(Action): if not with_messages: return await self._execute_api(user_requirement=user_requirement, design_filename=design_filename) - self.input_args = with_messages[0].instruct_content + self.input_args = with_messages[-1].instruct_content self.repo = ProjectRepo(self.input_args.project_path) changed_system_designs = self.input_args.changed_system_design_filenames changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())] @@ -82,7 +85,7 @@ class WriteTasks(Action): # `docs/system_designs/`. for filename in changed_system_designs: task_doc = await self._update_tasks(filename=filename) - change_files.docs[filename] = task_doc + change_files.docs[str(self.repo.docs.task.workdir / task_doc.filename)] = task_doc # Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`. for filename in changed_tasks: diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 0584a247f..de3bcde84 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -9,6 +9,7 @@ 2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality. 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ from __future__ import annotations @@ -38,6 +39,7 @@ from metagpt.const import ( ) from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, aread, awrite, to_markdown_code_block from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file @@ -64,6 +66,7 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write product requirement documents"]) class WritePRD(Action): """WritePRD deal with the following situations: 1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated. @@ -145,11 +148,11 @@ class WritePRD(Action): self.input_args = with_messages[-1].instruct_content if not self.input_args: - self.repo = ProjectRepo(self.config.project_path) + self.repo = ProjectRepo(self.context.kwargs.project_path) await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content) self.input_args = AIMessage.create_instruct_value( kvs={ - "project_path": self.config.project_path, + "project_path": self.context.kwargs.project_path, "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), "prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files], }, @@ -183,6 +186,9 @@ class WritePRD(Action): kvs["changed_prd_filenames"] = [ str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys()) ] + kvs["project_path"] = str(self.repo.workdir) + kvs["requirements_filename"] = str(self.repo.docs.workdir / REQUIREMENT_FILENAME) + self.context.kwargs.project_path = str(self.repo.workdir) return AIMessage( content="PRD is completed. " + "\n".join( @@ -302,7 +308,7 @@ class WritePRD(Action): async def _execute_api( self, user_requirement: str, output_path: str, exists_prd_filename: str, extra_info: str ) -> AIMessage: - content = to_markdown_code_block(val=user_requirement) + content = to_markdown_code_block(val=user_requirement, type_="text") if extra_info: content += to_markdown_code_block(val=extra_info) @@ -320,4 +326,5 @@ class WritePRD(Action): output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json" await awrite(filename=output_filename, data=new_prd.content) - return AIMessage(content=f'PRD filename: "{str(output_filename)}"') + kvs = AIMessage.create_instruct_value({"changed_prd_filenames": [str(output_filename)]}) + return AIMessage(content=f'PRD filename: "{str(output_filename)}"', instruct_content=kvs) diff --git a/metagpt/context.py b/metagpt/context.py index f1c3568d9..384e8da48 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -8,7 +8,6 @@ from __future__ import annotations import os -from pathlib import Path from typing import Any, Dict, Optional from pydantic import BaseModel, ConfigDict @@ -22,8 +21,6 @@ from metagpt.utils.cost_manager import ( FireworksCostManager, TokenCostManager, ) -from metagpt.utils.git_repository import GitRepository -from metagpt.utils.project_repo import ProjectRepo class AttrDict(BaseModel): @@ -66,9 +63,6 @@ class Context(BaseModel): kwargs: AttrDict = AttrDict() config: Config = Config.default() - repo: Optional[ProjectRepo] = None - git_repo: Optional[GitRepository] = None - src_workspace: Optional[Path] = None cost_manager: CostManager = CostManager() _llm: Optional[BaseLLM] = None @@ -80,11 +74,6 @@ class Context(BaseModel): # env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env - def set_repo_dir(self, path: str | Path): - repo_path = Path(path) - self.git_repo = GitRepository(local_path=repo_path, auto_init=True) - self.repo = ProjectRepo(self.git_repo) - def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: """Return a CostManager instance""" if llm_config.api_type == LLMType.FIREWORKS: @@ -117,7 +106,6 @@ class Context(BaseModel): Dict[str, Any]: A dictionary containing serialized data. """ return { - "workdir": str(self.repo.workdir) if self.repo else "", "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, "cost_manager": self.cost_manager.model_dump_json(), } @@ -130,13 +118,6 @@ class Context(BaseModel): """ if not serialized_data: return - workdir = serialized_data.get("workdir") - if workdir: - self.git_repo = GitRepository(local_path=workdir, auto_init=True) - self.repo = ProjectRepo(self.git_repo) - src_workspace = self.git_repo.workdir / self.git_repo.workdir.name - if src_workspace.exists(): - self.src_workspace = src_workspace kwargs = serialized_data.get("kwargs") if kwargs: for k, v in kwargs.items(): diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index fe1660fc6..5d6d3a286 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -22,6 +22,7 @@ from metagpt.logs import logger from metagpt.memory import Memory from metagpt.schema import Message from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to +from metagpt.utils.git_repository import GitRepository if TYPE_CHECKING: from metagpt.roles.role import Role # noqa: F401 @@ -243,8 +244,9 @@ class Environment(ExtEnv): self.member_addrs[obj] = addresses def archive(self, auto_archive=True): - if auto_archive and self.context.git_repo: - self.context.git_repo.archive() + if auto_archive and self.context.kwargs.get("project_path"): + git_repo = GitRepository(self.context.kwargs.project_path) + git_repo.archive() @classmethod def model_rebuild(cls, **kwargs): diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 58d8076ab..1f66758ea 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -10,7 +10,7 @@ from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments -from metagpt.roles.role import Role +from metagpt.roles.role import Role, RoleReactMode from metagpt.utils.common import any_to_name, any_to_str from metagpt.utils.git_repository import GitRepository @@ -35,6 +35,7 @@ class ProductManager(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.enable_memory = False + self.rc.react_mode = RoleReactMode.BY_ORDER self.set_actions([PrepareDocuments(send_to=any_to_str(self)), WritePRD]) self._watch([UserRequirement, PrepareDocuments]) self.todo_action = any_to_name(WritePRD) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 4cc9eb9e2..fc8fa5353 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -94,7 +94,7 @@ class QaEngineer(Role): code_filename=context.code_doc.filename, test_filename=context.test_doc.filename, working_directory=str(self.repo.workdir), - additional_python_paths=[str(self.context.src_workspace)], + additional_python_paths=[str(self.repo.srcs.workdir)], ) self.publish_message( AIMessage(content=run_code_context.model_dump_json(), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_SELF) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 5592841eb..344e1df5e 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -386,8 +386,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): msg = response else: msg = AIMessage(content=response or "", cause_by=self.rc.todo, sent_from=self) - if self.enable_memory: - self.rc.memory.add(msg) + self.rc.memory.add(msg) return msg diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 7f0c56388..2ea16f55f 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -68,7 +68,7 @@ def generate_repo( company.run_project(idea, send_to=any_to_str(ProductManager)) asyncio.run(company.run(n_round=n_round)) - return ctx.repo + return ctx.kwargs.get("project_path") @app.command("", help="Start a new project.") diff --git a/tests/conftest.py b/tests/conftest.py index f26ab2ef9..1f6661f7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,6 @@ import logging import os import re import uuid -from pathlib import Path from typing import Callable import aiohttp.web @@ -23,7 +22,6 @@ from metagpt.context import Context as MetagptContext from metagpt.llm import LLM from metagpt.logs import logger from metagpt.utils.git_repository import GitRepository -from metagpt.utils.project_repo import ProjectRepo from tests.mock.mock_aiohttp import MockAioResponse from tests.mock.mock_curl_cffi import MockCurlCffiResponse from tests.mock.mock_httplib2 import MockHttplib2Response @@ -149,13 +147,14 @@ def loguru_caplog(caplog): @pytest.fixture(scope="function") def context(request): ctx = MetagptContext() - ctx.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") - ctx.repo = ProjectRepo(ctx.git_repo) + repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") + ctx.config.project_path = str(repo.workdir) # Destroy git repo at the end of the test session. def fin(): - if ctx.git_repo: - ctx.git_repo.delete_repository() + if ctx.config.project_path: + git_repo = GitRepository(ctx.config.project_path) + git_repo.delete_repository() # Register the function for destroying the environment. request.addfinalizer(fin) @@ -279,6 +278,6 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache): @pytest.fixture def git_dir(): """Fixture to get the unittest directory.""" - git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}" + git_dir = DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}" git_dir.mkdir(parents=True, exist_ok=True) return git_dir diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index 9924a2e84..1351b418a 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -6,37 +6,105 @@ @File : test_design_api.py @Modifiled By: mashenquan, 2023-12-6. According to RFC 135 """ +import json + import pytest from metagpt.actions.design_api import WriteDesign -from metagpt.llm import LLM +from metagpt.const import METAGPT_ROOT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message +from metagpt.utils.project_repo import ProjectRepo from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON @pytest.mark.asyncio -async def test_design_api(context): - inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE - for prd in inputs: - await context.repo.docs.prd.save(filename="new_prd.txt", content=prd) +async def test_design(context): + # Mock new design env + prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。" + context.kwargs.project_path = context.config.project_path + context.kwargs.inc = False + filename = "prd.txt" + repo = ProjectRepo(context.kwargs.project_path) + await repo.docs.prd.save(filename=filename, content=prd) + kvs = { + "project_path": str(context.kwargs.project_path), + "changed_prd_filenames": [str(repo.docs.prd.workdir / filename)], + } + instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput") - design_api = WriteDesign(context=context) - - result = await design_api.run(Message(content=prd, instruct_content=None)) - logger.info(result) - - assert result - - -@pytest.mark.asyncio -async def test_refined_design_api(context): - await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON)) - await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE) - - design_api = WriteDesign(context=context, llm=LLM()) - - result = await design_api.run(Message(content="", instruct_content=None)) + design_api = WriteDesign(context=context) + result = await design_api.run([Message(content=prd, instruct_content=instruct_content)]) logger.info(result) - assert result + assert isinstance(result, AIMessage) + assert result.instruct_content + assert repo.docs.system_design.changed_files + + # Mock incremental design env + context.kwargs.inc = True + await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON)) + await repo.docs.system_design.save(filename=filename, content=DESIGN_SAMPLE) + + result = await design_api.run([Message(content="", instruct_content=instruct_content)]) + logger.info(result) + assert result + assert isinstance(result, AIMessage) + assert result.instruct_content + assert repo.docs.system_design.changed_files + + +@pytest.mark.parametrize( + ("user_requirement", "prd_filename", "exists_design_filename"), + [ + ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), + ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), + ( + "write 2048 game", + str(METAGPT_ROOT / "tests/data/prd.json"), + str(METAGPT_ROOT / "tests/data/system_design.json"), + ), + ], +) +@pytest.mark.asyncio +async def test_design_api(context, user_requirement, prd_filename, exists_design_filename): + action = WriteDesign() + result = await action.run( + user_requirement=user_requirement, prd_filename=prd_filename, exists_design_filename=exists_design_filename + ) + assert isinstance(result, AIMessage) + assert result.content + m = json.loads(result.content) + assert m + + +@pytest.mark.parametrize( + ("user_requirement", "prd_filename", "exists_design_filename"), + [ + ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), + ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), + ( + "write 2048 game", + str(METAGPT_ROOT / "tests/data/prd.json"), + str(METAGPT_ROOT / "tests/data/system_design.json"), + ), + ], +) +@pytest.mark.asyncio +async def test_design_api_dir(context, user_requirement, prd_filename, exists_design_filename): + action = WriteDesign() + result = await action.run( + user_requirement=user_requirement, + prd_filename=prd_filename, + exists_design_filename=exists_design_filename, + output_path=context.config.project_path, + ) + assert isinstance(result, AIMessage) + assert result.content + assert str(context.config.project_path) in result.content + assert result.instruct_content + assert result.instruct_content.changed_system_design_filenames + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_project_management.py b/tests/metagpt/actions/test_project_management.py index 5d0d11efb..26699dea7 100644 --- a/tests/metagpt/actions/test_project_management.py +++ b/tests/metagpt/actions/test_project_management.py @@ -5,13 +5,15 @@ @Author : alexanderwu @File : test_project_management.py """ +import json import pytest from metagpt.actions.project_management import WriteTasks -from metagpt.llm import LLM +from metagpt.const import METAGPT_ROOT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message +from metagpt.utils.project_repo import ProjectRepo from tests.data.incremental_dev_project.mock import ( REFINED_DESIGN_JSON, REFINED_PRD_JSON, @@ -22,29 +24,46 @@ from tests.metagpt.actions.mock_json import DESIGN, PRD @pytest.mark.asyncio async def test_task(context): - await context.repo.docs.prd.save("1.txt", content=str(PRD)) - await context.repo.docs.system_design.save("1.txt", content=str(DESIGN)) - logger.info(context.git_repo) + # Mock write tasks env + context.kwargs.project_path = context.config.project_path + context.kwargs.inc = False + repo = ProjectRepo(context.kwargs.project_path) + filename = "1.txt" + await repo.docs.prd.save(filename=filename, content=str(PRD)) + await repo.docs.system_design.save(filename=filename, content=str(DESIGN)) + kvs = { + "project_path": context.kwargs.project_path, + "changed_system_design_filenames": [str(repo.docs.system_design.workdir / filename)], + } + instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput") action = WriteTasks(context=context) - - result = await action.run(Message(content="", instruct_content=None)) + result = await action.run([Message(content="", instruct_content=instruct_content)]) logger.info(result) - assert result + assert result.instruct_content.changed_task_filenames + + # Mock incremental env + context.kwargs.inc = True + await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON)) + await repo.docs.system_design.save(filename=filename, content=str(REFINED_DESIGN_JSON)) + await repo.docs.task.save(filename=filename, content=TASK_SAMPLE) + + result = await action.run([Message(content="", instruct_content=instruct_content)]) + logger.info(result) + assert result + assert result.instruct_content.changed_task_filenames @pytest.mark.asyncio -async def test_refined_task(context): - await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON)) - await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON)) - await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE) - - logger.info(context.git_repo) - - action = WriteTasks(context=context, llm=LLM()) - - result = await action.run(Message(content="", instruct_content=None)) - logger.info(result) - +async def test_task_api(context): + action = WriteTasks() + result = await action.run(design_filename=str(METAGPT_ROOT / "tests/data/system_design.json")) assert result + assert result.content + m = json.loads(result.content) + assert m + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 42623f807..1c1772031 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -26,12 +26,7 @@ from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPL def setup_inc_workdir(context, inc: bool = False): """setup incremental workdir for testing""" - context.src_workspace = context.git_repo.workdir / "src" - if inc: - context.config.inc = inc - context.repo.old_workspace = context.repo.git_repo.workdir / "old" - context.config.project_path = "old" - + context.config.inc = inc return context diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 43aa336b7..8cbc01716 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -6,6 +6,7 @@ @File : test_write_prd.py @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`. """ +import json import pytest @@ -14,17 +15,16 @@ from metagpt.const import REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.roles.product_manager import ProductManager from metagpt.roles.role import RoleReactMode -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message from metagpt.utils.common import any_to_str -from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE -from tests.metagpt.actions.test_write_code import setup_inc_workdir +from metagpt.utils.project_repo import ProjectRepo +from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE @pytest.mark.asyncio async def test_write_prd(new_filename, context): product_manager = ProductManager(context=context) requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结" - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) product_manager.rc.react_mode = RoleReactMode.BY_ORDER prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement)) assert prd.cause_by == any_to_str(WritePRD) @@ -34,38 +34,39 @@ async def test_write_prd(new_filename, context): # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert product_manager.context.repo.docs.prd.changed_files + repo = ProjectRepo(context.kwargs.project_path) + assert repo.docs.prd.changed_files + repo.git_repo.archive() - -@pytest.mark.asyncio -async def test_write_prd_inc(new_filename, context, git_dir): - context = setup_inc_workdir(context, inc=True) - await context.repo.docs.prd.save("1.txt", PRD_SAMPLE) - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE) + # Mock incremental requirement + context.config.inc = True + context.config.project_path = context.kwargs.project_path + repo = ProjectRepo(context.config.project_path) + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE) action = WritePRD(context=context) - prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)) + prd = await action.run([Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)]) logger.info(NEW_REQUIREMENT_SAMPLE) logger.info(prd) # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert "Refined Requirements" in prd.content + assert repo.git_repo.changed_files @pytest.mark.asyncio async def test_fix_debug(new_filename, context, git_dir): - context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name + # Mock legacy project + context.kwargs.project_path = str(git_dir) + repo = ProjectRepo(context.kwargs.project_path) + repo.with_src_path(git_dir.name) + await repo.srcs.save(filename="main.py", content='if __name__ == "__main__":\nmain()') + requirements = "ValueError: undefined variable `st`." + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) - await context.repo.with_src_path(context.src_workspace).srcs.save( - filename="main.py", content='if __name__ == "__main__":\nmain()' - ) - requirements = "Please fix the bug in the code." - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) action = WritePRD(context=context) - - prd = await action.run(Message(content=requirements, instruct_content=None)) + prd = await action.run([Message(content=requirements, instruct_content=None)]) logger.info(prd) # Assert the prd is not None or empty @@ -73,5 +74,39 @@ async def test_fix_debug(new_filename, context, git_dir): assert prd.content != "" +@pytest.mark.asyncio +async def test_write_prd_api(context): + action = WritePRD() + result = await action.run(user_requirement="write a snake game.") + assert isinstance(result, AIMessage) + assert result.content + m = json.loads(result.content) + assert m + + result = await action.run(user_requirement="write a snake game.", output_path=str(context.config.project_path)) + assert isinstance(result, AIMessage) + assert result.content + assert result.instruct_content + assert str(context.config.project_path) in result.content + + legacy_prd_filename = result.instruct_content.changed_prd_filenames[-1] + + result = await action.run(user_requirement="Add moving enemy.", exists_prd_filename=legacy_prd_filename) + assert isinstance(result, AIMessage) + assert result.content + m = json.loads(result.content) + assert m + + result = await action.run( + user_requirement="Add moving enemy.", + output_path=str(context.config.project_path), + exists_prd_filename=legacy_prd_filename, + ) + assert isinstance(result, AIMessage) + assert result.content + assert result.instruct_content + assert str(context.config.project_path) in result.content + + if __name__ == "__main__": pytest.main([__file__, "-s"])