feat: software action + api interface

This commit is contained in:
莘权 马 2024-06-01 18:47:40 +08:00
parent f3b839847b
commit ce3260038a
15 changed files with 231 additions and 114 deletions

View file

@ -8,6 +8,7 @@
1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name.
2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality.
@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD.
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
"""
import json
import uuid
@ -28,6 +29,7 @@ from metagpt.actions.design_api_an import (
from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import AIMessage, Document, Documents, Message
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import aread, awrite, to_markdown_code_block
from metagpt.utils.mermaid import mermaid_to_file
from metagpt.utils.project_repo import ProjectRepo
@ -42,6 +44,7 @@ NEW_REQ_TEMPLATE = """
"""
@register_tool(tags=["software development", "write system design"])
class WriteDesign(Action):
name: str = ""
i_context: Optional[str] = None
@ -163,7 +166,7 @@ class WriteDesign(Action):
output_path=output_path,
)
self.input_args = with_messages[0].instruct_content
self.input_args = with_messages[-1].instruct_content
self.repo = ProjectRepo(self.input_args.project_path)
changed_prds = self.input_args.changed_prd_filenames
changed_system_designs = [
@ -283,7 +286,12 @@ class WriteDesign(Action):
)
if not output_path:
return AIMessage(content=design.instruct_content.model_dump_json())
return AIMessage(content=design.content)
output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json"
await awrite(filename=output_filename, data=design.content)
return AIMessage(content=f'System Design filename: "{str(output_filename)}"')
kvs = {"changed_system_design_filenames": [output_filename]}
return AIMessage(
content=f'System Design filename: "{str(output_filename)}"',
instruct_content=AIMessage.create_instruct_value(kvs=kvs),
)

View file

@ -46,8 +46,8 @@ class PrepareDocuments(Action):
path = Path(self.config.project_path)
if path.exists() and not self.config.inc:
shutil.rmtree(path)
self.config.project_path = path
self.context.set_repo_dir(path)
self.context.kwargs.project_path = path
self.context.kwargs.inc = self.config.inc
return ProjectRepo(path)
async def run(self, with_messages, **kwargs):

View file

@ -8,6 +8,7 @@
1. Divide the context into three components: legacy code, unit test code, and console log.
2. Move the document storage operations related to WritePRD from the save operation of WriteDesign.
3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality.
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
"""
import json
@ -21,6 +22,7 @@ from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
from metagpt.logs import logger
from metagpt.schema import AIMessage, Document, Documents, Message
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import aread, to_markdown_code_block
from metagpt.utils.project_repo import ProjectRepo
from metagpt.utils.report import DocsReporter
@ -34,6 +36,7 @@ NEW_REQ_TEMPLATE = """
"""
@register_tool(tags=["software development", "write a project schedule given a project system design file"])
class WriteTasks(Action):
name: str = "CreateTasks"
i_context: Optional[str] = None
@ -73,7 +76,7 @@ class WriteTasks(Action):
if not with_messages:
return await self._execute_api(user_requirement=user_requirement, design_filename=design_filename)
self.input_args = with_messages[0].instruct_content
self.input_args = with_messages[-1].instruct_content
self.repo = ProjectRepo(self.input_args.project_path)
changed_system_designs = self.input_args.changed_system_design_filenames
changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())]
@ -82,7 +85,7 @@ class WriteTasks(Action):
# `docs/system_designs/`.
for filename in changed_system_designs:
task_doc = await self._update_tasks(filename=filename)
change_files.docs[filename] = task_doc
change_files.docs[str(self.repo.docs.task.workdir / task_doc.filename)] = task_doc
# Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`.
for filename in changed_tasks:

View file

