mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
feat: software action + api interface
This commit is contained in:
parent
f3b839847b
commit
ce3260038a
15 changed files with 231 additions and 114 deletions
|
|
@ -8,6 +8,7 @@
|
|||
1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name.
|
||||
2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality.
|
||||
@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD.
|
||||
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
|
|
@ -28,6 +29,7 @@ from metagpt.actions.design_api_an import (
|
|||
from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import AIMessage, Document, Documents, Message
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import aread, awrite, to_markdown_code_block
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
|
@ -42,6 +44,7 @@ NEW_REQ_TEMPLATE = """
|
|||
"""
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "write system design"])
|
||||
class WriteDesign(Action):
|
||||
name: str = ""
|
||||
i_context: Optional[str] = None
|
||||
|
|
@ -163,7 +166,7 @@ class WriteDesign(Action):
|
|||
output_path=output_path,
|
||||
)
|
||||
|
||||
self.input_args = with_messages[0].instruct_content
|
||||
self.input_args = with_messages[-1].instruct_content
|
||||
self.repo = ProjectRepo(self.input_args.project_path)
|
||||
changed_prds = self.input_args.changed_prd_filenames
|
||||
changed_system_designs = [
|
||||
|
|
@ -283,7 +286,12 @@ class WriteDesign(Action):
|
|||
)
|
||||
|
||||
if not output_path:
|
||||
return AIMessage(content=design.instruct_content.model_dump_json())
|
||||
return AIMessage(content=design.content)
|
||||
output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json"
|
||||
await awrite(filename=output_filename, data=design.content)
|
||||
return AIMessage(content=f'System Design filename: "{str(output_filename)}"')
|
||||
kvs = {"changed_system_design_filenames": [output_filename]}
|
||||
|
||||
return AIMessage(
|
||||
content=f'System Design filename: "{str(output_filename)}"',
|
||||
instruct_content=AIMessage.create_instruct_value(kvs=kvs),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ class PrepareDocuments(Action):
|
|||
path = Path(self.config.project_path)
|
||||
if path.exists() and not self.config.inc:
|
||||
shutil.rmtree(path)
|
||||
self.config.project_path = path
|
||||
self.context.set_repo_dir(path)
|
||||
self.context.kwargs.project_path = path
|
||||
self.context.kwargs.inc = self.config.inc
|
||||
return ProjectRepo(path)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
1. Divide the context into three components: legacy code, unit test code, and console log.
|
||||
2. Move the document storage operations related to WritePRD from the save operation of WriteDesign.
|
||||
3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality.
|
||||
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -21,6 +22,7 @@ from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE
|
|||
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import AIMessage, Document, Documents, Message
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import aread, to_markdown_code_block
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.report import DocsReporter
|
||||
|
|
@ -34,6 +36,7 @@ NEW_REQ_TEMPLATE = """
|
|||
"""
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "write a project schedule given a project system design file"])
|
||||
class WriteTasks(Action):
|
||||
name: str = "CreateTasks"
|
||||
i_context: Optional[str] = None
|
||||
|
|
@ -73,7 +76,7 @@ class WriteTasks(Action):
|
|||
if not with_messages:
|
||||
return await self._execute_api(user_requirement=user_requirement, design_filename=design_filename)
|
||||
|
||||
self.input_args = with_messages[0].instruct_content
|
||||
self.input_args = with_messages[-1].instruct_content
|
||||
self.repo = ProjectRepo(self.input_args.project_path)
|
||||
changed_system_designs = self.input_args.changed_system_design_filenames
|
||||
changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())]
|
||||
|
|
@ -82,7 +85,7 @@ class WriteTasks(Action):
|
|||
# `docs/system_designs/`.
|
||||
for filename in changed_system_designs:
|
||||
task_doc = await self._update_tasks(filename=filename)
|
||||
change_files.docs[filename] = task_doc
|
||||
change_files.docs[str(self.repo.docs.task.workdir / task_doc.filename)] = task_doc
|
||||
|
||||
# Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`.
|
||||
for filename in changed_tasks:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality.
|
||||
3. Move the document storage operations related to WritePRD from the save operation of WriteDesign.
|
||||
@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD.
|
||||
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -38,6 +39,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import AIMessage, Document, Documents, Message
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import CodeParser, aread, awrite, to_markdown_code_block
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
|
|
@ -64,6 +66,7 @@ NEW_REQ_TEMPLATE = """
|
|||
"""
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "write product requirement documents"])
|
||||
class WritePRD(Action):
|
||||
"""WritePRD deal with the following situations:
|
||||
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
|
||||
|
|
@ -145,11 +148,11 @@ class WritePRD(Action):
|
|||
|
||||
self.input_args = with_messages[-1].instruct_content
|
||||
if not self.input_args:
|
||||
self.repo = ProjectRepo(self.config.project_path)
|
||||
self.repo = ProjectRepo(self.context.kwargs.project_path)
|
||||
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content)
|
||||
self.input_args = AIMessage.create_instruct_value(
|
||||
kvs={
|
||||
"project_path": self.config.project_path,
|
||||
"project_path": self.context.kwargs.project_path,
|
||||
"requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME),
|
||||
"prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files],
|
||||
},
|
||||
|
|
@ -183,6 +186,9 @@ class WritePRD(Action):
|
|||
kvs["changed_prd_filenames"] = [
|
||||
str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys())
|
||||
]
|
||||
kvs["project_path"] = str(self.repo.workdir)
|
||||
kvs["requirements_filename"] = str(self.repo.docs.workdir / REQUIREMENT_FILENAME)
|
||||
self.context.kwargs.project_path = str(self.repo.workdir)
|
||||
return AIMessage(
|
||||
content="PRD is completed. "
|
||||
+ "\n".join(
|
||||
|
|
@ -302,7 +308,7 @@ class WritePRD(Action):
|
|||
async def _execute_api(
|
||||
self, user_requirement: str, output_path: str, exists_prd_filename: str, extra_info: str
|
||||
) -> AIMessage:
|
||||
content = to_markdown_code_block(val=user_requirement)
|
||||
content = to_markdown_code_block(val=user_requirement, type_="text")
|
||||
if extra_info:
|
||||
content += to_markdown_code_block(val=extra_info)
|
||||
|
||||
|
|
@ -320,4 +326,5 @@ class WritePRD(Action):
|
|||
|
||||
output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json"
|
||||
await awrite(filename=output_filename, data=new_prd.content)
|
||||
return AIMessage(content=f'PRD filename: "{str(output_filename)}"')
|
||||
kvs = AIMessage.create_instruct_value({"changed_prd_filenames": [str(output_filename)]})
|
||||
return AIMessage(content=f'PRD filename: "{str(output_filename)}"', instruct_content=kvs)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -22,8 +21,6 @@ from metagpt.utils.cost_manager import (
|
|||
FireworksCostManager,
|
||||
TokenCostManager,
|
||||
)
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class AttrDict(BaseModel):
|
||||
|
|
@ -66,9 +63,6 @@ class Context(BaseModel):
|
|||
kwargs: AttrDict = AttrDict()
|
||||
config: Config = Config.default()
|
||||
|
||||
repo: Optional[ProjectRepo] = None
|
||||
git_repo: Optional[GitRepository] = None
|
||||
src_workspace: Optional[Path] = None
|
||||
cost_manager: CostManager = CostManager()
|
||||
|
||||
_llm: Optional[BaseLLM] = None
|
||||
|
|
@ -80,11 +74,6 @@ class Context(BaseModel):
|
|||
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
def set_repo_dir(self, path: str | Path):
|
||||
repo_path = Path(path)
|
||||
self.git_repo = GitRepository(local_path=repo_path, auto_init=True)
|
||||
self.repo = ProjectRepo(self.git_repo)
|
||||
|
||||
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
||||
"""Return a CostManager instance"""
|
||||
if llm_config.api_type == LLMType.FIREWORKS:
|
||||
|
|
@ -117,7 +106,6 @@ class Context(BaseModel):
|
|||
Dict[str, Any]: A dictionary containing serialized data.
|
||||
"""
|
||||
return {
|
||||
"workdir": str(self.repo.workdir) if self.repo else "",
|
||||
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
|
||||
"cost_manager": self.cost_manager.model_dump_json(),
|
||||
}
|
||||
|
|
@ -130,13 +118,6 @@ class Context(BaseModel):
|
|||
"""
|
||||
if not serialized_data:
|
||||
return
|
||||
workdir = serialized_data.get("workdir")
|
||||
if workdir:
|
||||
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
|
||||
self.repo = ProjectRepo(self.git_repo)
|
||||
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
if src_workspace.exists():
|
||||
self.src_workspace = src_workspace
|
||||
kwargs = serialized_data.get("kwargs")
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from metagpt.logs import logger
|
|||
from metagpt.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from metagpt.roles.role import Role # noqa: F401
|
||||
|
|
@ -243,8 +244,9 @@ class Environment(ExtEnv):
|
|||
self.member_addrs[obj] = addresses
|
||||
|
||||
def archive(self, auto_archive=True):
|
||||
if auto_archive and self.context.git_repo:
|
||||
self.context.git_repo.archive()
|
||||
if auto_archive and self.context.kwargs.get("project_path"):
|
||||
git_repo = GitRepository(self.context.kwargs.project_path)
|
||||
git_repo.archive()
|
||||
|
||||
@classmethod
|
||||
def model_rebuild(cls, **kwargs):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.roles.role import Role, RoleReactMode
|
||||
from metagpt.utils.common import any_to_name, any_to_str
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
|
@ -35,6 +35,7 @@ class ProductManager(Role):
|
|||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.enable_memory = False
|
||||
self.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
self.set_actions([PrepareDocuments(send_to=any_to_str(self)), WritePRD])
|
||||
self._watch([UserRequirement, PrepareDocuments])
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class QaEngineer(Role):
|
|||
code_filename=context.code_doc.filename,
|
||||
test_filename=context.test_doc.filename,
|
||||
working_directory=str(self.repo.workdir),
|
||||
additional_python_paths=[str(self.context.src_workspace)],
|
||||
additional_python_paths=[str(self.repo.srcs.workdir)],
|
||||
)
|
||||
self.publish_message(
|
||||
AIMessage(content=run_code_context.model_dump_json(), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_SELF)
|
||||
|
|
|
|||
|
|
@ -386,8 +386,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
msg = response
|
||||
else:
|
||||
msg = AIMessage(content=response or "", cause_by=self.rc.todo, sent_from=self)
|
||||
if self.enable_memory:
|
||||
self.rc.memory.add(msg)
|
||||
self.rc.memory.add(msg)
|
||||
|
||||
return msg
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def generate_repo(
|
|||
company.run_project(idea, send_to=any_to_str(ProductManager))
|
||||
asyncio.run(company.run(n_round=n_round))
|
||||
|
||||
return ctx.repo
|
||||
return ctx.kwargs.get("project_path")
|
||||
|
||||
|
||||
@app.command("", help="Start a new project.")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import aiohttp.web
|
||||
|
|
@ -23,7 +22,6 @@ from metagpt.context import Context as MetagptContext
|
|||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.mock.mock_aiohttp import MockAioResponse
|
||||
from tests.mock.mock_curl_cffi import MockCurlCffiResponse
|
||||
from tests.mock.mock_httplib2 import MockHttplib2Response
|
||||
|
|
@ -149,13 +147,14 @@ def loguru_caplog(caplog):
|
|||
@pytest.fixture(scope="function")
|
||||
def context(request):
|
||||
ctx = MetagptContext()
|
||||
ctx.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
|
||||
ctx.repo = ProjectRepo(ctx.git_repo)
|
||||
repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
|
||||
ctx.config.project_path = str(repo.workdir)
|
||||
|
||||
# Destroy git repo at the end of the test session.
|
||||
def fin():
|
||||
if ctx.git_repo:
|
||||
ctx.git_repo.delete_repository()
|
||||
if ctx.config.project_path:
|
||||
git_repo = GitRepository(ctx.config.project_path)
|
||||
git_repo.delete_repository()
|
||||
|
||||
# Register the function for destroying the environment.
|
||||
request.addfinalizer(fin)
|
||||
|
|
@ -279,6 +278,6 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache):
|
|||
@pytest.fixture
|
||||
def git_dir():
|
||||
"""Fixture to get the unittest directory."""
|
||||
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
|
||||
git_dir = DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}"
|
||||
git_dir.mkdir(parents=True, exist_ok=True)
|
||||
return git_dir
|
||||
|
|
|
|||
|
|
@ -6,37 +6,105 @@
|
|||
@File : test_design_api.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import AIMessage, Message
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api(context):
|
||||
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE
|
||||
for prd in inputs:
|
||||
await context.repo.docs.prd.save(filename="new_prd.txt", content=prd)
|
||||
async def test_design(context):
|
||||
# Mock new design env
|
||||
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
|
||||
context.kwargs.project_path = context.config.project_path
|
||||
context.kwargs.inc = False
|
||||
filename = "prd.txt"
|
||||
repo = ProjectRepo(context.kwargs.project_path)
|
||||
await repo.docs.prd.save(filename=filename, content=prd)
|
||||
kvs = {
|
||||
"project_path": str(context.kwargs.project_path),
|
||||
"changed_prd_filenames": [str(repo.docs.prd.workdir / filename)],
|
||||
}
|
||||
instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput")
|
||||
|
||||
design_api = WriteDesign(context=context)
|
||||
|
||||
result = await design_api.run(Message(content=prd, instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refined_design_api(context):
|
||||
await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON))
|
||||
await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE)
|
||||
|
||||
design_api = WriteDesign(context=context, llm=LLM())
|
||||
|
||||
result = await design_api.run(Message(content="", instruct_content=None))
|
||||
design_api = WriteDesign(context=context)
|
||||
result = await design_api.run([Message(content=prd, instruct_content=instruct_content)])
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.instruct_content
|
||||
assert repo.docs.system_design.changed_files
|
||||
|
||||
# Mock incremental design env
|
||||
context.kwargs.inc = True
|
||||
await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON))
|
||||
await repo.docs.system_design.save(filename=filename, content=DESIGN_SAMPLE)
|
||||
|
||||
result = await design_api.run([Message(content="", instruct_content=instruct_content)])
|
||||
logger.info(result)
|
||||
assert result
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.instruct_content
|
||||
assert repo.docs.system_design.changed_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_requirement", "prd_filename", "exists_design_filename"),
|
||||
[
|
||||
("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None),
|
||||
("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None),
|
||||
(
|
||||
"write 2048 game",
|
||||
str(METAGPT_ROOT / "tests/data/prd.json"),
|
||||
str(METAGPT_ROOT / "tests/data/system_design.json"),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api(context, user_requirement, prd_filename, exists_design_filename):
|
||||
action = WriteDesign()
|
||||
result = await action.run(
|
||||
user_requirement=user_requirement, prd_filename=prd_filename, exists_design_filename=exists_design_filename
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
m = json.loads(result.content)
|
||||
assert m
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_requirement", "prd_filename", "exists_design_filename"),
|
||||
[
|
||||
("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None),
|
||||
("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None),
|
||||
(
|
||||
"write 2048 game",
|
||||
str(METAGPT_ROOT / "tests/data/prd.json"),
|
||||
str(METAGPT_ROOT / "tests/data/system_design.json"),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api_dir(context, user_requirement, prd_filename, exists_design_filename):
|
||||
action = WriteDesign()
|
||||
result = await action.run(
|
||||
user_requirement=user_requirement,
|
||||
prd_filename=prd_filename,
|
||||
exists_design_filename=exists_design_filename,
|
||||
output_path=context.config.project_path,
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert str(context.config.project_path) in result.content
|
||||
assert result.instruct_content
|
||||
assert result.instruct_content.changed_system_design_filenames
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -5,13 +5,15 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_project_management.py
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import AIMessage, Message
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
REFINED_DESIGN_JSON,
|
||||
REFINED_PRD_JSON,
|
||||
|
|
@ -22,29 +24,46 @@ from tests.metagpt.actions.mock_json import DESIGN, PRD
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task(context):
|
||||
await context.repo.docs.prd.save("1.txt", content=str(PRD))
|
||||
await context.repo.docs.system_design.save("1.txt", content=str(DESIGN))
|
||||
logger.info(context.git_repo)
|
||||
# Mock write tasks env
|
||||
context.kwargs.project_path = context.config.project_path
|
||||
context.kwargs.inc = False
|
||||
repo = ProjectRepo(context.kwargs.project_path)
|
||||
filename = "1.txt"
|
||||
await repo.docs.prd.save(filename=filename, content=str(PRD))
|
||||
await repo.docs.system_design.save(filename=filename, content=str(DESIGN))
|
||||
kvs = {
|
||||
"project_path": context.kwargs.project_path,
|
||||
"changed_system_design_filenames": [str(repo.docs.system_design.workdir / filename)],
|
||||
}
|
||||
instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput")
|
||||
|
||||
action = WriteTasks(context=context)
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
result = await action.run([Message(content="", instruct_content=instruct_content)])
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
assert result.instruct_content.changed_task_filenames
|
||||
|
||||
# Mock incremental env
|
||||
context.kwargs.inc = True
|
||||
await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON))
|
||||
await repo.docs.system_design.save(filename=filename, content=str(REFINED_DESIGN_JSON))
|
||||
await repo.docs.task.save(filename=filename, content=TASK_SAMPLE)
|
||||
|
||||
result = await action.run([Message(content="", instruct_content=instruct_content)])
|
||||
logger.info(result)
|
||||
assert result
|
||||
assert result.instruct_content.changed_task_filenames
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refined_task(context):
|
||||
await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON))
|
||||
await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON))
|
||||
await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE)
|
||||
|
||||
logger.info(context.git_repo)
|
||||
|
||||
action = WriteTasks(context=context, llm=LLM())
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
async def test_task_api(context):
|
||||
action = WriteTasks()
|
||||
result = await action.run(design_filename=str(METAGPT_ROOT / "tests/data/system_design.json"))
|
||||
assert result
|
||||
assert result.content
|
||||
m = json.loads(result.content)
|
||||
assert m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -26,12 +26,7 @@ from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPL
|
|||
|
||||
def setup_inc_workdir(context, inc: bool = False):
|
||||
"""setup incremental workdir for testing"""
|
||||
context.src_workspace = context.git_repo.workdir / "src"
|
||||
if inc:
|
||||
context.config.inc = inc
|
||||
context.repo.old_workspace = context.repo.git_repo.workdir / "old"
|
||||
context.config.project_path = "old"
|
||||
|
||||
context.config.inc = inc
|
||||
return context
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@File : test_write_prd.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`.
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -14,17 +15,16 @@ from metagpt.const import REQUIREMENT_FILENAME
|
|||
from metagpt.logs import logger
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import RoleReactMode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import AIMessage, Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE
|
||||
from tests.metagpt.actions.test_write_code import setup_inc_workdir
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd(new_filename, context):
|
||||
product_manager = ProductManager(context=context)
|
||||
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
product_manager.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
|
||||
assert prd.cause_by == any_to_str(WritePRD)
|
||||
|
|
@ -34,38 +34,39 @@ async def test_write_prd(new_filename, context):
|
|||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd.content != ""
|
||||
assert product_manager.context.repo.docs.prd.changed_files
|
||||
repo = ProjectRepo(context.kwargs.project_path)
|
||||
assert repo.docs.prd.changed_files
|
||||
repo.git_repo.archive()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd_inc(new_filename, context, git_dir):
|
||||
context = setup_inc_workdir(context, inc=True)
|
||||
await context.repo.docs.prd.save("1.txt", PRD_SAMPLE)
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
|
||||
# Mock incremental requirement
|
||||
context.config.inc = True
|
||||
context.config.project_path = context.kwargs.project_path
|
||||
repo = ProjectRepo(context.config.project_path)
|
||||
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
|
||||
|
||||
action = WritePRD(context=context)
|
||||
prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None))
|
||||
prd = await action.run([Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)])
|
||||
logger.info(NEW_REQUIREMENT_SAMPLE)
|
||||
logger.info(prd)
|
||||
|
||||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd.content != ""
|
||||
assert "Refined Requirements" in prd.content
|
||||
assert repo.git_repo.changed_files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_debug(new_filename, context, git_dir):
|
||||
context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name
|
||||
# Mock legacy project
|
||||
context.kwargs.project_path = str(git_dir)
|
||||
repo = ProjectRepo(context.kwargs.project_path)
|
||||
repo.with_src_path(git_dir.name)
|
||||
await repo.srcs.save(filename="main.py", content='if __name__ == "__main__":\nmain()')
|
||||
requirements = "ValueError: undefined variable `st`."
|
||||
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
|
||||
await context.repo.with_src_path(context.src_workspace).srcs.save(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()'
|
||||
)
|
||||
requirements = "Please fix the bug in the code."
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
action = WritePRD(context=context)
|
||||
|
||||
prd = await action.run(Message(content=requirements, instruct_content=None))
|
||||
prd = await action.run([Message(content=requirements, instruct_content=None)])
|
||||
logger.info(prd)
|
||||
|
||||
# Assert the prd is not None or empty
|
||||
|
|
@ -73,5 +74,39 @@ async def test_fix_debug(new_filename, context, git_dir):
|
|||
assert prd.content != ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd_api(context):
|
||||
action = WritePRD()
|
||||
result = await action.run(user_requirement="write a snake game.")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
m = json.loads(result.content)
|
||||
assert m
|
||||
|
||||
result = await action.run(user_requirement="write a snake game.", output_path=str(context.config.project_path))
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert result.instruct_content
|
||||
assert str(context.config.project_path) in result.content
|
||||
|
||||
legacy_prd_filename = result.instruct_content.changed_prd_filenames[-1]
|
||||
|
||||
result = await action.run(user_requirement="Add moving enemy.", exists_prd_filename=legacy_prd_filename)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
m = json.loads(result.content)
|
||||
assert m
|
||||
|
||||
result = await action.run(
|
||||
user_requirement="Add moving enemy.",
|
||||
output_path=str(context.config.project_path),
|
||||
exists_prd_filename=legacy_prd_filename,
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert result.instruct_content
|
||||
assert str(context.config.project_path) in result.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue