refine writeprd code

This commit is contained in:
geekan 2024-01-15 16:37:42 +08:00
parent 007102e022
commit 8baa6d094f
13 changed files with 150 additions and 114 deletions

View file

@ -34,8 +34,8 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
node: ActionNode = Field(default=None, exclude=True)
@property
def project_repo(self):
return ProjectRepo(self.context.git_repo)
def repo(self) -> ProjectRepo:
return self.context.repo
@property
def prompt_schema(self):

View file

@ -49,7 +49,7 @@ class DebugError(Action):
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
async def run(self, *args, **kwargs) -> str:
output_doc = await self.project_repo.test_outputs.get(filename=self.i_context.output_filename)
output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename)
if not output_doc:
return ""
output_detail = RunCodeResult.loads(output_doc.content)
@ -59,12 +59,12 @@ class DebugError(Action):
return ""
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
code_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get(
filename=self.i_context.code_filename
)
if not code_doc:
return ""
test_doc = await self.project_repo.tests.get(filename=self.i_context.test_filename)
test_doc = await self.repo.tests.get(filename=self.i_context.test_filename)
if not test_doc:
return ""
prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr)

View file

@ -40,10 +40,10 @@ class WriteDesign(Action):
async def run(self, with_messages: Message, schema: str = None):
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
changed_prds = self.project_repo.docs.prd.changed_files
changed_prds = self.repo.docs.prd.changed_files
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
# changes.
changed_system_designs = self.project_repo.docs.system_design.changed_files
changed_system_designs = self.repo.docs.system_design.changed_files
# For those PRDs and design documents that have undergone changes, regenerate the design content.
changed_files = Documents()
@ -73,21 +73,21 @@ class WriteDesign(Action):
return system_design_doc
async def _update_system_design(self, filename) -> Document:
prd = await self.project_repo.docs.prd.get(filename)
old_system_design_doc = await self.project_repo.docs.system_design.get(filename)
prd = await self.repo.docs.prd.get(filename)
old_system_design_doc = await self.repo.docs.system_design.get(filename)
if not old_system_design_doc:
system_design = await self._new_system_design(context=prd.content)
doc = await self.project_repo.docs.system_design.save(
doc = await self.repo.docs.system_design.save(
filename=filename,
content=system_design.instruct_content.model_dump_json(),
dependencies={prd.root_relative_path},
)
else:
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
await self.project_repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
await self._save_data_api_design(doc)
await self._save_seq_flow(doc)
await self.project_repo.resources.system_design.save_pdf(doc=doc)
await self.repo.resources.system_design.save_pdf(doc=doc)
return doc
async def _save_data_api_design(self, design_doc):
@ -95,7 +95,7 @@ class WriteDesign(Action):
data_api_design = m.get("Data structures and interfaces")
if not data_api_design:
return
pathname = self.project_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
pathname = self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(data_api_design, pathname)
logger.info(f"Save class view to {str(pathname)}")
@ -104,7 +104,7 @@ class WriteDesign(Action):
seq_flow = m.get("Program call flow")
if not seq_flow:
return
pathname = self.project_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
pathname = self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(seq_flow, pathname)
logger.info(f"Saving sequence flow to {str(pathname)}")

View file

@ -15,6 +15,7 @@ from metagpt.actions import Action, ActionOutput
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class PrepareDocuments(Action):
@ -38,13 +39,14 @@ class PrepareDocuments(Action):
shutil.rmtree(path)
self.config.project_path = path
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
self.context.repo = ProjectRepo(self.context.git_repo)
async def run(self, with_messages, **kwargs):
"""Create and initialize the workspace folder, initialize the Git environment."""
self._init_repo()
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
doc = await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
# Send a Message notification to the WritePRD action, instructing it to process requirements using
# `docs/requirement.txt` and `docs/prds/`.
return ActionOutput(content=doc.content, instruct_content=doc)

View file

@ -13,8 +13,8 @@
import json
from typing import Optional
from metagpt.actions import ActionOutput
from metagpt.actions.action import Action
from metagpt.actions.action_output import ActionOutput
from metagpt.actions.project_management_an import PM_NODE
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
from metagpt.logs import logger
@ -34,8 +34,8 @@ class WriteTasks(Action):
i_context: Optional[str] = None
async def run(self, with_messages):
changed_system_designs = self.project_repo.docs.system_design.changed_files
changed_tasks = self.project_repo.docs.task.changed_files
changed_system_designs = self.repo.docs.system_design.changed_files
changed_tasks = self.repo.docs.task.changed_files
change_files = Documents()
# Rewrite the system designs that have undergone changes based on the git head diff under
# `docs/system_designs/`.
@ -57,16 +57,14 @@ class WriteTasks(Action):
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _update_tasks(self, filename):
system_design_doc = await self.project_repo.docs.system_design.get(filename)
task_doc = await self.project_repo.docs.task.get(filename)
system_design_doc = await self.repo.docs.system_design.get(filename)
task_doc = await self.repo.docs.task.get(filename)
if task_doc:
task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc)
await self.project_repo.docs.task.save_doc(
doc=task_doc, dependencies={system_design_doc.root_relative_path}
)
await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path})
else:
rsp = await self._run_new_tasks(context=system_design_doc.content)
task_doc = await self.project_repo.docs.task.save(
task_doc = await self.repo.docs.task.save(
filename=filename,
content=rsp.instruct_content.model_dump_json(),
dependencies={system_design_doc.root_relative_path},
@ -87,7 +85,7 @@ class WriteTasks(Action):
async def _update_requirements(self, doc):
m = json.loads(doc.content)
packages = set(m.get("Required Python third-party packages", set()))
requirement_doc = await self.project_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
if not requirement_doc:
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
lines = requirement_doc.content.splitlines()
@ -95,4 +93,4 @@ class WriteTasks(Action):
if pkg == "":
continue
packages.add(pkg)
await self.project_repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))