@ -9,6 +9,7 @@
2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality.
3. Move the document storage operations related to WritePRD from the save operation of WriteDesign.
@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD.
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
"""
from __future__ import annotations
@ -38,6 +39,7 @@ from metagpt.const import (
)
from metagpt.logs import logger
from metagpt.schema import AIMessage, Document, Documents, Message
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser, aread, awrite, to_markdown_code_block
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.mermaid import mermaid_to_file
@ -64,6 +66,7 @@ NEW_REQ_TEMPLATE = """
"""
@register_tool(tags=["software development", "write product requirement documents"])
class WritePRD(Action):
"""WritePRD deal with the following situations:
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
@ -145,11 +148,11 @@ class WritePRD(Action):
self.input_args = with_messages[-1].instruct_content
if not self.input_args:
self.repo = ProjectRepo(self.config.project_path)
self.repo = ProjectRepo(self.context.kwargs.project_path)
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content)
self.input_args = AIMessage.create_instruct_value(
kvs={
"project_path": self.config.project_path,
"project_path": self.context.kwargs.project_path,
"requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME),
"prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files],
},
@ -183,6 +186,9 @@ class WritePRD(Action):
kvs["changed_prd_filenames"] = [
str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys())
]
kvs["project_path"] = str(self.repo.workdir)
kvs["requirements_filename"] = str(self.repo.docs.workdir / REQUIREMENT_FILENAME)
self.context.kwargs.project_path = str(self.repo.workdir)
return AIMessage(
content="PRD is completed. "
+ "\n".join(
@ -302,7 +308,7 @@ class WritePRD(Action):
async def _execute_api(
self, user_requirement: str, output_path: str, exists_prd_filename: str, extra_info: str
) -> AIMessage:
content = to_markdown_code_block(val=user_requirement)
content = to_markdown_code_block(val=user_requirement, type_="text")
if extra_info:
content += to_markdown_code_block(val=extra_info)
@ -320,4 +326,5 @@ class WritePRD(Action):
output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json"
await awrite(filename=output_filename, data=new_prd.content)
return AIMessage(content=f'PRD filename: "{str(output_filename)}"')
kvs = AIMessage.create_instruct_value({"changed_prd_filenames": [str(output_filename)]})
return AIMessage(content=f'PRD filename: "{str(output_filename)}"', instruct_content=kvs)

View file

@ -8,7 +8,6 @@
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict
@ -22,8 +21,6 @@ from metagpt.utils.cost_manager import (
FireworksCostManager,
TokenCostManager,
)
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class AttrDict(BaseModel):
@ -66,9 +63,6 @@ class Context(BaseModel):
kwargs: AttrDict = AttrDict()
config: Config = Config.default()
repo: Optional[ProjectRepo] = None
git_repo: Optional[GitRepository] = None
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()
_llm: Optional[BaseLLM] = None
@ -80,11 +74,6 @@ class Context(BaseModel):
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
def set_repo_dir(self, path: str | Path):
repo_path = Path(path)
self.git_repo = GitRepository(local_path=repo_path, auto_init=True)
self.repo = ProjectRepo(self.git_repo)
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
"""Return a CostManager instance"""
if llm_config.api_type == LLMType.FIREWORKS:
@ -117,7 +106,6 @@ class Context(BaseModel):
Dict[str, Any]: A dictionary containing serialized data.
"""
return {
"workdir": str(self.repo.workdir) if self.repo else "",
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
"cost_manager": self.cost_manager.model_dump_json(),
}
@ -130,13 +118,6 @@ class Context(BaseModel):
"""
if not serialized_data:
return
workdir = serialized_data.get("workdir")
if workdir:
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
self.repo = ProjectRepo(self.git_repo)
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
if src_workspace.exists():
self.src_workspace = src_workspace
kwargs = serialized_data.get("kwargs")
if kwargs:
for k, v in kwargs.items():

View file

@ -22,6 +22,7 @@ from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.schema import Message
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
from metagpt.utils.git_repository import GitRepository
if TYPE_CHECKING:
from metagpt.roles.role import Role # noqa: F401
@ -243,8 +244,9 @@ class Environment(ExtEnv):
self.member_addrs[obj] = addresses
def archive(self, auto_archive=True):
if auto_archive and self.context.git_repo:
self.context.git_repo.archive()
if auto_archive and self.context.kwargs.get("project_path"):
git_repo = GitRepository(self.context.kwargs.project_path)
git_repo.archive()
@classmethod
def model_rebuild(cls, **kwargs):

View file

@ -10,7 +10,7 @@
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.roles.role import Role
from metagpt.roles.role import Role, RoleReactMode
from metagpt.utils.common import any_to_name, any_to_str
from metagpt.utils.git_repository import GitRepository
@ -35,6 +35,7 @@ class ProductManager(Role):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.enable_memory = False
self.rc.react_mode = RoleReactMode.BY_ORDER
self.set_actions([PrepareDocuments(send_to=any_to_str(self)), WritePRD])
self._watch([UserRequirement, PrepareDocuments])
self.todo_action = any_to_name(WritePRD)

View file

@ -94,7 +94,7 @@ class QaEngineer(Role):
code_filename=context.code_doc.filename,
test_filename=context.test_doc.filename,
working_directory=str(self.repo.workdir),
additional_python_paths=[str(self.context.src_workspace)],
additional_python_paths=[str(self.repo.srcs.workdir)],
)
self.publish_message(
AIMessage(content=run_code_context.model_dump_json(), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_SELF)

View file

@ -386,8 +386,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
msg = response
else:
msg = AIMessage(content=response or "", cause_by=self.rc.todo, sent_from=self)
if self.enable_memory:
self.rc.memory.add(msg)
self.rc.memory.add(msg)
return msg

View file

@ -68,7 +68,7 @@ def generate_repo(
company.run_project(idea, send_to=any_to_str(ProductManager))
asyncio.run(company.run(n_round=n_round))
return ctx.repo
return ctx.kwargs.get("project_path")
@app.command("", help="Start a new project.")

View file

@ -12,7 +12,6 @@ import logging
import os
import re
import uuid
from pathlib import Path
from typing import Callable
import aiohttp.web
@ -23,7 +22,6 @@ from metagpt.context import Context as MetagptContext
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
from tests.mock.mock_aiohttp import MockAioResponse
from tests.mock.mock_curl_cffi import MockCurlCffiResponse
from tests.mock.mock_httplib2 import MockHttplib2Response
@ -149,13 +147,14 @@ def loguru_caplog(caplog):
@pytest.fixture(scope="function")
def context(request):
ctx = MetagptContext()
ctx.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
ctx.repo = ProjectRepo(ctx.git_repo)
repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
ctx.config.project_path = str(repo.workdir)
# Destroy git repo at the end of the test session.
def fin():
if ctx.git_repo:
ctx.git_repo.delete_repository()
if ctx.config.project_path:
git_repo = GitRepository(ctx.config.project_path)
git_repo.delete_repository()
# Register the function for destroying the environment.
request.addfinalizer(fin)
@ -279,6 +278,6 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache):
@pytest.fixture
def git_dir():
"""Fixture to get the unittest directory."""
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
git_dir = DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}"
git_dir.mkdir(parents=True, exist_ok=True)
return git_dir

View file

@ -6,37 +6,105 @@
@File : test_design_api.py
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
"""
import json
import pytest
from metagpt.actions.design_api import WriteDesign
from metagpt.llm import LLM
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.schema import AIMessage, Message
from metagpt.utils.project_repo import ProjectRepo
from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON
@pytest.mark.asyncio
async def test_design_api(context):
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE
for prd in inputs:
await context.repo.docs.prd.save(filename="new_prd.txt", content=prd)
async def test_design(context):
# Mock new design env
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
context.kwargs.project_path = context.config.project_path
context.kwargs.inc = False
filename = "prd.txt"
repo = ProjectRepo(context.kwargs.project_path)
await repo.docs.prd.save(filename=filename, content=prd)
kvs = {
"project_path": str(context.kwargs.project_path),
"changed_prd_filenames": [str(repo.docs.prd.workdir / filename)],
}
instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput")
design_api = WriteDesign(context=context)
result = await design_api.run(Message(content=prd, instruct_content=None))
logger.info(result)
assert result
@pytest.mark.asyncio
async def test_refined_design_api(context):
await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON))
await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE)
design_api = WriteDesign(context=context, llm=LLM())
result = await design_api.run(Message(content="", instruct_content=None))
design_api = WriteDesign(context=context)
result = await design_api.run([Message(content=prd, instruct_content=instruct_content)])
logger.info(result)
assert result
assert isinstance(result, AIMessage)
assert result.instruct_content
assert repo.docs.system_design.changed_files
# Mock incremental design env
context.kwargs.inc = True
await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON))
await repo.docs.system_design.save(filename=filename, content=DESIGN_SAMPLE)
result = await design_api.run([Message(content="", instruct_content=instruct_content)])
logger.info(result)
assert result
assert isinstance(result, AIMessage)
assert result.instruct_content
assert repo.docs.system_design.changed_files
@pytest.mark.parametrize(
("user_requirement", "prd_filename", "exists_design_filename"),
[
("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None),
("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None),
(
"write 2048 game",
str(METAGPT_ROOT / "tests/data/prd.json"),
str(METAGPT_ROOT / "tests/data/system_design.json"),
),
],
)
@pytest.mark.asyncio
async def test_design_api(context, user_requirement, prd_filename, exists_design_filename):
action = WriteDesign()
result = await action.run(
user_requirement=user_requirement, prd_filename=prd_filename, exists_design_filename=exists_design_filename
)
assert isinstance(result, AIMessage)
assert result.content
m = json.loads(result.content)
assert m
@pytest.mark.parametrize(
("user_requirement", "prd_filename", "exists_design_filename"),
[
("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None),
("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None),
(
"write 2048 game",
str(METAGPT_ROOT / "tests/data/prd.json"),
str(METAGPT_ROOT / "tests/data/system_design.json"),
),
],
)
@pytest.mark.asyncio
async def test_design_api_dir(context, user_requirement, prd_filename, exists_design_filename):
action = WriteDesign()
result = await action.run(
user_requirement=user_requirement,
prd_filename=prd_filename,
exists_design_filename=exists_design_filename,
output_path=context.config.project_path,
)
assert isinstance(result, AIMessage)
assert result.content
assert str(context.config.project_path) in result.content
assert result.instruct_content
assert result.instruct_content.changed_system_design_filenames
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -5,13 +5,15 @@
@Author : alexanderwu
@File : test_project_management.py
"""
import json
import pytest
from metagpt.actions.project_management import WriteTasks
from metagpt.llm import LLM
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.schema import AIMessage, Message
from metagpt.utils.project_repo import ProjectRepo
from tests.data.incremental_dev_project.mock import (
REFINED_DESIGN_JSON,
REFINED_PRD_JSON,
@ -22,29 +24,46 @@ from tests.metagpt.actions.mock_json import DESIGN, PRD
@pytest.mark.asyncio
async def test_task(context):
await context.repo.docs.prd.save("1.txt", content=str(PRD))
await context.repo.docs.system_design.save("1.txt", content=str(DESIGN))
logger.info(context.git_repo)
# Mock write tasks env
context.kwargs.project_path = context.config.project_path
context.kwargs.inc = False
repo = ProjectRepo(context.kwargs.project_path)
filename = "1.txt"
await repo.docs.prd.save(filename=filename, content=str(PRD))
await repo.docs.system_design.save(filename=filename, content=str(DESIGN))
kvs = {
"project_path": context.kwargs.project_path,
"changed_system_design_filenames": [str(repo.docs.system_design.workdir / filename)],
}
instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput")
action = WriteTasks(context=context)
result = await action.run(Message(content="", instruct_content=None))
result = await action.run([Message(content="", instruct_content=instruct_content)])
logger.info(result)
assert result
assert result.instruct_content.changed_task_filenames
# Mock incremental env
context.kwargs.inc = True
await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON))
await repo.docs.system_design.save(filename=filename, content=str(REFINED_DESIGN_JSON))
await repo.docs.task.save(filename=filename, content=TASK_SAMPLE)
result = await action.run([Message(content="", instruct_content=instruct_content)])
logger.info(result)
assert result
assert result.instruct_content.changed_task_filenames
@pytest.mark.asyncio
async def test_refined_task(context):
await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON))
await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON))
await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE)
logger.info(context.git_repo)
action = WriteTasks(context=context, llm=LLM())
result = await action.run(Message(content="", instruct_content=None))
logger.info(result)
async def test_task_api(context):
action = WriteTasks()
result = await action.run(design_filename=str(METAGPT_ROOT / "tests/data/system_design.json"))
assert result
assert result.content
m = json.loads(result.content)
assert m
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -26,12 +26,7 @@ from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPL
def setup_inc_workdir(context, inc: bool = False):
"""setup incremental workdir for testing"""
context.src_workspace = context.git_repo.workdir / "src"
if inc:
context.config.inc = inc
context.repo.old_workspace = context.repo.git_repo.workdir / "old"
context.config.project_path = "old"
context.config.inc = inc
return context

View file

@ -6,6 +6,7 @@
@File : test_write_prd.py
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`.
"""
import json
import pytest
@ -14,17 +15,16 @@ from metagpt.const import REQUIREMENT_FILENAME
from metagpt.logs import logger
from metagpt.roles.product_manager import ProductManager
from metagpt.roles.role import RoleReactMode
from metagpt.schema import Message
from metagpt.schema import AIMessage, Message
from metagpt.utils.common import any_to_str
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE
from tests.metagpt.actions.test_write_code import setup_inc_workdir
from metagpt.utils.project_repo import ProjectRepo
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE
@pytest.mark.asyncio
async def test_write_prd(new_filename, context):
product_manager = ProductManager(context=context)
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
product_manager.rc.react_mode = RoleReactMode.BY_ORDER
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
assert prd.cause_by == any_to_str(WritePRD)
@ -34,38 +34,39 @@ async def test_write_prd(new_filename, context):
# Assert the prd is not None or empty
assert prd is not None
assert prd.content != ""
assert product_manager.context.repo.docs.prd.changed_files
repo = ProjectRepo(context.kwargs.project_path)
assert repo.docs.prd.changed_files
repo.git_repo.archive()
@pytest.mark.asyncio
async def test_write_prd_inc(new_filename, context, git_dir):
context = setup_inc_workdir(context, inc=True)
await context.repo.docs.prd.save("1.txt", PRD_SAMPLE)
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
# Mock incremental requirement
context.config.inc = True
context.config.project_path = context.kwargs.project_path
repo = ProjectRepo(context.config.project_path)
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
action = WritePRD(context=context)
prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None))
prd = await action.run([Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)])
logger.info(NEW_REQUIREMENT_SAMPLE)
logger.info(prd)
# Assert the prd is not None or empty
assert prd is not None
assert prd.content != ""
assert "Refined Requirements" in prd.content
assert repo.git_repo.changed_files
@pytest.mark.asyncio
async def test_fix_debug(new_filename, context, git_dir):
context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name
# Mock legacy project
context.kwargs.project_path = str(git_dir)
repo = ProjectRepo(context.kwargs.project_path)
repo.with_src_path(git_dir.name)
await repo.srcs.save(filename="main.py", content='if __name__ == "__main__":\nmain()')
requirements = "ValueError: undefined variable `st`."
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
await context.repo.with_src_path(context.src_workspace).srcs.save(
filename="main.py", content='if __name__ == "__main__":\nmain()'
)
requirements = "Please fix the bug in the code."
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
action = WritePRD(context=context)
prd = await action.run(Message(content=requirements, instruct_content=None))
prd = await action.run([Message(content=requirements, instruct_content=None)])
logger.info(prd)
# Assert the prd is not None or empty
@ -73,5 +74,39 @@ async def test_fix_debug(new_filename, context, git_dir):
assert prd.content != ""
@pytest.mark.asyncio
async def test_write_prd_api(context):
action = WritePRD()
result = await action.run(user_requirement="write a snake game.")
assert isinstance(result, AIMessage)
assert result.content
m = json.loads(result.content)
assert m
result = await action.run(user_requirement="write a snake game.", output_path=str(context.config.project_path))
assert isinstance(result, AIMessage)
assert result.content
assert result.instruct_content
assert str(context.config.project_path) in result.content
legacy_prd_filename = result.instruct_content.changed_prd_filenames[-1]
result = await action.run(user_requirement="Add moving enemy.", exists_prd_filename=legacy_prd_filename)
assert isinstance(result, AIMessage)
assert result.content
m = json.loads(result.content)
assert m
result = await action.run(
user_requirement="Add moving enemy.",
output_path=str(context.config.project_path),
exists_prd_filename=legacy_prd_filename,
)
assert isinstance(result, AIMessage)
assert result.content
assert result.instruct_content
assert str(context.config.project_path) in result.content
if __name__ == "__main__":
pytest.main([__file__, "-s"])