View file

@ -98,10 +98,10 @@ class SummarizeCode(Action):
async def run(self):
design_pathname = Path(self.i_context.design_filename)
design_doc = await self.project_repo.docs.system_design.get(filename=design_pathname.name)
design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name)
task_pathname = Path(self.i_context.task_filename)
task_doc = await self.project_repo.docs.task.get(filename=task_pathname.name)
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
task_doc = await self.repo.docs.task.get(filename=task_pathname.name)
src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs
code_blocks = []
for filename in self.i_context.codes_filenames:
code_doc = await src_file_repo.get(filename)

View file

@ -88,12 +88,12 @@ class WriteCode(Action):
return code
async def run(self, *args, **kwargs) -> CodingContext:
bug_feedback = await self.project_repo.docs.get(filename=BUGFIX_FILENAME)
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
coding_context = CodingContext.loads(self.i_context.content)
test_doc = await self.project_repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
summary_doc = None
if coding_context.design_doc and coding_context.design_doc.filename:
summary_doc = await self.project_repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
logs = ""
if test_doc:
test_detail = RunCodeResult.loads(test_doc.content)
@ -105,7 +105,7 @@ class WriteCode(Action):
code_context = await self.get_codes(
coding_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
project_repo=self.repo.with_src_path(self.context.src_workspace),
)
prompt = PROMPT_TEMPLATE.format(

View file

@ -143,7 +143,7 @@ class WriteCodeReview(Action):
code_context = await WriteCode.get_codes(
self.i_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
project_repo=self.repo.with_src_path(self.context.src_workspace),
)
context = "\n".join(
[

View file

@ -15,7 +15,6 @@ from __future__ import annotations
import json
from pathlib import Path
from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
@ -58,96 +57,106 @@ NEW_REQ_TEMPLATE = """
class WritePRD(Action):
name: str = "WritePRD"
content: Optional[str] = None
"""WritePRD deal with the following situations:
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
2. New requirement: If the requirement is a new requirement, the PRD document will be generated.
3. Requirement update: If the requirement is an update, the PRD document will be updated.
"""
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
requirement_doc = await self.project_repo.docs.get(filename=REQUIREMENT_FILENAME)
if requirement_doc and await self._is_bugfix(requirement_doc.content):
await self.project_repo.docs.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=self,
send_to="Alex", # the name of Engineer
)
"""Run the action."""
req: Document = await self.repo.requirement
docs: list[Document] = await self.repo.docs.prd.get_all()
if not req:
raise FileNotFoundError("No requirement document found.")
if await self._is_bugfix(req.content):
logger.info(f"Bugfix detected: {req.content}")
return await self._handle_bugfix(req)
# remove bugfix file from last round in case of conflict
await self.repo.docs.delete(filename=BUGFIX_FILENAME)
# if requirement is related to other documents, update them, otherwise create a new one
if related_docs := await self.get_related_docs(req, docs):
logger.info(f"Requirement update detected: {req.content}")
return await self._handle_requirement_update(req, related_docs)
else:
await self.project_repo.docs.delete(filename=BUGFIX_FILENAME)
logger.info(f"New requirement detected: {req.content}")
return await self._handle_new_requirement(req)
prd_docs = await self.project_repo.docs.prd.get_all()
change_files = Documents()
for prd_doc in prd_docs:
prd_doc = await self._update_prd(requirement_doc=requirement_doc, prd_doc=prd_doc, *args, **kwargs)
if not prd_doc:
continue
change_files.docs[prd_doc.filename] = prd_doc
logger.info(f"rewrite prd: {prd_doc.filename}")
# If there is no existing PRD, generate one using 'docs/requirement.txt'.
if not change_files.docs:
prd_doc = await self._update_prd(requirement_doc=requirement_doc, *args, **kwargs)
if prd_doc:
change_files.docs[prd_doc.filename] = prd_doc
logger.debug(f"new prd: {prd_doc.filename}")
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
# 'publish' message to transition the workflow to the next stage. This design allows room for global
# optimization in subsequent steps.
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _handle_bugfix(self, req: Document) -> Message:
# ... bugfix logic ...
await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content)
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=self,
send_to="Alex", # the name of Engineer
)
async def _run_new_requirement(self, requirements) -> ActionOutput:
async def _handle_new_requirement(self, req: Document) -> ActionOutput:
"""handle new requirement"""
project_name = self.project_name
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
await self._rename_workspace(node)
return node
new_prd_doc = await self.repo.docs.prd.save(
filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json()
)
await self._save_competitive_analysis(new_prd_doc)
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
return Documents.from_iterable(documents=[new_prd_doc]).to_action_output()
async def _is_relative(self, new_requirement_doc, old_prd_doc) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd_doc.content, requirements=new_requirement_doc.content)
async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput:
# ... requirement update logic ...
for doc in related_docs:
await self._update_prd(req, doc)
return Documents.from_iterable(documents=related_docs).to_action_output()
async def _is_bugfix(self, context: str) -> bool:
if not self.repo.code_files_exists():
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
return node.get("issue_type") == "BUG"
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
"""get the related documents"""
# refine: use gather to speed up
return [i for i in docs if await self._is_related(req, i)]
async def _is_related(self, req: Document, old_prd: Document) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
return node.get("is_relative") == "YES"
async def _merge(self, new_requirement_doc, prd_doc) -> Document:
async def _merge(self, req: Document, related_doc: Document) -> Document:
if not self.project_name:
self.project_name = Path(self.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
prd_doc.content = node.instruct_content.model_dump_json()
related_doc.content = node.instruct_content.model_dump_json()
await self._rename_workspace(node)
return prd_doc
return related_doc
async def _update_prd(self, requirement_doc, prd_doc=None, *args, **kwargs) -> Document | None:
if not prd_doc:
prd = await self._run_new_requirement(
requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs
)
new_prd_doc = await self.project_repo.docs.prd.save(
filename=FileRepository.new_filename() + ".json", content=prd.instruct_content.model_dump_json()
)
elif await self._is_relative(requirement_doc, prd_doc):
new_prd_doc = await self._merge(requirement_doc, prd_doc)
self.project_repo.docs.prd.save_doc(doc=new_prd_doc)
else:
return None
async def _update_prd(self, req: Document, prd_doc: Document) -> Document:
new_prd_doc: Document = await self._merge(req, prd_doc)
self.repo.docs.prd.save_doc(doc=new_prd_doc)
await self._save_competitive_analysis(new_prd_doc)
await self.project_repo.resources.prd.save_pdf(doc=new_prd_doc)
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
return new_prd_doc
async def _save_competitive_analysis(self, prd_doc):
async def _save_competitive_analysis(self, prd_doc: Document):
m = json.loads(prd_doc.content)
quadrant_chart = m.get("Competitive Quadrant Chart")
if not quadrant_chart:
return
pathname = (
self.project_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
)
if not pathname.parent.exists():
pathname.parent.mkdir(parents=True, exist_ok=True)
pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(self.config.mermaid_engine, quadrant_chart, pathname)
async def _rename_workspace(self, prd):
@ -158,15 +167,4 @@ class WritePRD(Action):
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
if ws_name:
self.project_name = ws_name
self.project_repo.git_repo.rename_root(self.project_name)
async def _is_bugfix(self, context) -> bool:
git_workdir = self.project_repo.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
if not src_workdir.exists():
return False
code_files = self.project_repo.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
return node.get("issue_type") == "BUG"
self.repo.git_repo.rename_root(self.project_name)

View file

@ -17,6 +17,7 @@ from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class AttrDict(BaseModel):
@ -58,6 +59,8 @@ class Context(BaseModel):
kwargs: AttrDict = AttrDict()
config: Config = Config.default()
repo: Optional[ProjectRepo] = None
git_repo: Optional[GitRepository] = None
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()

View file

@ -10,8 +10,9 @@
from pydantic import Field
from metagpt.actions import ActionOutput, SearchAndSummarize
from metagpt.actions import SearchAndSummarize
from metagpt.actions.action_node import ActionNode
from metagpt.actions.action_output import ActionOutput
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
@ -36,6 +36,7 @@ from pydantic import (
model_validator,
)
from metagpt.actions.action_output import ActionOutput
from metagpt.const import (
MESSAGE_ROUTE_CAUSE_BY,
MESSAGE_ROUTE_FROM,
@ -162,6 +163,25 @@ class Documents(BaseModel):
docs: Dict[str, Document] = Field(default_factory=dict)
@classmethod
def from_iterable(cls, documents: Iterable[Document]) -> Documents:
"""Create a Documents instance from a list of Document instances.
:param documents: A list of Document instances.
:return: A Documents instance.
"""
docs = {doc.filename: doc for doc in documents}
return Documents(docs=docs)
def to_action_output(self) -> ActionOutput:
"""Convert to action output string.
:return: A string representing action output.
"""
return ActionOutput(content=self.model_dump_json(), instruct_content=self)
class Message(BaseModel):
"""list[<role>: <content>]"""

View file

@ -21,6 +21,7 @@ from metagpt.const import (
GRAPH_REPO_FILE_REPO,
PRD_PDF_FILE_REPO,
PRDS_FILE_REPO,
REQUIREMENT_FILENAME,
RESOURCES_FILE_REPO,
SD_OUTPUT_FILE_REPO,
SEQ_FLOW_FILE_REPO,
@ -93,6 +94,10 @@ class ProjectRepo(FileRepository):
self.test_outputs = self._git_repo.new_file_repository(relative_path=TEST_OUTPUTS_FILE_REPO)
self._srcs_path = None
@property
async def requirement(self):
return await self.docs.get(filename=REQUIREMENT_FILENAME)
@property
def git_repo(self) -> GitRepository:
return self._git_repo
@ -107,6 +112,15 @@ class ProjectRepo(FileRepository):
raise ValueError("Call with_srcs first.")
return self._git_repo.new_file_repository(self._srcs_path)
def code_files_exists(self) -> bool:
git_workdir = self.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
if not src_workdir.exists():
return False
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
def with_src_path(self, path: str | Path) -> ProjectRepo:
try:
self._srcs_path = Path(path).relative_to(self.workdir)