From 9eba81862b12923f15ae5d9b4cc647d9c2bec8e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 16 May 2024 12:40:55 +0800 Subject: [PATCH 01/20] fixbug: circular import --- metagpt/tools/libs/git.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/metagpt/tools/libs/git.py b/metagpt/tools/libs/git.py index eb3fd6822..b4d759bf4 100644 --- a/metagpt/tools/libs/git.py +++ b/metagpt/tools/libs/git.py @@ -9,7 +9,6 @@ from github.Issue import Issue from github.PullRequest import PullRequest from metagpt.tools.tool_registry import register_tool -from metagpt.utils.git_repository import GitBranch, GitRepository @register_tool(tags=["software development", "git", "Commit the changes and push to remote git repository."]) @@ -18,7 +17,7 @@ async def git_push( access_token: str, comments: str = "Commit", new_branch: str = "", -) -> GitBranch: +) -> "GitBranch": """ Pushes changes from a local Git repository to its remote counterpart. @@ -49,6 +48,8 @@ async def git_push( base branch:'master', head branch:'feature/new', repo_name:'iorisa/snake-game' """ + from metagpt.utils.git_repository import GitRepository + if not GitRepository.is_git_dir(local_path): raise ValueError("Invalid local git repository") From 6326526a08d1aed1bc6178393540e98db31d804c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Fri, 17 May 2024 13:28:28 +0800 Subject: [PATCH 02/20] feat: +vault config demo --- config/vault.example.yaml | 48 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 config/vault.example.yaml diff --git a/config/vault.example.yaml b/config/vault.example.yaml new file mode 100644 index 000000000..0e197d2a8 --- /dev/null +++ b/config/vault.example.yaml @@ -0,0 +1,48 @@ +# Usage: +# 1. Get value. +# >>> from metagpt.tools.libs.env import get_env +# >>> access_token = await get_env(key="access_token", app_name="github") +# >>> print(access_token) +# YOUR_ACCESS_TOKEN +# +# 2. Get description for LLM understanding. +# >>> from metagpt.tools.libs.env import get_env_description +# >>> descriptions = await get_env_description +# >>> for k, desc in descriptions.items(): +# >>> print(f"{key}:{desc}") +# await get_env(key="access_token", app_name="github"):Get github access token +# await get_env(key="access_token", app_name="gitlab"):Get gitlab access token +# ... + +vault: + github: + values: + access_token: "YOUR_ACCESS_TOKEN" + descriptions: + access_token: "Get github access token" + gitlab: + values: + access_token: "YOUR_ACCESS_TOKEN" + descriptions: + access_token: "Get gitlab access token" + iflytek_tts: + values: + api_id: "YOUR_APP_ID" + api_key: "YOUR_API_KEY" + api_secret: "YOUR_API_SECRET" + descriptions: + api_id: "Get the API ID of IFlyTek Text to Speech" + api_key: "Get the API KEY of IFlyTek Text to Speech" + api_secret: "Get the API SECRET of IFlyTek Text to Speech" + azure_tts: + values: + subscription_key: "YOUR_SUBSCRIPTION_KEY" + region: "YOUR_REGION" + descriptions: + subscription_key: "Get the subscription key of Azure Text to Speech." + region: "Get the region of Azure Text to Speech." + default: # All key-value pairs whose app name is an empty string are placed below + values: + proxy: "YOUR_PROXY" + descriptions: + proxy: "Get proxy for tools like requests, playwright, selenium, etc." \ No newline at end of file From ee0b9d2039e41625ff5669b1e0c4c8ad0d1f79d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Sat, 18 May 2024 14:36:22 +0800 Subject: [PATCH 03/20] feat: new/inc/patch pass --- metagpt/actions/action.py | 7 - metagpt/actions/debug_error.py | 10 +- metagpt/actions/design_api.py | 34 ++- metagpt/actions/prepare_documents.py | 22 +- metagpt/actions/project_management.py | 25 +- metagpt/actions/summarize_code.py | 12 +- metagpt/actions/write_code.py | 72 +++--- .../actions/write_code_plan_and_change_an.py | 37 +-- metagpt/actions/write_code_review.py | 13 +- metagpt/actions/write_prd.py | 47 +++- metagpt/actions/write_prd_an.py | 4 +- metagpt/roles/engineer.py | 213 ++++++++++++------ metagpt/roles/product_manager.py | 7 +- metagpt/roles/qa_engineer.py | 95 +++++--- metagpt/roles/role.py | 24 -- metagpt/schema.py | 66 ++++-- metagpt/utils/common.py | 41 ++++ metagpt/utils/git_repository.py | 2 + metagpt/utils/project_repo.py | 9 +- tests/metagpt/actions/test_action_node.py | 3 +- tests/metagpt/test_schema.py | 6 + 21 files changed, 512 insertions(+), 237 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 5fd538720..b760c96d8 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -22,7 +22,6 @@ from metagpt.schema import ( SerializationMixin, TestingContext, ) -from metagpt.utils.project_repo import ProjectRepo class Action(SerializationMixin, ContextMixin, BaseModel): @@ -36,12 +35,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel): desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) - @property - def repo(self) -> ProjectRepo: - if not self.context.repo: - self.context.repo = ProjectRepo(self.context.git_repo) - return self.context.repo - @property def prompt_schema(self): return self.config.prompt_schema diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index b027616f7..8f0f52266 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -9,13 +9,15 @@ 2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. """ import re +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser +from metagpt.utils.project_repo import ProjectRepo PROMPT_TEMPLATE = """ NOTICE @@ -47,6 +49,8 @@ Now you should start rewriting the code: class DebugError(Action): i_context: RunCodeContext = Field(default_factory=RunCodeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) async def run(self, *args, **kwargs) -> str: output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename) @@ -59,9 +63,7 @@ class DebugError(Action): return "" logger.info(f"Debug and rewrite {self.i_context.test_filename}") - code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get( - filename=self.i_context.code_filename - ) + code_doc = await self.repo.srcs.get(filename=self.i_context.code_filename) if not code_doc: return "" test_doc = await self.repo.tests.get(filename=self.i_context.test_filename) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 613c4a47b..2e84cc463 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -13,6 +13,8 @@ import json from pathlib import Path from typing import Optional +from pydantic import BaseModel, Field + from metagpt.actions import Action from metagpt.actions.design_api_an import ( DATA_STRUCTURES_AND_INTERFACES, @@ -26,6 +28,7 @@ from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message from metagpt.utils.mermaid import mermaid_to_file +from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter, GalleryReporter NEW_REQ_TEMPLATE = """ @@ -45,21 +48,25 @@ class WriteDesign(Action): "data structures, library tables, processes, and paths. Please provide your design, feedback " "clearly and in detail." ) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) async def run(self, with_messages: Message, schema: str = None): - # Use `git status` to identify which PRD documents have been modified in the `docs/prd` directory. - changed_prds = self.repo.docs.prd.changed_files - # Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone - # changes. - changed_system_designs = self.repo.docs.system_design.changed_files + self.input_args = with_messages[0].instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + changed_prds = self.input_args.changed_prd_filenames + changed_system_designs = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] # For those PRDs and design documents that have undergone changes, regenerate the design content. changed_files = Documents() - for filename in changed_prds.keys(): + for filename in changed_prds: doc = await self._update_system_design(filename=filename) changed_files.docs[filename] = doc - for filename in changed_system_designs.keys(): + for filename in changed_system_designs: if filename in changed_files.docs: continue doc = await self._update_system_design(filename=filename) @@ -68,6 +75,11 @@ class WriteDesign(Action): logger.info("Nothing has changed.") # Wait until all files under `docs/system_designs/` are processed before sending the publish message, # leaving room for global optimization in subsequent steps. + kvs = self.input_args.model_dump() + kvs["changed_system_design_filenames"] = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] return AIMessage( content="Designing is complete. " + "\n".join( @@ -75,6 +87,7 @@ class WriteDesign(Action): + list(self.repo.resources.data_api_design.changed_files.keys()) + list(self.repo.resources.seq_flow.changed_files.keys()) ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput"), cause_by=self, ) @@ -89,14 +102,15 @@ class WriteDesign(Action): return system_design_doc async def _update_system_design(self, filename) -> Document: - prd = await self.repo.docs.prd.get(filename) - old_system_design_doc = await self.repo.docs.system_design.get(filename) + root_relative_path = Path(filename).relative_to(self.repo.workdir) + prd = await Document.load(filename=filename, project_path=self.repo.workdir) + old_system_design_doc = await self.repo.docs.system_design.get(root_relative_path.name) async with DocsReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "design"}, "meta") if not old_system_design_doc: system_design = await self._new_system_design(context=prd.content) doc = await self.repo.docs.system_design.save( - filename=filename, + filename=prd.filename, content=system_design.instruct_content.model_dump_json(), dependencies={prd.root_relative_path}, ) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index eb674374c..89ebd59a3 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -17,6 +17,7 @@ from metagpt.logs import logger from metagpt.schema import AIMessage from metagpt.utils.common import any_to_str from metagpt.utils.file_repository import FileRepository +from metagpt.utils.project_repo import ProjectRepo class PrepareDocuments(Action): @@ -36,7 +37,7 @@ class PrepareDocuments(Action): def config(self): return self.context.config - def _init_repo(self): + def _init_repo(self) -> ProjectRepo: """Initialize the Git environment.""" if not self.config.project_path: name = self.config.project_name or FileRepository.new_filename() @@ -47,6 +48,7 @@ class PrepareDocuments(Action): shutil.rmtree(path) self.config.project_path = path self.context.set_repo_dir(path) + return ProjectRepo(path) async def run(self, with_messages, **kwargs): """Create and initialize the workspace folder, initialize the Git environment.""" @@ -67,10 +69,22 @@ class PrepareDocuments(Action): max_auto_summarize_code=0, ) - self._init_repo() + repo = self._init_repo() # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. - doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content) + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content) # Send a Message notification to the WritePRD action, instructing it to process requirements using # `docs/requirement.txt` and `docs/prd/`. - return AIMessage(content="", instruct_content=doc, cause_by=self, send_to=self.send_to) + return AIMessage( + content="", + instruct_content=AIMessage.create_instruct_value( + kvs={ + "project_path": str(repo.workdir), + "requirements_filename": str(repo.docs.workdir / REQUIREMENT_FILENAME), + "prd_filenames": [str(repo.docs.prd.workdir / i) for i in repo.docs.prd.all_files], + }, + class_name="PrepareDocumentsOutput", + ), + cause_by=self, + send_to=self.send_to, + ) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index ef0fe6fc6..55356f58b 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -11,13 +11,17 @@ """ import json +from pathlib import Path from typing import Optional +from pydantic import BaseModel, Field + from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents +from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter NEW_REQ_TEMPLATE = """ @@ -32,10 +36,14 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): name: str = "CreateTasks" i_context: Optional[str] = None + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) async def run(self, with_messages): - changed_system_designs = self.repo.docs.system_design.changed_files - changed_tasks = self.repo.docs.task.changed_files + self.input_args = with_messages[0].instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + changed_system_designs = self.input_args.changed_system_design_filenames + changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())] change_files = Documents() # Rewrite the system designs that have undergone changes based on the git head diff under # `docs/system_designs/`. @@ -54,6 +62,11 @@ class WriteTasks(Action): logger.info("Nothing has changed.") # Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for # global optimization in subsequent steps. + kvs = self.input_args.model_dump() + kvs["changed_task_filenames"] = [ + str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys()) + ] + kvs["python_package_dependency_filename"] = str(self.repo.workdir / PACKAGE_REQUIREMENTS_FILENAME) return AIMessage( content="WBS is completed. " + "\n".join( @@ -61,12 +74,14 @@ class WriteTasks(Action): + list(self.repo.docs.task.changed_files.keys()) + list(self.repo.resources.api_spec_and_task.changed_files.keys()) ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTaskOutput"), cause_by=self, ) async def _update_tasks(self, filename): - system_design_doc = await self.repo.docs.system_design.get(filename) - task_doc = await self.repo.docs.task.get(filename) + root_relative_path = Path(filename).relative_to(self.repo.workdir) + system_design_doc = await Document.load(filename=filename, project_path=self.repo.workdir) + task_doc = await self.repo.docs.task.get(root_relative_path.name) async with DocsReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "task"}, "meta") if task_doc: @@ -75,7 +90,7 @@ class WriteTasks(Action): else: rsp = await self._run_new_tasks(context=system_design_doc.content) task_doc = await self.repo.docs.task.save( - filename=filename, + filename=system_design_doc.filename, content=rsp.instruct_content.model_dump_json(), dependencies={system_design_doc.root_relative_path}, ) diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index d21b62f83..e3556caa7 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -6,13 +6,16 @@ @Modified By: mashenquan, 2023/12/5. Archive the summarization content of issue discovery for use in WriteCode. """ from pathlib import Path +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext +from metagpt.utils.common import get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo PROMPT_TEMPLATE = """ NOTICE @@ -90,6 +93,8 @@ flowchart TB class SummarizeCode(Action): name: str = "SummarizeCode" i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): @@ -101,11 +106,10 @@ class SummarizeCode(Action): design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name) task_pathname = Path(self.i_context.task_filename) task_doc = await self.repo.docs.task.get(filename=task_pathname.name) - src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs code_blocks = [] for filename in self.i_context.codes_filenames: - code_doc = await src_file_repo.get(filename) - code_block = f"```python\n{code_doc.content}\n```\n-----" + code_doc = await self.repo.srcs.get(filename) + code_block = f"```{get_markdown_code_block_type(filename)}\n{code_doc.content}\n```\n---\n" code_blocks.append(code_block) format_example = FORMAT_EXAMPLE prompt = PROMPT_TEMPLATE.format( diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 67b859d23..7f225d469 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -16,17 +16,18 @@ """ import json +from pathlib import Path +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE -from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult -from metagpt.utils.common import CodeParser +from metagpt.utils.common import CodeParser, get_markdown_code_block_type from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import EditorReporter @@ -44,9 +45,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc {task} ## Legacy Code -```Code {code} -``` ## Debug logs ```text @@ -61,14 +60,14 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc ``` # Format example -## Code: {filename} +## Code: {demo_filename}.py ```python -## {filename} +## {demo_filename}.py ... ``` -## Code: {filename} +## Code: {demo_filename}.js ```javascript -// {filename} +// {demo_filename}.js ... ``` @@ -89,6 +88,8 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" i_context: Document = Field(default_factory=Document) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: @@ -97,10 +98,16 @@ class WriteCode(Action): return code async def run(self, *args, **kwargs) -> CodingContext: - bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME) + bug_feedback = None + if self.input_args and hasattr(self.input_args, "issue_filename"): + bug_feedback = await Document.load(self.input_args.issue_filename) coding_context = CodingContext.loads(self.i_context.content) + if not coding_context.code_plan_and_change_doc: + coding_context.code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get( + filename=coding_context.task_doc.filename + ) test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json") - requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME) + requirement_doc = await Document.load(self.input_args.requirements_filename) summary_doc = None if coding_context.design_doc and coding_context.design_doc.filename: summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename) @@ -109,29 +116,28 @@ class WriteCode(Action): test_detail = RunCodeResult.loads(test_doc.content) logs = test_detail.stderr - if bug_feedback: - code_context = coding_context.code_doc.content - elif self.config.inc: + if self.config.inc or bug_feedback: code_context = await self.get_codes( coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True ) else: code_context = await self.get_codes( - coding_context.task_doc, - exclude=self.i_context.filename, - project_repo=self.repo.with_src_path(self.context.src_workspace), + coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo ) if self.config.inc: prompt = REFINED_TEMPLATE.format( user_requirement=requirement_doc.content if requirement_doc else "", - code_plan_and_change=str(coding_context.code_plan_and_change_doc), + code_plan_and_change=coding_context.code_plan_and_change_doc.content + if coding_context.code_plan_and_change_doc + else "", design=coding_context.design_doc.content if coding_context.design_doc else "", task=coding_context.task_doc.content if coding_context.task_doc else "", code=code_context, logs=logs, feedback=bug_feedback.content if bug_feedback else "", filename=self.i_context.filename, + demo_filename=Path(self.i_context.filename).stem, summary_log=summary_doc.content if summary_doc else "", ) else: @@ -142,6 +148,7 @@ class WriteCode(Action): logs=logs, feedback=bug_feedback.content if bug_feedback else "", filename=self.i_context.filename, + demo_filename=Path(self.i_context.filename).stem, summary_log=summary_doc.content if summary_doc else "", ) logger.info(f"Writing {coding_context.filename}..") @@ -150,8 +157,9 @@ class WriteCode(Action): code = await self.write_code(prompt) if not coding_context.code_doc: # avoid root_path pydantic ValidationError if use WriteCode alone - root_path = self.context.src_workspace if self.context.src_workspace else "" - coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) + coding_context.code_doc = Document( + filename=coding_context.filename, root_path=str(self.repo.src_relative_path) + ) coding_context.code_doc.content = code await reporter.async_report(self.repo.workdir / coding_context.code_doc.root_relative_path, "path") return coding_context @@ -178,35 +186,32 @@ class WriteCode(Action): code_filenames = m.get(TASK_LIST.key, []) if not use_inc else m.get(REFINED_TASK_LIST.key, []) codes = [] src_file_repo = project_repo.srcs - # Incremental development scenario if use_inc: - src_files = src_file_repo.all_files - # Get the old workspace contained the old codes and old workspace are created in previous CodePlanAndChange - old_file_repo = project_repo.git_repo.new_file_repository(relative_path=project_repo.old_workspace) - old_files = old_file_repo.all_files - # Get the union of the files in the src and old workspaces - union_files_list = list(set(src_files) | set(old_files)) - for filename in union_files_list: + for filename in src_file_repo.all_files: + code_block_type = get_markdown_code_block_type(filename) # Exclude the current file from the all code snippets if filename == exclude: # If the file is in the old workspace, use the old code # Exclude unnecessary code to maintain a clean and focused main.py file, ensuring only relevant and # essential functionality is included for the project’s requirements - if filename in old_files and filename != "main.py": + if filename != "main.py": # Use old code - doc = await old_file_repo.get(filename=filename) + doc = await src_file_repo.get(filename=filename) # If the file is in the src workspace, skip it else: continue - codes.insert(0, f"-----Now, {filename} to be rewritten\n```{doc.content}```\n=====") + codes.insert( + 0, f"### The name of file to rewrite: `{filename}`\n```{code_block_type}\n{doc.content}```\n" + ) + logger.info(f"Prepare to rewrite `{filename}`") # The code snippets are generated from the src workspace else: doc = await src_file_repo.get(filename=filename) # If the file does not exist in the src workspace, skip it if not doc: continue - codes.append(f"----- {filename}\n```{doc.content}```") + codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n") # Normal scenario else: @@ -217,6 +222,7 @@ class WriteCode(Action): doc = await src_file_repo.get(filename=filename) if not doc: continue - codes.append(f"----- {filename}\n```{doc.content}```") + code_block_type = get_markdown_code_block_type(filename) + codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n") return "\n".join(codes) diff --git a/metagpt/actions/write_code_plan_and_change_an.py b/metagpt/actions/write_code_plan_and_change_an.py index a3c0e50a4..31482a94d 100644 --- a/metagpt/actions/write_code_plan_and_change_an.py +++ b/metagpt/actions/write_code_plan_and_change_an.py @@ -5,15 +5,16 @@ @Author : mannaandpoem @File : write_code_plan_and_change_an.py """ -import os -from typing import List +from typing import List, Optional -from pydantic import Field +from pydantic import BaseModel, Field from metagpt.actions.action import Action from metagpt.actions.action_node import ActionNode from metagpt.logs import logger -from metagpt.schema import CodePlanAndChangeContext +from metagpt.schema import CodePlanAndChangeContext, Document +from metagpt.utils.common import get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo DEVELOPMENT_PLAN = ActionNode( key="Development Plan", @@ -162,9 +163,8 @@ Role: You are a professional engineer; The main goal is to complete incremental {task} ## Legacy Code -```Code {code} -``` + ## Debug logs ```text @@ -179,14 +179,14 @@ Role: You are a professional engineer; The main goal is to complete incremental ``` # Format example -## Code: {filename} +## Code: {demo_filename}.py ```python -## {filename} +## {demo_filename}.py ... ``` -## Code: {filename} +## Code: {demo_filename}.js ```javascript -// {filename} +// {demo_filename}.js ... ``` @@ -211,13 +211,15 @@ WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChan class WriteCodePlanAndChange(Action): name: str = "WriteCodePlanAndChange" i_context: CodePlanAndChangeContext = Field(default_factory=CodePlanAndChangeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) async def run(self, *args, **kwargs): self.llm.system_prompt = "You are a professional software engineer, your primary responsibility is to " "meticulously craft comprehensive incremental development plan and deliver detailed incremental change" - prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename) - design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename) - task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename) + prd_doc = await Document.load(filename=self.i_context.prd_filename) + design_doc = await Document.load(filename=self.i_context.design_filename) + task_doc = await Document.load(filename=self.i_context.task_filename) context = CODE_PLAN_AND_CHANGE_CONTEXT.format( requirement=f"```text\n{self.i_context.requirement}\n```", issue=f"```text\n{self.i_context.issue}\n```", @@ -230,8 +232,9 @@ class WriteCodePlanAndChange(Action): return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json") async def get_old_codes(self) -> str: - self.repo.old_workspace = self.repo.git_repo.workdir / os.path.basename(self.config.project_path) - old_file_repo = self.repo.git_repo.new_file_repository(relative_path=self.repo.old_workspace) - old_codes = await old_file_repo.get_all() - codes = [f"----- {code.filename}\n```{code.content}```" for code in old_codes] + old_codes = await self.repo.srcs.get_all() + codes = [ + f"### File Name: `{code.filename}`\n```{get_markdown_code_block_type(code.filename)}\n{code.content}```\n" + for code in old_codes + ] return "\n".join(codes) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index f0faea701..3912095df 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -7,16 +7,17 @@ @Modified By: mashenquan, 2023/11/27. Following the think-act principle, solidify the task parameters when creating the WriteCode object, rather than passing them in when calling the run function. """ +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode from metagpt.actions.action import Action -from metagpt.const import REQUIREMENT_FILENAME from metagpt.logs import logger -from metagpt.schema import CodingContext +from metagpt.schema import CodingContext, Document from metagpt.utils.common import CodeParser +from metagpt.utils.project_repo import ProjectRepo PROMPT_TEMPLATE = """ # System @@ -126,6 +127,8 @@ or class WriteCodeReview(Action): name: str = "WriteCodeReview" i_context: CodingContext = Field(default_factory=CodingContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): @@ -150,7 +153,7 @@ class WriteCodeReview(Action): code_context = await WriteCode.get_codes( self.i_context.task_doc, exclude=self.i_context.filename, - project_repo=self.repo.with_src_path(self.context.src_workspace), + project_repo=self.repo, use_inc=self.config.inc, ) @@ -160,7 +163,7 @@ class WriteCodeReview(Action): "## Code Files\n" + code_context + "\n", ] if self.config.inc: - requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME) + requirement_doc = await Document.load(filename=self.input_args.requirements_filename) insert_ctx_list = [ "## User New Requirements\n" + str(requirement_doc) + "\n", "## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n", diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index a4f6e1dd1..3275619f7 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -15,6 +15,9 @@ from __future__ import annotations import json from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode @@ -37,6 +40,7 @@ from metagpt.schema import AIMessage, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file +from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter, GalleryReporter CONTEXT_TEMPLATE = """ @@ -66,10 +70,30 @@ class WritePRD(Action): 3. Requirement update: If the requirement is an update, the PRD document will be updated. """ + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + async def run(self, with_messages, *args, **kwargs) -> Message: """Run the action.""" - req: Document = await self.repo.requirement - docs: list[Document] = await self.repo.docs.prd.get_all() + self.input_args = with_messages[-1].instruct_content + if not self.input_args: + self.repo = ProjectRepo(self.config.project_path) + await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content) + self.input_args = AIMessage.create_instruct_value( + kvs={ + "project_path": self.config.project_path, + "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), + "prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files], + }, + class_name="PrepareDocumentsOutput", + ) + else: + self.repo = ProjectRepo(self.input_args.project_path) + req = await Document.load(filename=self.input_args.requirements_filename) + docs: list[Document] = [ + await Document.load(filename=i, project_path=self.repo.workdir) for i in self.input_args.prd_filenames + ] + if not req: raise FileNotFoundError("No requirement document found.") @@ -82,10 +106,14 @@ class WritePRD(Action): # if requirement is related to other documents, update them, otherwise create a new one if related_docs := await self.get_related_docs(req, docs): logger.info(f"Requirement update detected: {req.content}") - await self._handle_requirement_update(req, related_docs) + await self._handle_requirement_update(req=req, related_docs=related_docs) else: logger.info(f"New requirement detected: {req.content}") await self._handle_new_requirement(req) + kvs = self.input_args.model_dump() + kvs["changed_prd_filenames"] = [ + str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys()) + ] return AIMessage( content="PRD is completed. " + "\n".join( @@ -93,6 +121,7 @@ class WritePRD(Action): + list(self.repo.resources.prd.changed_files.keys()) + list(self.repo.resources.competitive_analysis.changed_files.keys()) ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput"), cause_by=self, ) @@ -103,6 +132,14 @@ class WritePRD(Action): return AIMessage( content=f"A new issue is received: {BUGFIX_FILENAME}", cause_by=FixBug, + instruct_content=AIMessage.create_instruct_value( + { + "project_path": str(self.repo.workdir), + "issue_filename": str(self.repo.docs.workdir / BUGFIX_FILENAME), + "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), + }, + class_name="IssueDetail", + ), send_to="Alex", # the name of Engineer ) @@ -128,7 +165,7 @@ class WritePRD(Action): async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput: # ... requirement update logic ... for doc in related_docs: - await self._update_prd(req, doc) + await self._update_prd(req=req, prd_doc=doc) return Documents.from_iterable(documents=related_docs).to_action_output() async def _is_bugfix(self, context: str) -> bool: @@ -159,7 +196,7 @@ class WritePRD(Action): async def _update_prd(self, req: Document, prd_doc: Document) -> Document: async with DocsReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "prd"}, "meta") - new_prd_doc: Document = await self._merge(req, prd_doc) + new_prd_doc: Document = await self._merge(req=req, related_doc=prd_doc) await self.repo.docs.prd.save_doc(doc=new_prd_doc) await self._save_competitive_analysis(new_prd_doc) md = await self.repo.resources.prd.save_pdf(doc=new_prd_doc) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index a33685cd3..1ceb2aade 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -5,7 +5,7 @@ @Author : alexanderwu @File : write_prd_an.py """ -from typing import List +from typing import List, Union from metagpt.actions.action_node import ActionNode @@ -132,7 +132,7 @@ REQUIREMENT_ANALYSIS = ActionNode( REFINED_REQUIREMENT_ANALYSIS = ActionNode( key="Refined Requirement Analysis", - expected_type=List[str], + expected_type=Union[List[str], str], instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project " "due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements " "required for the refined project scope.", diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index d9e375a9a..111e534a6 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -24,20 +24,15 @@ from collections import defaultdict from pathlib import Path from typing import List, Optional, Set -from metagpt.actions import ( - Action, - UserRequirement, - WriteCode, - WriteCodeReview, - WriteTasks, -) +from pydantic import BaseModel, Field + +from metagpt.actions import UserRequirement, WriteCode, WriteCodeReview, WriteTasks from metagpt.actions.fix_bug import FixBug from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST from metagpt.actions.summarize_code import SummarizeCode from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange from metagpt.const import ( - BUGFIX_FILENAME, CODE_PLAN_AND_CHANGE_FILE_REPO, MESSAGE_ROUTE_TO_SELF, REQUIREMENT_FILENAME, @@ -63,6 +58,7 @@ from metagpt.utils.common import ( init_python_folder, ) from metagpt.utils.git_repository import ChangeType +from metagpt.utils.project_repo import ProjectRepo IS_PASS_PROMPT = """ {context} @@ -100,6 +96,8 @@ class Engineer(Role): summarize_todos: list = [] next_todo_action: str = "" n_summarize: int = 0 + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -139,14 +137,20 @@ class Engineer(Role): coding_context = await todo.run() # Code review if review: - action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm) + action = WriteCodeReview( + i_context=coding_context, + repo=self.repo, + input_args=self.input_args, + context=self.context, + llm=self.llm, + ) self._init_action(action) coding_context = await action.run() dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path} if self.config.inc: dependencies.add(coding_context.code_plan_and_change_doc.root_relative_path) - await self.project_repo.srcs.save( + await self.repo.srcs.save( filename=coding_context.filename, dependencies=list(dependencies), content=coding_context.code_doc.content, @@ -186,9 +190,9 @@ class Engineer(Role): summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name dependencies = {todo.i_context.design_filename, todo.i_context.task_filename} for filename in todo.i_context.codes_filenames: - rpath = self.project_repo.src_relative_path / filename + rpath = self.repo.src_relative_path / filename dependencies.add(str(rpath)) - await self.project_repo.resources.code_summary.save( + await self.repo.resources.code_summary.save( filename=summary_filename, content=summary, dependencies=dependencies ) is_pass, reason = await self._is_pass(summary) @@ -196,23 +200,39 @@ class Engineer(Role): todo.i_context.reason = reason tasks.append(todo.i_context.model_dump()) - await self.project_repo.docs.code_summary.save( + await self.repo.docs.code_summary.save( filename=Path(todo.i_context.design_filename).name, content=todo.i_context.model_dump_json(), dependencies=dependencies, ) else: - await self.project_repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name) + await self.repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name) self.summarize_todos = [] logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}") if not tasks or self.config.max_auto_summarize_code == 0: self.n_summarize = 0 + kvs = self.input_args.model_dump() + kvs["changed_src_filenames"] = [ + str(self.repo.srcs.workdir / i) for i in list(self.repo.srcs.changed_files.keys()) + ] + if self.repo.docs.code_plan_and_change.changed_files: + kvs["changed_code_plan_and_change_filenames"] = [ + str(self.repo.docs.code_plan_and_change.workdir / i) + for i in list(self.repo.docs.code_plan_and_change.changed_files.keys()) + ] + if self.repo.docs.code_summary.changed_files: + kvs["changed_code_summary_filenames"] = [ + str(self.repo.docs.code_summary.workdir / i) + for i in list(self.repo.docs.code_summary.changed_files.keys()) + ] return AIMessage( - content=f"Coding is complete. The source code is at {self.project_repo.workdir.name}/{self.project_repo.srcs.root_path}, containing: " + content=f"Coding is complete. The source code is at {self.repo.workdir.name}/{self.repo.srcs.root_path}, containing: " + "\n".join( - list(self.project_repo.resources.code_summary.changed_files.keys()) - + list(self.project_repo.srcs.changed_files.keys()) + list(self.repo.resources.code_summary.changed_files.keys()) + + list(self.repo.srcs.changed_files.keys()) + + list(self.repo.resources.code_plan_and_change.changed_files.keys()) ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="SummarizeCodeOutput"), cause_by=SummarizeCode, send_to="Edward", # The name of QaEngineer ) @@ -227,15 +247,15 @@ class Engineer(Role): code_plan_and_change = node.instruct_content.model_dump_json() dependencies = { REQUIREMENT_FILENAME, - str(self.project_repo.docs.prd.root_path / self.rc.todo.i_context.prd_filename), - str(self.project_repo.docs.system_design.root_path / self.rc.todo.i_context.design_filename), - str(self.project_repo.docs.task.root_path / self.rc.todo.i_context.task_filename), + str(Path(self.rc.todo.i_context.prd_filename).relative_to(self.repo.workdir)), + str(Path(self.rc.todo.i_context.design_filename).relative_to(self.repo.workdir)), + str(Path(self.rc.todo.i_context.task_filename).relative_to(self.repo.workdir)), } code_plan_and_change_filepath = Path(self.rc.todo.i_context.design_filename) - await self.project_repo.docs.code_plan_and_change.save( + await self.repo.docs.code_plan_and_change.save( filename=code_plan_and_change_filepath.name, content=code_plan_and_change, dependencies=dependencies ) - await self.project_repo.resources.code_plan_and_change.save( + await self.repo.resources.code_plan_and_change.save( filename=code_plan_and_change_filepath.with_suffix(".md").name, content=node.content, dependencies=dependencies, @@ -250,10 +270,11 @@ class Engineer(Role): return True, rsp return False, rsp - async def _think(self) -> Action | None: + async def _think(self) -> bool: if not self.rc.news: - return None + return False msg = self.rc.news[0] + input_args = msg.instruct_content if msg.cause_by == any_to_str(UserRequirement): self.rc.todo = PrepareDocuments( key_descriptions={ @@ -263,42 +284,47 @@ class Engineer(Role): context=self.context, send_to=any_to_str(self), ) - return self.rc.todo - - if not self.src_workspace: - self.src_workspace = get_project_srcs_path(self.project_repo.workdir) + self.repo = ProjectRepo(input_args.project_path) + self.input_args = input_args + return bool(self.rc.todo) + elif msg.cause_by in {any_to_str(WriteTasks), any_to_str(FixBug)}: + self.input_args = input_args + self.repo = ProjectRepo(input_args.project_path) + if self.repo.src_relative_path is None: + path = get_project_srcs_path(self.repo.workdir) + self.repo.with_src_path(path) write_plan_and_change_filters = any_to_str_set([PrepareDocuments, WriteTasks, FixBug]) write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode]) summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview]) if self.config.inc and msg.cause_by in write_plan_and_change_filters: logger.debug(f"TODO WriteCodePlanAndChange:{msg.model_dump_json()}") await self._new_code_plan_and_change_action(cause_by=msg.cause_by) - return self.rc.todo + return bool(self.rc.todo) if msg.cause_by in write_code_filters: logger.debug(f"TODO WriteCode:{msg.model_dump_json()}") await self._new_code_actions() - return self.rc.todo + return bool(self.rc.todo) if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self): logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}") await self._new_summarize_actions() - return self.rc.todo - return None + return bool(self.rc.todo) + return False async def _new_coding_context(self, filename, dependency) -> Optional[CodingContext]: - old_code_doc = await self.project_repo.srcs.get(filename) + old_code_doc = await self.repo.srcs.get(filename) if not old_code_doc: - old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="") + old_code_doc = Document(root_path=str(self.repo.src_relative_path), filename=filename, content="") dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)} task_doc = None design_doc = None code_plan_and_change_doc = await self._get_any_code_plan_and_change() if await self._is_fixbug() else None for i in dependencies: if str(i.parent) == TASK_FILE_REPO: - task_doc = await self.project_repo.docs.task.get(i.name) + task_doc = await self.repo.docs.task.get(i.name) elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO: - design_doc = await self.project_repo.docs.system_design.get(i.name) + design_doc = await self.repo.docs.system_design.get(i.name) elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO: - code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name) + code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(i.name) if not task_doc or not design_doc: if filename == "__init__.py": # `__init__.py` created by `init_python_folder` return None @@ -318,34 +344,66 @@ class Engineer(Role): if not context: return None # `__init__.py` created by `init_python_folder` coding_doc = Document( - root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json() + root_path=str(self.repo.src_relative_path), filename=filename, content=context.model_dump_json() ) return coding_doc async def _new_code_actions(self): bug_fix = await self._is_fixbug() # Prepare file repos - changed_src_files = self.project_repo.srcs.changed_files + changed_src_files = self.repo.srcs.changed_files if self.context.kwargs.src_filename: changed_src_files = {self.context.kwargs.src_filename: ChangeType.UNTRACTED} if bug_fix: - changed_src_files = self.project_repo.srcs.all_files - changed_task_files = self.project_repo.docs.task.changed_files + changed_src_files = self.repo.srcs.all_files changed_files = Documents() # Recode caused by upstream changes. - for filename in changed_task_files: - design_doc = await self.project_repo.docs.system_design.get(filename) - task_doc = await self.project_repo.docs.task.get(filename) - code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename) + if hasattr(self.input_args, "changed_task_filenames"): + changed_task_filenames = self.input_args.changed_task_filenames + else: + changed_task_filenames = [ + str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys()) + ] + for filename in changed_task_filenames: + task_filename = Path(filename) + design_filename = None + if hasattr(self.input_args, "changed_system_design_filenames"): + changed_system_design_filenames = self.input_args.changed_system_design_filenames + else: + changed_system_design_filenames = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] + for i in changed_system_design_filenames: + if task_filename.name == Path(i).name: + design_filename = Path(i) + break + code_plan_and_change_filename = None + if hasattr(self.input_args, "changed_code_plan_and_change_filenames"): + changed_code_plan_and_change_filenames = self.input_args.changed_code_plan_and_change_filenames + else: + changed_code_plan_and_change_filenames = [ + str(self.repo.docs.code_plan_and_change.workdir / i) + for i in list(self.repo.docs.code_plan_and_change.changed_files.keys()) + ] + for i in changed_code_plan_and_change_filenames: + if task_filename.name == Path(i).name: + code_plan_and_change_filename = Path(i) + break + design_doc = await Document.load(filename=design_filename, project_path=self.repo.workdir) + task_doc = await Document.load(filename=task_filename, project_path=self.repo.workdir) + code_plan_and_change_doc = await Document.load( + filename=code_plan_and_change_filename, project_path=self.repo.workdir + ) task_list = self._parse_tasks(task_doc) await self._init_python_folder(task_list) for task_filename in task_list: if self.context.kwargs.src_filename and task_filename != self.context.kwargs.src_filename: continue - old_code_doc = await self.project_repo.srcs.get(task_filename) + old_code_doc = await self.repo.srcs.get(task_filename) if not old_code_doc: old_code_doc = Document( - root_path=str(self.project_repo.src_relative_path), filename=task_filename, content="" + root_path=str(self.repo.src_relative_path), filename=task_filename, content="" ) if not code_plan_and_change_doc: context = CodingContext( @@ -360,7 +418,7 @@ class Engineer(Role): code_plan_and_change_doc=code_plan_and_change_doc, ) coding_doc = Document( - root_path=str(self.project_repo.src_relative_path), + root_path=str(self.repo.src_relative_path), filename=task_filename, content=context.model_dump_json(), ) @@ -371,10 +429,11 @@ class Engineer(Role): ) changed_files.docs[task_filename] = coding_doc self.code_todos = [ - WriteCode(i_context=i, context=self.context, llm=self.llm) for i in changed_files.docs.values() + WriteCode(i_context=i, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm) + for i in changed_files.docs.values() ] # Code directly modified by the user. - dependency = await self.git_repo.get_dependency() + dependency = await self.repo.git_repo.get_dependency() for filename in changed_src_files: if filename in changed_files.docs: continue @@ -382,24 +441,30 @@ class Engineer(Role): if not coding_doc: continue # `__init__.py` created by `init_python_folder` changed_files.docs[filename] = coding_doc - self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm)) + self.code_todos.append( + WriteCode( + i_context=coding_doc, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ) + ) if self.code_todos: self.set_todo(self.code_todos[0]) async def _new_summarize_actions(self): - src_files = self.project_repo.srcs.all_files + src_files = self.repo.srcs.all_files # Generate a SummarizeCode action for each pair of (system_design_doc, task_doc). summarizations = defaultdict(list) for filename in src_files: - dependencies = await self.project_repo.srcs.get_dependency(filename=filename) + dependencies = await self.repo.srcs.get_dependency(filename=filename) ctx = CodeSummarizeContext.loads(filenames=list(dependencies)) summarizations[ctx].append(filename) for ctx, filenames in summarizations.items(): if not ctx.design_filename or not ctx.task_filename: continue # cause by `__init__.py` which is created by `init_python_folder` ctx.codes_filenames = filenames - new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm) + new_summarize = SummarizeCode( + i_context=ctx, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ) for i, act in enumerate(self.summarize_todos): if act.i_context.task_filename == new_summarize.i_context.task_filename: self.summarize_todos[i] = new_summarize @@ -412,16 +477,37 @@ class Engineer(Role): async def _new_code_plan_and_change_action(self, cause_by: str): """Create a WriteCodePlanAndChange action for subsequent to-do actions.""" - files = self.project_repo.all_files options = {} if cause_by != any_to_str(FixBug): - requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME) + requirement_doc = await Document.load(filename=self.input_args.requirements_filename) options["requirement"] = requirement_doc.content else: - fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME) + fixbug_doc = await Document.load(filename=self.input_args.issue_filename) options["issue"] = fixbug_doc.content - code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, **options) - self.rc.todo = WriteCodePlanAndChange(i_context=code_plan_and_change_ctx, context=self.context, llm=self.llm) + # The code here is flawed: if there are multiple unrelated requirements, this piece of logic will break + if hasattr(self.input_args, "changed_prd_filenames"): + code_plan_and_change_ctx = CodePlanAndChangeContext( + requirement=options.get("requirement", ""), + issue=options.get("issue", ""), + prd_filename=self.input_args.changed_prd_filenames[0], + design_filename=self.input_args.changed_system_design_filenames[0], + task_filename=self.input_args.changed_task_filenames[0], + ) + else: + code_plan_and_change_ctx = CodePlanAndChangeContext( + requirement=options.get("requirement", ""), + issue=options.get("issue", ""), + prd_filename=str(self.repo.docs.prd.workdir / self.repo.docs.prd.all_files[0]), + design_filename=str(self.repo.docs.system_design.workdir / self.repo.docs.system_design.all_files[0]), + task_filename=str(self.repo.docs.task.workdir / self.repo.docs.task.all_files[0]), + ) + self.rc.todo = WriteCodePlanAndChange( + i_context=code_plan_and_change_ctx, + repo=self.repo, + input_args=self.input_args, + context=self.context, + llm=self.llm, + ) @property def action_description(self) -> str: @@ -433,17 +519,16 @@ class Engineer(Role): filename = Path(i) if filename.suffix != ".py": continue - workdir = self.src_workspace / filename.parent + workdir = self.repo.srcs.workdir / filename.parent await init_python_folder(workdir) async def _is_fixbug(self) -> bool: - fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME) - return bool(fixbug_doc and fixbug_doc.content) + return bool(self.input_args and hasattr(self.input_args, "issue_filename")) async def _get_any_code_plan_and_change(self) -> Optional[Document]: - changed_files = self.project_repo.docs.code_plan_and_change.changed_files + changed_files = self.repo.docs.code_plan_and_change.changed_files for filename in changed_files.keys(): - doc = await self.project_repo.docs.code_plan_and_change.get(filename) + doc = await self.repo.docs.code_plan_and_change.get(filename) if doc and doc.content: return doc return None diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 4beab5366..58d8076ab 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -7,10 +7,12 @@ @Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135. """ + from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments -from metagpt.roles.role import Role, RoleReactMode +from metagpt.roles.role import Role from metagpt.utils.common import any_to_name, any_to_str +from metagpt.utils.git_repository import GitRepository class ProductManager(Role): @@ -35,12 +37,11 @@ class ProductManager(Role): self.enable_memory = False self.set_actions([PrepareDocuments(send_to=any_to_str(self)), WritePRD]) self._watch([UserRequirement, PrepareDocuments]) - self.rc.react_mode = RoleReactMode.BY_ORDER self.todo_action = any_to_name(WritePRD) async def _think(self) -> bool: """Decide what to do""" - if self.git_repo and not self.config.git_reinit: + if GitRepository.is_git_dir(self.config.project_path) and not self.config.git_reinit: self._set_state(1) else: self._set_state(0) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index f76baff3f..48ed24c2c 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -14,6 +14,9 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ +from typing import Optional + +from pydantic import BaseModel, Field from metagpt.actions import DebugError, RunCode, UserRequirement, WriteTest from metagpt.actions.prepare_documents import PrepareDocuments @@ -25,9 +28,11 @@ from metagpt.schema import AIMessage, Document, Message, RunCodeContext, Testing from metagpt.utils.common import ( any_to_str, any_to_str_set, + get_project_srcs_path, init_python_folder, parse_recipient, ) +from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import EditorReporter @@ -41,6 +46,8 @@ class QaEngineer(Role): ) test_round_allowed: int = 5 test_round: int = 0 + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -57,22 +64,21 @@ class QaEngineer(Role): self.test_round = 0 async def _write_test(self, message: Message) -> None: - src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs reqa_file = self.context.kwargs.reqa_file or self.config.reqa_file - changed_files = {reqa_file} if reqa_file else set(src_file_repo.changed_files.keys()) + changed_files = {reqa_file} if reqa_file else set(self.repo.srcs.changed_files.keys()) for filename in changed_files: # write tests if not filename or "test" in filename: continue - code_doc = await src_file_repo.get(filename) - if not code_doc: + code_doc = await self.repo.srcs.get(filename) + if not code_doc or not code_doc.content: continue if not code_doc.filename.endswith(".py"): continue - test_doc = await self.project_repo.tests.get("test_" + code_doc.filename) + test_doc = await self.repo.tests.get("test_" + code_doc.filename) if not test_doc: test_doc = Document( - root_path=str(self.project_repo.tests.root_path), filename="test_" + code_doc.filename, content="" + root_path=str(self.repo.tests.root_path), filename="test_" + code_doc.filename, content="" ) logger.info(f"Writing {test_doc.filename}..") context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc) @@ -81,40 +87,38 @@ class QaEngineer(Role): async with EditorReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "test", "filename": test_doc.filename}, "meta") - doc = await self.project_repo.tests.save_doc( + doc = await self.repo.tests.save_doc( doc=context.test_doc, dependencies={context.code_doc.root_relative_path} ) - await reporter.async_report(self.project_repo.workdir / doc.root_relative_path, "path") + await reporter.async_report(self.repo.workdir / doc.root_relative_path, "path") # prepare context for run tests in next round run_code_context = RunCodeContext( command=["python", context.test_doc.root_relative_path], code_filename=context.code_doc.filename, test_filename=context.test_doc.filename, - working_directory=str(self.project_repo.workdir), + working_directory=str(self.repo.workdir), additional_python_paths=[str(self.context.src_workspace)], ) self.publish_message( AIMessage(content=run_code_context.model_dump_json(), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_SELF) ) - logger.info(f"Done {str(self.project_repo.tests.workdir)} generating.") + logger.info(f"Done {str(self.repo.tests.workdir)} generating.") async def _run_code(self, msg): run_code_context = RunCodeContext.loads(msg.content) - src_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get( - run_code_context.code_filename - ) + src_doc = await self.repo.srcs.get(run_code_context.code_filename) if not src_doc: return - test_doc = await self.project_repo.tests.get(run_code_context.test_filename) + test_doc = await self.repo.tests.get(run_code_context.test_filename) if not test_doc: return run_code_context.code = src_doc.content run_code_context.test_code = test_doc.content result = await RunCode(i_context=run_code_context, context=self.context, llm=self.llm).run() run_code_context.output_filename = run_code_context.test_filename + ".json" - await self.project_repo.test_outputs.save( + await self.repo.test_outputs.save( filename=run_code_context.output_filename, content=result.model_dump_json(), dependencies={src_doc.root_relative_path, test_doc.root_relative_path}, @@ -124,31 +128,53 @@ class QaEngineer(Role): # the recipient might be Engineer or myself recipient = parse_recipient(result.summary) mappings = {"Engineer": "Alex", "QaEngineer": "Edward"} - self.publish_message( - AIMessage( - content=run_code_context.model_dump_json(), - cause_by=RunCode, - send_to=mappings.get(recipient, MESSAGE_ROUTE_TO_NONE), + if recipient != "Engineer": + self.publish_message( + AIMessage( + content=run_code_context.model_dump_json(), + cause_by=RunCode, + instruct_content=self.input_args, + send_to=MESSAGE_ROUTE_TO_SELF, + ) + ) + else: + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] + self.publish_message( + AIMessage( + content=run_code_context.model_dump_json(), + cause_by=RunCode, + instruct_content=self.input_args, + send_to=mappings.get(recipient, MESSAGE_ROUTE_TO_NONE), + ) ) - ) async def _debug_error(self, msg): run_code_context = RunCodeContext.loads(msg.content) - code = await DebugError(i_context=run_code_context, context=self.context, llm=self.llm).run() - await self.project_repo.tests.save(filename=run_code_context.test_filename, content=code) + code = await DebugError( + i_context=run_code_context, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ).run() + await self.repo.tests.save(filename=run_code_context.test_filename, content=code) run_code_context.output = None self.publish_message( AIMessage(content=run_code_context.model_dump_json(), cause_by=DebugError, send_to=MESSAGE_ROUTE_TO_SELF) ) async def _act(self) -> Message: - if self.project_path: - await init_python_folder(self.project_repo.tests.workdir) + if self.input_args.project_path: + await init_python_folder(self.repo.tests.workdir) if self.test_round > self.test_round_allowed: + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] result_msg = AIMessage( content=f"Exceeding {self.test_round_allowed} rounds of tests, stop. " - + "\n".join(list(self.project_repo.tests.changed_files.keys())), + + "\n".join(list(self.repo.tests.changed_files.keys())), cause_by=WriteTest, + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTestOutput"), send_to=MESSAGE_ROUTE_TO_NONE, ) return result_msg @@ -171,8 +197,13 @@ class QaEngineer(Role): elif msg.cause_by == any_to_str(UserRequirement): return await self._parse_user_requirement(msg) self.test_round += 1 + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] return AIMessage( content=f"Round {self.test_round} of tests done", + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTestOutput"), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_NONE, ) @@ -190,3 +221,15 @@ class QaEngineer(Role): if not self.src_workspace: self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name return rsp + + async def _think(self) -> bool: + if not self.rc.news: + return False + msg = self.rc.news[0] + if msg.cause_by == any_to_str(SummarizeCode): + self.input_args = msg.instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + if self.repo.src_relative_path is None: + path = get_project_srcs_path(self.repo.workdir) + self.repo.with_src_path(path) + return True diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 1eaa77fa3..5592841eb 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -45,7 +45,6 @@ from metagpt.schema import ( ) from metagpt.strategy.planner import Planner from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator -from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output if TYPE_CHECKING: @@ -196,29 +195,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel): value.context = self.context self.rc.todo = value - @property - def git_repo(self): - """Git repo""" - return self.context.git_repo - - @git_repo.setter - def git_repo(self, value): - self.context.git_repo = value - - @property - def src_workspace(self): - """Source workspace under git repo""" - return self.context.src_workspace - - @src_workspace.setter - def src_workspace(self, value): - self.context.src_workspace = value - - @property - def project_repo(self) -> ProjectRepo: - project_repo = ProjectRepo(self.context.git_repo) - return project_repo.with_src_path(self.context.src_workspace) if self.context.src_workspace else project_repo - @property def prompt_schema(self): """Prompt schema: json/markdown""" diff --git a/metagpt/schema.py b/metagpt/schema.py index 5af16bc38..5106fad4d 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -31,6 +31,7 @@ from pydantic import ( ConfigDict, Field, PrivateAttr, + create_model, field_serializer, field_validator, model_serializer, @@ -43,13 +44,18 @@ from metagpt.const import ( MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, - PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) from metagpt.logs import logger from metagpt.repo_parser import DotClassInfo -from metagpt.utils.common import CodeParser, any_to_str, any_to_str_set, import_class +from metagpt.utils.common import ( + CodeParser, + any_to_str, + any_to_str_set, + aread, + import_class, +) from metagpt.utils.exceptions import handle_exception from metagpt.utils.report import TaskReporter from metagpt.utils.serialize import ( @@ -157,6 +163,30 @@ class Document(BaseModel): def __repr__(self): return self.content + @classmethod + async def load( + cls, filename: Union[str, Path], project_path: Optional[Union[str, Path]] = None + ) -> Optional["Document"]: + """ + Load a document from a file. + + Args: + filename (Union[str, Path]): The path to the file to load. + project_path (Optional[Union[str, Path]], optional): The path to the project. Defaults to None. + + Returns: + Optional[Document]: The loaded document, or None if the file does not exist. + + """ + if not filename or not Path(filename).exists(): + return None + content = await aread(filename=filename) + doc = cls(content=content, filename=str(filename)) + if project_path and Path(filename).is_relative_to(project_path): + doc.root_path = Path(filename).relative_to(project_path).parent + doc.filename = Path(filename).name + return doc + class Documents(BaseModel): """A class representing a collection of documents. @@ -360,6 +390,22 @@ class Message(BaseModel): def add_metadata(self, key: str, value: str): self.metadata[key] = value + @staticmethod + def create_instruct_value(kvs: Dict[str, Any], class_name: str = "") -> BaseModel: + """ + Dynamically creates a Pydantic BaseModel subclass based on a given dictionary. + + Parameters: + - data: A dictionary from which to create the BaseModel subclass. + + Returns: + - A Pydantic BaseModel subclass instance populated with the given data. + """ + if not class_name: + class_name = "DM" + uuid.uuid4().hex[0:8] + dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()}) + return dynamic_class.model_validate(kvs) + class UserMessage(Message): """便于支持OpenAI的消息 @@ -762,22 +808,6 @@ class CodePlanAndChangeContext(BaseModel): design_filename: str = "" task_filename: str = "" - @staticmethod - def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext: - ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", "")) - for filename in filenames: - filename = Path(filename) - if filename.is_relative_to(PRDS_FILE_REPO): - ctx.prd_filename = filename.name - continue - if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO): - ctx.design_filename = filename.name - continue - if filename.is_relative_to(TASK_FILE_REPO): - ctx.task_filename = filename.name - continue - return ctx - # mermaid class view class UMLClassMeta(BaseModel): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e2520ef13..6d40828af 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -899,3 +899,44 @@ async def init_python_folder(workdir: str | Path): return async with aiofiles.open(init_filename, "a"): os.utime(init_filename, None) + + +def get_markdown_code_block_type(filename: str) -> str: + if not filename: + return "" + ext = Path(filename).suffix + types = { + ".py": "python", + ".js": "javascript", + ".java": "java", + ".cpp": "cpp", + ".c": "c", + ".html": "html", + ".css": "css", + ".xml": "xml", + ".json": "json", + ".yaml": "yaml", + ".md": "markdown", + ".sql": "sql", + ".rb": "ruby", + ".php": "php", + ".sh": "bash", + ".swift": "swift", + ".go": "go", + ".rs": "rust", + ".pl": "perl", + ".asm": "assembly", + ".r": "r", + ".scss": "scss", + ".sass": "sass", + ".lua": "lua", + ".ts": "typescript", + ".tsx": "tsx", + ".jsx": "jsx", + ".yml": "yaml", + ".ini": "ini", + ".toml": "toml", + ".svg": "xml", # SVG can often be treated as XML + # Add more file extensions and corresponding code block types as needed + } + return types.get(ext, "") diff --git a/metagpt/utils/git_repository.py b/metagpt/utils/git_repository.py index 7b09c1775..f3d6350bd 100644 --- a/metagpt/utils/git_repository.py +++ b/metagpt/utils/git_repository.py @@ -156,6 +156,8 @@ class GitRepository: :param local_path: The local path to check. :return: True if the directory is a Git repository, False otherwise. """ + if not local_path: + return False git_dir = Path(local_path) / ".git" if git_dir.exists() and is_git_dir(git_dir): return True diff --git a/metagpt/utils/project_repo.py b/metagpt/utils/project_repo.py index 64ed602a9..5761c0188 100644 --- a/metagpt/utils/project_repo.py +++ b/metagpt/utils/project_repo.py @@ -140,10 +140,11 @@ class ProjectRepo(FileRepository): return bool(code_files) def with_src_path(self, path: str | Path) -> ProjectRepo: - try: - self._srcs_path = Path(path).relative_to(self.workdir) - except ValueError: - self._srcs_path = Path(path) + path = Path(path) + if path.is_relative_to(self.workdir): + self._srcs_path = path.relative_to(self.workdir) + else: + self._srcs_path = path return self @property diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 989e2249c..bc85925a8 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -303,5 +303,4 @@ def test_action_node_from_pydantic_and_print_everything(): if __name__ == "__main__": - test_create_model_class() - test_create_model_class_with_mapping() + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 6f54b062d..48f13f4a2 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -392,5 +392,11 @@ async def test_parse_resources(context, content: str, key_descriptions): assert k in result +@pytest.mark.parametrize(("name", "value"), [("c1", {"age": 10, "name": "Alice"}), ("", {"path": __file__})]) +def test_create_instruct_value(name, value): + obj = Message.create_instruct_value(kvs=value, class_name=name) + assert obj.model_dump() == value + + if __name__ == "__main__": pytest.main([__file__, "-s"]) From d0486f8e11b9e1ff544444ea237b199b81e95874 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Sat, 11 May 2024 19:47:33 +0800 Subject: [PATCH 04/20] add cr reporter --- metagpt/actions/write_code.py | 2 +- metagpt/actions/write_code_review.py | 16 +++++++++++----- metagpt/utils/report.py | 8 ++++++-- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 67b859d23..dc8a6dee5 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -153,7 +153,7 @@ class WriteCode(Action): root_path = self.context.src_workspace if self.context.src_workspace else "" coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) coding_context.code_doc.content = code - await reporter.async_report(self.repo.workdir / coding_context.code_doc.root_relative_path, "path") + await reporter.async_report(coding_context.code_doc, "document") return coding_context @staticmethod diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index f0faea701..1b9f9554b 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -17,6 +17,7 @@ from metagpt.const import REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser +from metagpt.utils.report import EditorReporter PROMPT_TEMPLATE = """ # System @@ -128,16 +129,21 @@ class WriteCodeReview(Action): i_context: CodingContext = Field(default_factory=CodingContext) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) - async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): + async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, doc): + filename = doc.filename cr_rsp = await self._aask(context_prompt + cr_prompt) result = CodeParser.parse_block("Code Review Result", cr_rsp) if "LGTM" in result: return result, None # if LBTM, rewrite code - rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}" - code_rsp = await self._aask(rewrite_prompt) - code = CodeParser.parse_code(text=code_rsp) + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "code", "filename": filename, "src_path": doc.root_relative_path}, "meta") + rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}" + code_rsp = await self._aask(rewrite_prompt) + code = CodeParser.parse_code(text=code_rsp) + doc.content = code + await reporter.async_report(doc, "document") return result, code async def run(self, *args, **kwargs) -> CodingContext: @@ -182,7 +188,7 @@ class WriteCodeReview(Action): f"len(self.i_context.code_doc.content)={len2}" ) result, rewrited_code = await self.write_code_review_and_rewrite( - context_prompt, cr_prompt, self.i_context.code_doc.filename + context_prompt, cr_prompt, self.i_context.code_doc ) if "LBTM" in result: iterative_code = rewrited_code diff --git a/metagpt/utils/report.py b/metagpt/utils/report.py index a61c77381..616a52f30 100644 --- a/metagpt/utils/report.py +++ b/metagpt/utils/report.py @@ -131,7 +131,11 @@ class ResourceReporter(BaseModel): def _format_data(self, value, name): data = self.model_dump(mode="json", exclude=("callback_url", "llm_stream")) - data["value"] = str(value) if isinstance(value, Path) else value + if isinstance(value, BaseModel): + value = value.model_dump(mode="json") + elif isinstance(value, Path): + value = str(value) + data["value"] = value data["name"] = name role = CURRENT_ROLE.get(None) if role: @@ -263,7 +267,7 @@ class FileReporter(ResourceReporter): """Report file resource synchronously.""" return super().report(value, name) - async def async_report(self, value: Path, name: Literal["path", "meta", "content"] = "path"): + async def async_report(self, value: Path, name: Literal["path", "meta", "content", "document"] = "path"): """Report file resource asynchronously.""" return await super().async_report(value, name) From 2079cbf3ae69ff2ae9d67d6d76fb9d963bf0c239 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Wed, 15 May 2024 14:34:15 +0800 Subject: [PATCH 05/20] report abs path --- metagpt/utils/report.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/metagpt/utils/report.py b/metagpt/utils/report.py index 616a52f30..85f9bfa22 100644 --- a/metagpt/utils/report.py +++ b/metagpt/utils/report.py @@ -135,6 +135,9 @@ class ResourceReporter(BaseModel): value = value.model_dump(mode="json") elif isinstance(value, Path): value = str(value) + + if name == "path": + value = os.path.abspath(value) data["value"] = value data["name"] = name role = CURRENT_ROLE.get(None) From 7745e09c397b78846723a4f1bfc5fdf17cd80649 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Fri, 17 May 2024 10:29:25 +0800 Subject: [PATCH 06/20] add thought reporter --- metagpt/roles/di/data_analyst.py | 5 +++-- metagpt/roles/di/data_interpreter.py | 4 +++- metagpt/roles/di/team_leader.py | 4 +++- metagpt/roles/role.py | 5 +++-- metagpt/utils/common.py | 2 +- metagpt/utils/report.py | 11 +++++++++++ 6 files changed, 24 insertions(+), 7 deletions(-) diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index 0fc95b9d6..fc298ea4c 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -20,6 +20,7 @@ from metagpt.strategy.thinking_command import ( ) from metagpt.tools.tool_recommend import BM25ToolRecommender from metagpt.utils.common import CodeParser +from metagpt.utils.report import ThoughtReporter class DataAnalyst(DataInterpreter): @@ -82,8 +83,8 @@ class DataAnalyst(DataInterpreter): available_commands=prepare_command_prompt(self.available_commands), ) context = self.llm.format_msg(self.working_memory.get() + [Message(content=prompt, role="user")]) - - rsp = await self.llm.aask(context) + async with ThoughtReporter(): + rsp = await self.llm.aask(context) self.commands = json.loads(CodeParser.parse_code(block=None, text=rsp)) self.rc.memory.add(Message(content=rsp, role="assistant")) diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py index e147cbbe3..bdfc0e294 100644 --- a/metagpt/roles/di/data_interpreter.py +++ b/metagpt/roles/di/data_interpreter.py @@ -15,6 +15,7 @@ from metagpt.schema import Message, Task, TaskResult from metagpt.strategy.task_type import TaskType from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.utils.common import CodeParser +from metagpt.utils.report import ThoughtReporter REACT_THINK_PROMPT = """ # User Requirement @@ -73,7 +74,8 @@ class DataInterpreter(Role): return True prompt = REACT_THINK_PROMPT.format(user_requirement=self.user_requirement, context=context) - rsp = await self.llm.aask(prompt) + async with ThoughtReporter(): + rsp = await self.llm.aask(prompt) rsp_dict = json.loads(CodeParser.parse_code(text=rsp)) self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant")) need_action = rsp_dict["state"] diff --git a/metagpt/roles/di/team_leader.py b/metagpt/roles/di/team_leader.py index a1ef11fa6..2fa782ade 100644 --- a/metagpt/roles/di/team_leader.py +++ b/metagpt/roles/di/team_leader.py @@ -20,6 +20,7 @@ from metagpt.strategy.thinking_command import ( run_commands, ) from metagpt.utils.common import CodeParser +from metagpt.utils.report import ThoughtReporter class TeamLeader(Role): @@ -69,7 +70,8 @@ class TeamLeader(Role): ) context = self.llm.format_msg(self.get_memories(k=10) + [Message(content=prompt, role="user")]) - rsp = await self.llm.aask(context, system_msgs=[SYSTEM_PROMPT]) + async with ThoughtReporter(): + rsp = await self.llm.aask(context, system_msgs=[SYSTEM_PROMPT]) self.commands = json.loads(CodeParser.parse_code(text=rsp)) self.rc.memory.add(Message(content=rsp, role="assistant")) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 1eaa77fa3..f6d26eeb1 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -47,6 +47,7 @@ from metagpt.strategy.planner import Planner from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output +from metagpt.utils.report import ThoughtReporter if TYPE_CHECKING: from metagpt.environment import Environment # noqa: F401 @@ -381,8 +382,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): n_states=len(self.states) - 1, previous_state=self.rc.state, ) - - next_state = await self.llm.aask(prompt) + async with ThoughtReporter(): + next_state = await self.llm.aask(prompt) next_state = extract_state_value_from_output(next_state) logger.debug(f"{prompt=}") diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e2520ef13..fd7fdcb7a 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -646,7 +646,7 @@ def role_raise_decorator(func): raise Exception(format_trackback_info(limit=None)) except Exception as e: if self.latest_observed_msg: - logger.warning( + logger.exception( "There is a exception in role's execution, in order to resume, " "we delete the newest role communication message in the role's memory." ) diff --git a/metagpt/utils/report.py b/metagpt/utils/report.py index 85f9bfa22..491688f3a 100644 --- a/metagpt/utils/report.py +++ b/metagpt/utils/report.py @@ -39,6 +39,7 @@ class BlockType(str, Enum): GALLERY = "Gallery" NOTEBOOK = "Notebook" DOCS = "Docs" + THOUGHT = "Thought" END_MARKER_NAME = "end_marker" @@ -259,6 +260,16 @@ class TaskReporter(ObjectReporter): block: Literal[BlockType.TASK] = BlockType.TASK +class ThoughtReporter(ObjectReporter): + """Reporter for object resources to Task Block.""" + + block: Literal[BlockType.THOUGHT] = BlockType.THOUGHT + + async def __aenter__(self): + await self.async_report({}) + return await super().__aenter__() + + class FileReporter(ResourceReporter): """File resource callback for reporting complete file paths. From 76e3a14d38caa1a6cbd138598e6cfd5a3b3b0d9d Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Tue, 21 May 2024 10:28:24 +0800 Subject: [PATCH 07/20] add extra field for report --- metagpt/tools/libs/editor.py | 3 ++- metagpt/utils/report.py | 38 ++++++++++++++++++++++++------------ 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index e032dcef5..a2670a2bd 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -100,7 +100,8 @@ class Editor: file_path=file_path, block_content=block_content, ) - self.resource.report(result.file_path, "path") + self.resource.report(result.file_path, "path", + extra={"type": "search", "line_range": {"start": start, "end": end}}) return result return None diff --git a/metagpt/utils/report.py b/metagpt/utils/report.py index 491688f3a..2d72af111 100644 --- a/metagpt/utils/report.py +++ b/metagpt/utils/report.py @@ -56,23 +56,23 @@ class ResourceReporter(BaseModel): callback_url: str = Field(METAGPT_REPORTER_DEFAULT_URL, description="The URL to which the report should be sent") _llm_task: Optional[asyncio.Task] = PrivateAttr(None) - def report(self, value: Any, name: str): + def report(self, value: Any, name: str, extra: Optional[dict] = None): """Synchronously report resource observation data. Args: value: The data to report. name: The type name of the data. """ - return self._report(value, name) + return self._report(value, name, extra) - async def async_report(self, value: Any, name: str): + async def async_report(self, value: Any, name: str, extra: Optional[dict] = None): """Asynchronously report resource observation data. Args: value: The data to report. name: The type name of the data. """ - return await self._async_report(value, name) + return await self._async_report(value, name, extra) @classmethod def set_report_fn(cls, fn: Callable): @@ -101,20 +101,20 @@ class ResourceReporter(BaseModel): """ cls._async_report = fn - def _report(self, value: Any, name: str): + def _report(self, value: Any, name: str, extra: Optional[dict] = None): if not self.callback_url: return - data = self._format_data(value, name) + data = self._format_data(value, name, extra) resp = requests.post(self.callback_url, json=data) resp.raise_for_status() return resp.text - async def _async_report(self, value: Any, name: str): + async def _async_report(self, value: Any, name: str, extra: Optional[dict] = None): if not self.callback_url: return - data = self._format_data(value, name) + data = self._format_data(value, name, extra) url = self.callback_url _result = urlparse(url) sessiion_kwargs = {} @@ -130,7 +130,7 @@ class ResourceReporter(BaseModel): resp.raise_for_status() return await resp.text() - def _format_data(self, value, name): + def _format_data(self, value, name, extra): data = self.model_dump(mode="json", exclude=("callback_url", "llm_stream")) if isinstance(value, BaseModel): value = value.model_dump(mode="json") @@ -147,6 +147,8 @@ class ResourceReporter(BaseModel): else: role_name = os.environ.get("METAGPT_ROLE") data["role"] = role_name + if extra: + data["extra"] = extra return data def __enter__(self): @@ -277,13 +279,23 @@ class FileReporter(ResourceReporter): if the file can be partially output for display first, use streaming callback. """ - def report(self, value: Union[Path, dict, Any], name: Literal["path", "meta", "content"] = "path"): + def report( + self, + value: Union[Path, dict, Any], + name: Literal["path", "meta", "content"] = "path", + extra: Optional[dict] = None, + ): """Report file resource synchronously.""" - return super().report(value, name) + return super().report(value, name, extra) - async def async_report(self, value: Path, name: Literal["path", "meta", "content", "document"] = "path"): + async def async_report( + self, + value: Union[Path, dict, Any], + name: Literal["path", "meta", "content"] = "path", + extra: Optional[dict] = None, + ): """Report file resource asynchronously.""" - return await super().async_report(value, name) + return await super().async_report(value, name, extra) class NotebookReporter(FileReporter): From 49ffb79433ab21feb9a50ed388603f59d37700dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 21 May 2024 21:59:01 +0800 Subject: [PATCH 08/20] feat: remove UserRequirement --- metagpt/roles/architect.py | 23 +++-------------------- metagpt/roles/engineer.py | 29 +++-------------------------- metagpt/roles/project_manager.py | 23 +++-------------------- metagpt/roles/qa_engineer.py | 8 ++------ 4 files changed, 11 insertions(+), 72 deletions(-) diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index 465beff05..9e1761c85 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -6,11 +6,9 @@ @File : architect.py """ -from metagpt.actions import UserRequirement, WritePRD +from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign -from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.roles.role import Role -from metagpt.utils.common import any_to_str class Architect(Role): @@ -36,22 +34,7 @@ class Architect(Role): super().__init__(**kwargs) self.enable_memory = False # Initialize actions specific to the Architect role - self.set_actions([PrepareDocuments(send_to=any_to_str(self), context=self.context), WriteDesign]) + self.set_actions([WriteDesign]) # Set events or actions the Architect should watch or be aware of - self._watch({UserRequirement, PrepareDocuments, WritePRD}) - - async def _think(self) -> bool: - """Decide what to do""" - mappings = { - any_to_str(UserRequirement): 0, - any_to_str(PrepareDocuments): 1, - any_to_str(WritePRD): 1, - } - for i in self.rc.news: - idx = mappings.get(i.cause_by, -1) - if idx < 0: - continue - self.rc.todo = self.actions[idx] - return bool(self.rc.todo) - return False + self._watch({WritePRD}) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 111e534a6..919b4bf13 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -26,7 +26,7 @@ from typing import List, Optional, Set from pydantic import BaseModel, Field -from metagpt.actions import UserRequirement, WriteCode, WriteCodeReview, WriteTasks +from metagpt.actions import WriteCode, WriteCodeReview, WriteTasks from metagpt.actions.fix_bug import FixBug from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST @@ -103,18 +103,7 @@ class Engineer(Role): super().__init__(**kwargs) self.enable_memory = False self.set_actions([WriteCode]) - self._watch( - [ - UserRequirement, - PrepareDocuments, - WriteTasks, - SummarizeCode, - WriteCode, - WriteCodeReview, - FixBug, - WriteCodePlanAndChange, - ] - ) + self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug, WriteCodePlanAndChange]) self.code_todos = [] self.summarize_todos = [] self.next_todo_action = any_to_name(WriteCode) @@ -275,19 +264,7 @@ class Engineer(Role): return False msg = self.rc.news[0] input_args = msg.instruct_content - if msg.cause_by == any_to_str(UserRequirement): - self.rc.todo = PrepareDocuments( - key_descriptions={ - "project_path": 'the project path if exists in "Original Requirement"', - "src_filename": 'the file name of the source code file explicitly requested for modification if exists in "Original Requirement"', - }, - context=self.context, - send_to=any_to_str(self), - ) - self.repo = ProjectRepo(input_args.project_path) - self.input_args = input_args - return bool(self.rc.todo) - elif msg.cause_by in {any_to_str(WriteTasks), any_to_str(FixBug)}: + if msg.cause_by in {any_to_str(WriteTasks), any_to_str(FixBug)}: self.input_args = input_args self.repo = ProjectRepo(input_args.project_path) if self.repo.src_relative_path is None: diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 70bd3bf8b..d6374e673 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -6,11 +6,9 @@ @File : project_manager.py """ -from metagpt.actions import UserRequirement, WriteTasks +from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign -from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.roles.role import Role -from metagpt.utils.common import any_to_str class ProjectManager(Role): @@ -35,20 +33,5 @@ class ProjectManager(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.enable_memory = False - self.set_actions([PrepareDocuments(send_to=any_to_str(self), context=self.context), WriteTasks]) - self._watch([UserRequirement, PrepareDocuments, WriteDesign]) - - async def _think(self) -> bool: - """Decide what to do""" - mappings = { - any_to_str(UserRequirement): 0, - any_to_str(PrepareDocuments): 1, - any_to_str(WriteDesign): 1, - } - for i in self.rc.news: - idx = mappings.get(i.cause_by, -1) - if idx < 0: - continue - self.rc.todo = self.actions[idx] - return bool(self.rc.todo) - return False + self.set_actions([WriteTasks]) + self._watch([WriteDesign]) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 48ed24c2c..4cc9eb9e2 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -55,12 +55,8 @@ class QaEngineer(Role): # FIXME: a bit hack here, only init one action to circumvent _think() logic, # will overwrite _think() in future updates - self.set_actions( - [ - WriteTest, - ] - ) - self._watch([UserRequirement, PrepareDocuments, SummarizeCode, WriteTest, RunCode, DebugError]) + self.set_actions([WriteTest]) + self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 async def _write_test(self, message: Message) -> None: From 6a38d5173362b2fd34de6e8b50963a5a05499ea7 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 23 May 2024 21:16:17 +0800 Subject: [PATCH 09/20] report the search content result --- metagpt/tools/libs/editor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index a2670a2bd..78560e375 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -100,8 +100,7 @@ class Editor: file_path=file_path, block_content=block_content, ) - self.resource.report(result.file_path, "path", - extra={"type": "search", "line_range": {"start": start, "end": end}}) + self.resource.report(result.file_path, "path", extra={"type": "search", "line": i, "symbol": symbol}) return result return None From 9ab36acdfc6f18ce91843b3a3b027ed868f4fd8f Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Tue, 28 May 2024 23:29:20 +0800 Subject: [PATCH 10/20] coalesce the stream output of the notebook --- metagpt/actions/di/execute_nb_code.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/metagpt/actions/di/execute_nb_code.py b/metagpt/actions/di/execute_nb_code.py index b4fe949fe..64620d9cc 100644 --- a/metagpt/actions/di/execute_nb_code.py +++ b/metagpt/actions/di/execute_nb_code.py @@ -65,7 +65,7 @@ class ExecuteNbCode(Action): """execute notebook code block, return result to llm, and display it.""" nb: NotebookNode - nb_client: NotebookClient = None + nb_client: RealtimeOutputNotebookClient = None console: Console interaction: str timeout: int = 600 @@ -78,11 +78,15 @@ class ExecuteNbCode(Action): interaction=("ipython" if self.is_ipython() else "terminal"), ) self.reporter = NotebookReporter() + self.set_nb_client() + + def set_nb_client(self): self.nb_client = RealtimeOutputNotebookClient( - nb, - timeout=timeout, + self.nb, + timeout=self.timeout, resources={"metadata": {"path": DEFAULT_WORKSPACE_ROOT}}, notebook_reporter=self.reporter, + coalesce_streams=True, ) async def build(self): @@ -118,7 +122,7 @@ class ExecuteNbCode(Action): # sleep 1s to wait for the kernel to be cleaned up completely await asyncio.sleep(1) await self.build() - self.nb_client = NotebookClient(self.nb, timeout=self.timeout) + self.set_nb_client() def add_code_cell(self, code: str): self.nb.cells.append(new_code_cell(source=code)) From 3afb8a87f3ab6525eebf6b414b7fd3e7363aea6c Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 30 May 2024 11:00:05 +0800 Subject: [PATCH 11/20] undo thought reporter in the base role --- metagpt/roles/role.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index f6d26eeb1..1eaa77fa3 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -47,7 +47,6 @@ from metagpt.strategy.planner import Planner from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output -from metagpt.utils.report import ThoughtReporter if TYPE_CHECKING: from metagpt.environment import Environment # noqa: F401 @@ -382,8 +381,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): n_states=len(self.states) - 1, previous_state=self.rc.state, ) - async with ThoughtReporter(): - next_state = await self.llm.aask(prompt) + + next_state = await self.llm.aask(prompt) next_state = extract_state_value_from_output(next_state) logger.debug(f"{prompt=}") From 4f43b905a2aa886203291de33f0bb2301ffeccf5 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 30 May 2024 20:04:02 +0800 Subject: [PATCH 12/20] add crawler tools --- examples/di/crawl_webpage.py | 12 ++++++---- metagpt/rag/engines/simple.py | 4 +++- metagpt/tools/libs/browser.py | 45 +++++++++++++++++++++++++++++++---- metagpt/utils/file.py | 8 +++++++ metagpt/utils/parse_html.py | 25 ++++++++++++++++++- 5 files changed, 83 insertions(+), 11 deletions(-) diff --git a/examples/di/crawl_webpage.py b/examples/di/crawl_webpage.py index b8226f4f4..10b230f2b 100644 --- a/examples/di/crawl_webpage.py +++ b/examples/di/crawl_webpage.py @@ -6,16 +6,19 @@ """ from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.tools.libs.browser import Browser as _ + PAPER_LIST_REQ = """" Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, -and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables* +and save it to a csv file. paper title must include `multiagent` or `large language model`. +**Notice: view the page element before writing scraping code** """ ECOMMERCE_REQ = """ Get products data from website https://scrapeme.live/shop/ and save it as a csv file. -**Notice: Firstly parse the web page encoding and the text HTML structure; -The first page product name, price, product URL, and image URL must be saved in the csv;** +The first page product name, price, product URL, and image URL must be saved in the csv. +**Notice: view the page element before writing scraping code** """ NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**; @@ -25,11 +28,12 @@ NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间; 4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。 5. 将全部结果存在本地csv中 +**Notice: view the page element before writing scraping code** """ async def main(): - di = DataInterpreter(tools=["scrape_web_playwright"]) + di = DataInterpreter(tools=["Browser"]) await di.run(ECOMMERCE_REQ) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 5c5810308..623b3f350 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -4,6 +4,7 @@ import json import os from typing import Any, Optional, Union +from fsspec import AbstractFileSystem from llama_index.core import SimpleDirectoryReader, VectorStoreIndex from llama_index.core.callbacks.base import CallbackManager from llama_index.core.embeddings import BaseEmbedding @@ -83,6 +84,7 @@ class SimpleEngine(RetrieverQueryEngine): llm: LLM = None, retriever_configs: list[BaseRetrieverConfig] = None, ranker_configs: list[BaseRankerConfig] = None, + fs: Optional[AbstractFileSystem] = None, ) -> "SimpleEngine": """From docs. @@ -100,7 +102,7 @@ class SimpleEngine(RetrieverQueryEngine): if not input_dir and not input_files: raise ValueError("Must provide either `input_dir` or `input_files`.") - documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() + documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files, fs=fs).load_data() cls._fix_document_metadata(documents) index = VectorStoreIndex.from_documents( diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py index 7fde804fe..223434b3a 100644 --- a/metagpt/tools/libs/browser.py +++ b/metagpt/tools/libs/browser.py @@ -1,9 +1,12 @@ from __future__ import annotations +import contextlib from playwright.async_api import async_playwright - +from metagpt.utils.file import MemoryFileSystem +from uuid import uuid4 from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.tools.tool_registry import register_tool +from metagpt.utils.parse_html import simplify_html from metagpt.utils.report import BrowserReporter @@ -35,16 +38,48 @@ class Browser: print("Now on page ", url) await self._view() - async def open_new_page(self, url: str): + async def open_new_page(self, url: str, timeout: float = 30000): """open a new page in the browser and view the page""" async with self.reporter as reporter: page = await self.browser.new_page() await reporter.async_report(url, "url") - await page.goto(url) + await page.goto(url, timeout=timeout) self.pages[url] = page await self._set_current_page(page, url) await reporter.async_report(page, "page") + async def view_page_element_to_scrape(self, requirement: str, keep_links: bool = False) -> None: + """view the HTML content of current page to understand the structure. When executed, the content will be printed out + + Args: + requirement (str): Providing a clear and detailed requirement helps in focusing the inspection on the desired elements. + keep_links (bool): Whether to keep the hyperlinks in the HTML content. Set to True if links are required + """ + html = await self.current_page.content() + html = simplify_html(html, url=self.current_page.url, keep_links=keep_links) + mem_fs = MemoryFileSystem() + filename = f"{uuid4().hex}.html" + with mem_fs.open(filename, "w") as f: + f.write(html) + + with contextlib.suppress(Exception): + + from metagpt.rag.engines import SimpleEngine # avoid circular import + + # TODO make `from_docs` asynchronous + engine = SimpleEngine.from_docs(input_files=[filename], fs=mem_fs) + nodes = await engine.aretrieve(requirement) + html = "\n".join(i.text for i in nodes) + + mem_fs.rm_file(filename) + print(html) + + async def get_page_content(self) -> str: + """Get the HTML content of current page.""" + html = await self.current_page.content() + html_content = html.strip() + return html_content + async def switch_page(self, url: str): """switch to an opened page in the browser and view the page""" if url in self.pages: @@ -152,8 +187,8 @@ class Browser: async def _view(self, keep_len: int = 5000) -> str: """simulate human viewing the current page, return the visible text with links""" - visible_text_with_links = await self.current_page.evaluate(VIEW_CONTENT_JS) - print("The visible text and their links (if any): ", visible_text_with_links[:keep_len]) + # visible_text_with_links = await self.current_page.evaluate(VIEW_CONTENT_JS) + # print("The visible text and their links (if any): ", visible_text_with_links[:keep_len]) # html_content = await self._view_page_html(keep_len=keep_len) # print("The html content: ", html_content) diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index f62b44eb8..a8ed482d9 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -9,6 +9,7 @@ from pathlib import Path import aiofiles +from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem from metagpt.logs import logger from metagpt.utils.exceptions import handle_exception @@ -68,3 +69,10 @@ class File: content = b"".join(chunks) logger.debug(f"Successfully read file, the path of file: {file_path}") return content + + +class MemoryFileSystem(_MemoryFileSystem): + + @classmethod + def _strip_protocol(cls, path): + return super()._strip_protocol(str(path)) diff --git a/metagpt/utils/parse_html.py b/metagpt/utils/parse_html.py index 65aa3f236..3aac8ca6c 100644 --- a/metagpt/utils/parse_html.py +++ b/metagpt/utils/parse_html.py @@ -7,6 +7,8 @@ from urllib.parse import urljoin, urlparse from bs4 import BeautifulSoup from pydantic import BaseModel, PrivateAttr +import htmlmin + class WebPage(BaseModel): inner_text: str @@ -38,6 +40,22 @@ class WebPage(BaseModel): elif url.startswith(("http://", "https://")): yield urljoin(self.url, url) + def get_slim_soup(self, keep_links: bool = False): + soup = _get_soup(self.html) + keep_attrs = ["class"] + if keep_links: + keep_attrs.append("href") + + for i in soup.find_all(True): + for name in list(i.attrs): + if i[name] and name not in keep_attrs: + del i[name] + + for i in soup.find_all(["svg", "img", "video", "audio"]): + i.decompose() + + return soup + def get_html_content(page: str, base: str): soup = _get_soup(page) @@ -48,7 +66,12 @@ def get_html_content(page: str, base: str): def _get_soup(page: str): soup = BeautifulSoup(page, "html.parser") # https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup - for s in soup(["style", "script", "[document]", "head", "title"]): + for s in soup(["style", "script", "[document]", "head", "title", "footer"]): s.extract() return soup + + +def simplify_html(html: str, url: str, keep_links: bool = False): + html = WebPage(inner_text="", html=html, url=url).get_slim_soup(keep_links).decode() + return htmlmin.minify(html, remove_comments=True, remove_empty_space=True) From aeef03e29c48b6ca64244ec949850129cb1b53d1 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 30 May 2024 21:01:35 +0800 Subject: [PATCH 13/20] update the requirements --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b40c69c9f..83a904156 100644 --- a/requirements.txt +++ b/requirements.txt @@ -71,4 +71,6 @@ dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation gymnasium==0.29.1 pylint~=3.0.3 -pygithub~=2.3 \ No newline at end of file +pygithub~=2.3 +htmlmin +fsspec From 9dc5212d4709733d63ffd0b9999660939f7369fd Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Fri, 31 May 2024 15:50:05 +0800 Subject: [PATCH 14/20] Add explanation for error suppression in the method --- metagpt/tools/libs/browser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py index 223434b3a..8d6daec11 100644 --- a/metagpt/tools/libs/browser.py +++ b/metagpt/tools/libs/browser.py @@ -62,6 +62,7 @@ class Browser: with mem_fs.open(filename, "w") as f: f.write(html) + # Since RAG is an optional optimization, if it fails, the simplified HTML can be used as a fallback. with contextlib.suppress(Exception): from metagpt.rag.engines import SimpleEngine # avoid circular import From f3b839847b76fcb939c97d3aebc86cf0f8c438fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Fri, 31 May 2024 17:06:46 +0800 Subject: [PATCH 15/20] feat: Implement Chapter 3 of RFC 236. --- metagpt/actions/design_api.py | 144 +++++++++++++++++++++++++- metagpt/actions/project_management.py | 49 ++++++++- metagpt/actions/write_prd.py | 119 +++++++++++++++++++-- metagpt/utils/common.py | 7 ++ 4 files changed, 302 insertions(+), 17 deletions(-) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 2e84cc463..8bf11356a 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -10,8 +10,9 @@ @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. """ import json +import uuid from pathlib import Path -from typing import Optional +from typing import List, Optional from pydantic import BaseModel, Field @@ -27,6 +28,7 @@ from metagpt.actions.design_api_an import ( from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.utils.common import aread, awrite, to_markdown_code_block from metagpt.utils.mermaid import mermaid_to_file from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter, GalleryReporter @@ -51,7 +53,116 @@ class WriteDesign(Action): repo: Optional[ProjectRepo] = Field(default=None, exclude=True) input_args: Optional[BaseModel] = Field(default=None, exclude=True) - async def run(self, with_messages: Message, schema: str = None): + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + prd_filename: str = "", + exists_design_filename: str = "", + extra_info: str = "", + output_path: str = "", + **kwargs, + ) -> AIMessage: + """ + Write a system design. + + Args: + user_requirement (str): The user's requirements for the system design. + prd_filename (str, optional): The filename of the Product Requirement Document (PRD). + exists_design_filename (str, optional): The filename of the existing design document. + extra_info (str, optional): Additional information to be included in the system design. + output_path (str, optional): The output path where the system design should be saved. + + Returns: + AIMessage: An AIMessage object containing the system design. + + Example: + # Write a new system design. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info) + >>> print(result.content) + The design is balabala... + + # Modify an exists system design. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> exists_design_filename = "/path/to/exists/design/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, exists_design_filename=exists_design_filename) + >>> print(result.content) + The design is balabala... + + # Write a new system design with the given PRD(Product Requirement Document). + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> prd_filename = "/path/to/prd/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename) + >>> print(result.content) + The design is balabala... + + # Modify an exists system design with the given PRD(Product Requirement Document). + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> prd_filename = "/path/to/prd/filename" + >>> exists_design_filename = "/path/to/exists/design/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, exists_design_filename=exists_design_filename, prd_filename=prd_filename) + >>> print(result.content) + The design is balabala... + + # Write a new system design and save to the directory. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> output_path = "/path/to/save/" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, output_path=output_path) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Modify an exists system design and save to the directory. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> exists_design_filename = "/path/to/exists/design/filename" + >>> output_path = "/path/to/save/" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, exists_design_filename=exists_design_filename) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Write a new system design with the given PRD(Product Requirement Document) and save to the directory. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> prd_filename = "/path/to/prd/filename" + >>> output_path = "/path/to/save/" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Modify an exists system design with the given PRD(Product Requirement Document) and save to the directory. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> prd_filename = "/path/to/prd/filename" + >>> exists_design_filename = "/path/to/exists/design/filename" + >>> output_path = "/path/to/save/" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, exists_design_filename=exists_design_filename, prd_filename=prd_filename) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, + prd_filename=prd_filename, + exists_design_filename=exists_design_filename, + extra_info=extra_info, + output_path=output_path, + ) + self.input_args = with_messages[0].instruct_content self.repo = ProjectRepo(self.input_args.project_path) changed_prds = self.input_args.changed_prd_filenames @@ -147,3 +258,32 @@ class WriteDesign(Action): image_path = pathname.parent / f"{pathname.name}.png" if image_path.exists(): await GalleryReporter().async_report(image_path, "path") + + async def _execute_api( + self, + user_requirement: str = "", + prd_filename: str = "", + exists_design_filename: str = "", + extra_info: str = "", + output_path: str = "", + ) -> AIMessage: + context = to_markdown_code_block(user_requirement) + if extra_info: + context = to_markdown_code_block(extra_info) + if prd_filename: + prd_content = await aread(filename=prd_filename) + context += to_markdown_code_block(prd_content) + if not exists_design_filename: + node = await self._new_system_design(context=context) + design = Document(content=node.instruct_content.model_dump_json()) + else: + old_design_content = await aread(filename=exists_design_filename) + design = await self._merge( + prd_doc=Document(content=context), system_design_doc=Document(content=old_design_content) + ) + + if not output_path: + return AIMessage(content=design.instruct_content.model_dump_json()) + output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json" + await awrite(filename=output_filename, data=design.content) + return AIMessage(content=f'System Design filename: "{str(output_filename)}"') diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 55356f58b..9880a10f3 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -12,7 +12,7 @@ import json from pathlib import Path -from typing import Optional +from typing import List, Optional from pydantic import BaseModel, Field @@ -20,7 +20,8 @@ from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME from metagpt.logs import logger -from metagpt.schema import AIMessage, Document, Documents +from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.utils.common import aread, to_markdown_code_block from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter @@ -39,7 +40,39 @@ class WriteTasks(Action): repo: Optional[ProjectRepo] = Field(default=None, exclude=True) input_args: Optional[BaseModel] = Field(default=None, exclude=True) - async def run(self, with_messages): + async def run( + self, with_messages: List[Message] = None, *, user_requirement: str = "", design_filename: str = "", **kwargs + ) -> AIMessage: + """ + Write a project schedule given a project system design file. + + Args: + user_requirement (str, optional): A string specifying the user's requirements. Defaults to an empty string. + design_filename (str): The filename of the project system design file. Defaults to an empty string. + **kwargs: Additional keyword arguments. + + Returns: + AIMessage: The generated project schedule. + + Example: + # Write a new project schedule. + >>> design_filename = "/path/to/design/filename" + >>> action = WriteTasks() + >>> result = await action.run(design_filename=design_filename) + >>> print(result.content) + The project schedule is balabala... + + # Write a new project schedule with the user requirement. + >>> design_filename = "/path/to/design/filename" + >>> user_requirement = "Your user requirements" + >>> action = WriteTasks() + >>> result = await action.run(design_filename=design_filename, user_requirement=user_requirement) + >>> print(result.content) + The project schedule is balabala... + """ + if not with_messages: + return await self._execute_api(user_requirement=user_requirement, design_filename=design_filename) + self.input_args = with_messages[0].instruct_content self.repo = ProjectRepo(self.input_args.project_path) changed_system_designs = self.input_args.changed_system_design_filenames @@ -99,7 +132,7 @@ class WriteTasks(Action): await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") return task_doc - async def _run_new_tasks(self, context): + async def _run_new_tasks(self, context: str): node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema) return node @@ -121,3 +154,11 @@ class WriteTasks(Action): continue packages.add(pkg) await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages)) + + async def _execute_api(self, user_requirement: str = "", design_filename: str = ""): + context = to_markdown_code_block(user_requirement) + if not design_filename: + content = await aread(filename=design_filename) + context += to_markdown_code_block(content) + node = await self._run_new_tasks(context) + return AIMessage(content=node.instruct_content.model_dump_json()) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 3275619f7..0584a247f 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -14,8 +14,9 @@ from __future__ import annotations import json +import uuid from pathlib import Path -from typing import Optional +from typing import List, Optional from pydantic import BaseModel, Field @@ -37,7 +38,7 @@ from metagpt.const import ( ) from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message -from metagpt.utils.common import CodeParser +from metagpt.utils.common import CodeParser, aread, awrite, to_markdown_code_block from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file from metagpt.utils.project_repo import ProjectRepo @@ -73,8 +74,75 @@ class WritePRD(Action): repo: Optional[ProjectRepo] = Field(default=None, exclude=True) input_args: Optional[BaseModel] = Field(default=None, exclude=True) - async def run(self, with_messages, *args, **kwargs) -> Message: - """Run the action.""" + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + output_path: str = "", + exists_prd_filename: str = "", + extra_info: str = "", + **kwargs, + ) -> AIMessage: + """ + Write a Product Requirement Document. + + Args: + user_requirement (str): A string detailing the user's requirements. + output_path (str, optional): The file path where the output document should be saved. Defaults to "". + exists_prd_filename (str, optional): The file path of an existing Product Requirement Document to use as a reference. Defaults to "". + extra_info (str, optional): Additional information to include in the document. Defaults to "". + **kwargs: Additional keyword arguments. + + Returns: + AIMessage: The resulting message after generating the Product Requirement Document. + + Example: + # Write a new PRD(Product Requirement Document) + >>> user_requirement = "YOUR REQUIREMENTS" + >>> extra_info = "YOUR EXTRA INFO" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info) + >>> print(result.content) + The PRD is about balabala... + + # Modify a exists PRD(Product Requirement Document) + >>> user_requirement = "YOUR REQUIREMENTS" + >>> extra_info = "YOUR EXTRA INFO" + >>> exists_prd_filename = "/path/to/exists/prd_filename" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, exists_prd_filename=exists_prd_filename) + >>> print(result.content) + The PRD is about balabala... + + # Write and save a new PRD(Product Requirement Document) to the directory. + >>> user_requirement = "YOUR REQUIREMENTS" + >>> extra_info = "YOUR EXTRA INFO" + >>> output_path = "/path/to/prd/directory/" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, output_path=output_path) + >>> print(result.content) + PRD filename: "/path/to/prd/directory/213434ad.json" + + # Modify a exists PRD(Product Requirement Document) and save to the directory. + >>> user_requirement = "YOUR REQUIREMENTS" + >>> extra_info = "YOUR EXTRA INFO" + >>> exists_prd_filename = "/path/to/exists/prd_filename" + >>> output_path = "/path/to/prd/directory/" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, exists_prd_filename=exists_prd_filename, output_path=output_path) + >>> print(result.content) + PRD filename: "/path/to/prd/directory/213434ad.json" + + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, + output_path=output_path, + exists_prd_filename=exists_prd_filename, + extra_info=extra_info, + ) + self.input_args = with_messages[-1].instruct_content if not self.input_args: self.repo = ProjectRepo(self.config.project_path) @@ -110,6 +178,7 @@ class WritePRD(Action): else: logger.info(f"New requirement detected: {req.content}") await self._handle_new_requirement(req) + kvs = self.input_args.model_dump() kvs["changed_prd_filenames"] = [ str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys()) @@ -143,16 +212,20 @@ class WritePRD(Action): send_to="Alex", # the name of Engineer ) + async def _new_prd(self, requirement: str) -> ActionNode: + project_name = self.project_name + context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name) + exclude = [PROJECT_NAME.key] if project_name else [] + node = await WRITE_PRD_NODE.fill( + context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema + ) # schema=schema + return node + async def _handle_new_requirement(self, req: Document) -> ActionOutput: """handle new requirement""" async with DocsReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "prd"}, "meta") - project_name = self.project_name - context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name) - exclude = [PROJECT_NAME.key] if project_name else [] - node = await WRITE_PRD_NODE.fill( - context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema - ) # schema=schema + node = await self._new_prd(req.content) await self._rename_workspace(node) new_prd_doc = await self.repo.docs.prd.save( filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json() @@ -223,4 +296,28 @@ class WritePRD(Action): ws_name = CodeParser.parse_str(block="Project Name", text=prd) if ws_name: self.project_name = ws_name - self.repo.git_repo.rename_root(self.project_name) + if self.repo: + self.repo.git_repo.rename_root(self.project_name) + + async def _execute_api( + self, user_requirement: str, output_path: str, exists_prd_filename: str, extra_info: str + ) -> AIMessage: + content = to_markdown_code_block(val=user_requirement) + if extra_info: + content += to_markdown_code_block(val=extra_info) + + req = Document(content=content) + if not exists_prd_filename: + node = await self._new_prd(requirement=req.content) + new_prd = Document(content=node.instruct_content.model_dump_json()) + else: + content = await aread(filename=exists_prd_filename) + old_prd = Document(content=content) + new_prd = await self._merge(req=req, related_doc=old_prd) + + if not output_path: + return AIMessage(content=new_prd.content) + + output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json" + await awrite(filename=output_filename, data=new_prd.content) + return AIMessage(content=f'PRD filename: "{str(output_filename)}"') diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 6d40828af..9e9bb034c 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -667,6 +667,8 @@ def role_raise_decorator(func): @handle_exception async def aread(filename: str | Path, encoding="utf-8") -> str: """Read file asynchronously.""" + if not filename or not Path(filename).exists(): + return "" try: async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader: content = await reader.read() @@ -940,3 +942,8 @@ def get_markdown_code_block_type(filename: str) -> str: # Add more file extensions and corresponding code block types as needed } return types.get(ext, "") + + +def to_markdown_code_block(val: str, type_: str = "") -> str: + val = val.replace("```", "\\`\\`\\`") + return f"\n```{type_}\n{val}\n```\n" From ce3260038a2da5843039aaa0b58b2ef41a04ca88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Sat, 1 Jun 2024 18:47:40 +0800 Subject: [PATCH 16/20] feat: software action + api interface --- metagpt/actions/design_api.py | 14 ++- metagpt/actions/prepare_documents.py | 4 +- metagpt/actions/project_management.py | 7 +- metagpt/actions/write_prd.py | 15 ++- metagpt/context.py | 19 --- metagpt/environment/base_env.py | 6 +- metagpt/roles/product_manager.py | 3 +- metagpt/roles/qa_engineer.py | 2 +- metagpt/roles/role.py | 3 +- metagpt/software_company.py | 2 +- tests/conftest.py | 13 +- tests/metagpt/actions/test_design_api.py | 114 ++++++++++++++---- .../actions/test_project_management.py | 59 ++++++--- tests/metagpt/actions/test_write_code.py | 7 +- tests/metagpt/actions/test_write_prd.py | 77 ++++++++---- 15 files changed, 231 insertions(+), 114 deletions(-) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 8bf11356a..981e1405a 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -8,6 +8,7 @@ 1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. 2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ import json import uuid @@ -28,6 +29,7 @@ from metagpt.actions.design_api_an import ( from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import aread, awrite, to_markdown_code_block from metagpt.utils.mermaid import mermaid_to_file from metagpt.utils.project_repo import ProjectRepo @@ -42,6 +44,7 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write system design"]) class WriteDesign(Action): name: str = "" i_context: Optional[str] = None @@ -163,7 +166,7 @@ class WriteDesign(Action): output_path=output_path, ) - self.input_args = with_messages[0].instruct_content + self.input_args = with_messages[-1].instruct_content self.repo = ProjectRepo(self.input_args.project_path) changed_prds = self.input_args.changed_prd_filenames changed_system_designs = [ @@ -283,7 +286,12 @@ class WriteDesign(Action): ) if not output_path: - return AIMessage(content=design.instruct_content.model_dump_json()) + return AIMessage(content=design.content) output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json" await awrite(filename=output_filename, data=design.content) - return AIMessage(content=f'System Design filename: "{str(output_filename)}"') + kvs = {"changed_system_design_filenames": [output_filename]} + + return AIMessage( + content=f'System Design filename: "{str(output_filename)}"', + instruct_content=AIMessage.create_instruct_value(kvs=kvs), + ) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 89ebd59a3..393c483cc 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -46,8 +46,8 @@ class PrepareDocuments(Action): path = Path(self.config.project_path) if path.exists() and not self.config.inc: shutil.rmtree(path) - self.config.project_path = path - self.context.set_repo_dir(path) + self.context.kwargs.project_path = path + self.context.kwargs.inc = self.config.inc return ProjectRepo(path) async def run(self, with_messages, **kwargs): diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 9880a10f3..b44bfb9f3 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -8,6 +8,7 @@ 1. Divide the context into three components: legacy code, unit test code, and console log. 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ import json @@ -21,6 +22,7 @@ from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import aread, to_markdown_code_block from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter @@ -34,6 +36,7 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write a project schedule given a project system design file"]) class WriteTasks(Action): name: str = "CreateTasks" i_context: Optional[str] = None @@ -73,7 +76,7 @@ class WriteTasks(Action): if not with_messages: return await self._execute_api(user_requirement=user_requirement, design_filename=design_filename) - self.input_args = with_messages[0].instruct_content + self.input_args = with_messages[-1].instruct_content self.repo = ProjectRepo(self.input_args.project_path) changed_system_designs = self.input_args.changed_system_design_filenames changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())] @@ -82,7 +85,7 @@ class WriteTasks(Action): # `docs/system_designs/`. for filename in changed_system_designs: task_doc = await self._update_tasks(filename=filename) - change_files.docs[filename] = task_doc + change_files.docs[str(self.repo.docs.task.workdir / task_doc.filename)] = task_doc # Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`. for filename in changed_tasks: diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 0584a247f..de3bcde84 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -9,6 +9,7 @@ 2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality. 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ from __future__ import annotations @@ -38,6 +39,7 @@ from metagpt.const import ( ) from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, aread, awrite, to_markdown_code_block from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file @@ -64,6 +66,7 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write product requirement documents"]) class WritePRD(Action): """WritePRD deal with the following situations: 1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated. @@ -145,11 +148,11 @@ class WritePRD(Action): self.input_args = with_messages[-1].instruct_content if not self.input_args: - self.repo = ProjectRepo(self.config.project_path) + self.repo = ProjectRepo(self.context.kwargs.project_path) await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content) self.input_args = AIMessage.create_instruct_value( kvs={ - "project_path": self.config.project_path, + "project_path": self.context.kwargs.project_path, "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), "prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files], }, @@ -183,6 +186,9 @@ class WritePRD(Action): kvs["changed_prd_filenames"] = [ str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys()) ] + kvs["project_path"] = str(self.repo.workdir) + kvs["requirements_filename"] = str(self.repo.docs.workdir / REQUIREMENT_FILENAME) + self.context.kwargs.project_path = str(self.repo.workdir) return AIMessage( content="PRD is completed. " + "\n".join( @@ -302,7 +308,7 @@ class WritePRD(Action): async def _execute_api( self, user_requirement: str, output_path: str, exists_prd_filename: str, extra_info: str ) -> AIMessage: - content = to_markdown_code_block(val=user_requirement) + content = to_markdown_code_block(val=user_requirement, type_="text") if extra_info: content += to_markdown_code_block(val=extra_info) @@ -320,4 +326,5 @@ class WritePRD(Action): output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json" await awrite(filename=output_filename, data=new_prd.content) - return AIMessage(content=f'PRD filename: "{str(output_filename)}"') + kvs = AIMessage.create_instruct_value({"changed_prd_filenames": [str(output_filename)]}) + return AIMessage(content=f'PRD filename: "{str(output_filename)}"', instruct_content=kvs) diff --git a/metagpt/context.py b/metagpt/context.py index f1c3568d9..384e8da48 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -8,7 +8,6 @@ from __future__ import annotations import os -from pathlib import Path from typing import Any, Dict, Optional from pydantic import BaseModel, ConfigDict @@ -22,8 +21,6 @@ from metagpt.utils.cost_manager import ( FireworksCostManager, TokenCostManager, ) -from metagpt.utils.git_repository import GitRepository -from metagpt.utils.project_repo import ProjectRepo class AttrDict(BaseModel): @@ -66,9 +63,6 @@ class Context(BaseModel): kwargs: AttrDict = AttrDict() config: Config = Config.default() - repo: Optional[ProjectRepo] = None - git_repo: Optional[GitRepository] = None - src_workspace: Optional[Path] = None cost_manager: CostManager = CostManager() _llm: Optional[BaseLLM] = None @@ -80,11 +74,6 @@ class Context(BaseModel): # env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env - def set_repo_dir(self, path: str | Path): - repo_path = Path(path) - self.git_repo = GitRepository(local_path=repo_path, auto_init=True) - self.repo = ProjectRepo(self.git_repo) - def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: """Return a CostManager instance""" if llm_config.api_type == LLMType.FIREWORKS: @@ -117,7 +106,6 @@ class Context(BaseModel): Dict[str, Any]: A dictionary containing serialized data. """ return { - "workdir": str(self.repo.workdir) if self.repo else "", "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, "cost_manager": self.cost_manager.model_dump_json(), } @@ -130,13 +118,6 @@ class Context(BaseModel): """ if not serialized_data: return - workdir = serialized_data.get("workdir") - if workdir: - self.git_repo = GitRepository(local_path=workdir, auto_init=True) - self.repo = ProjectRepo(self.git_repo) - src_workspace = self.git_repo.workdir / self.git_repo.workdir.name - if src_workspace.exists(): - self.src_workspace = src_workspace kwargs = serialized_data.get("kwargs") if kwargs: for k, v in kwargs.items(): diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index fe1660fc6..5d6d3a286 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -22,6 +22,7 @@ from metagpt.logs import logger from metagpt.memory import Memory from metagpt.schema import Message from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to +from metagpt.utils.git_repository import GitRepository if TYPE_CHECKING: from metagpt.roles.role import Role # noqa: F401 @@ -243,8 +244,9 @@ class Environment(ExtEnv): self.member_addrs[obj] = addresses def archive(self, auto_archive=True): - if auto_archive and self.context.git_repo: - self.context.git_repo.archive() + if auto_archive and self.context.kwargs.get("project_path"): + git_repo = GitRepository(self.context.kwargs.project_path) + git_repo.archive() @classmethod def model_rebuild(cls, **kwargs): diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 58d8076ab..1f66758ea 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -10,7 +10,7 @@ from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments -from metagpt.roles.role import Role +from metagpt.roles.role import Role, RoleReactMode from metagpt.utils.common import any_to_name, any_to_str from metagpt.utils.git_repository import GitRepository @@ -35,6 +35,7 @@ class ProductManager(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.enable_memory = False + self.rc.react_mode = RoleReactMode.BY_ORDER self.set_actions([PrepareDocuments(send_to=any_to_str(self)), WritePRD]) self._watch([UserRequirement, PrepareDocuments]) self.todo_action = any_to_name(WritePRD) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 4cc9eb9e2..fc8fa5353 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -94,7 +94,7 @@ class QaEngineer(Role): code_filename=context.code_doc.filename, test_filename=context.test_doc.filename, working_directory=str(self.repo.workdir), - additional_python_paths=[str(self.context.src_workspace)], + additional_python_paths=[str(self.repo.srcs.workdir)], ) self.publish_message( AIMessage(content=run_code_context.model_dump_json(), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_SELF) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 5592841eb..344e1df5e 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -386,8 +386,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): msg = response else: msg = AIMessage(content=response or "", cause_by=self.rc.todo, sent_from=self) - if self.enable_memory: - self.rc.memory.add(msg) + self.rc.memory.add(msg) return msg diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 7f0c56388..2ea16f55f 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -68,7 +68,7 @@ def generate_repo( company.run_project(idea, send_to=any_to_str(ProductManager)) asyncio.run(company.run(n_round=n_round)) - return ctx.repo + return ctx.kwargs.get("project_path") @app.command("", help="Start a new project.") diff --git a/tests/conftest.py b/tests/conftest.py index f26ab2ef9..1f6661f7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,6 @@ import logging import os import re import uuid -from pathlib import Path from typing import Callable import aiohttp.web @@ -23,7 +22,6 @@ from metagpt.context import Context as MetagptContext from metagpt.llm import LLM from metagpt.logs import logger from metagpt.utils.git_repository import GitRepository -from metagpt.utils.project_repo import ProjectRepo from tests.mock.mock_aiohttp import MockAioResponse from tests.mock.mock_curl_cffi import MockCurlCffiResponse from tests.mock.mock_httplib2 import MockHttplib2Response @@ -149,13 +147,14 @@ def loguru_caplog(caplog): @pytest.fixture(scope="function") def context(request): ctx = MetagptContext() - ctx.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") - ctx.repo = ProjectRepo(ctx.git_repo) + repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") + ctx.config.project_path = str(repo.workdir) # Destroy git repo at the end of the test session. def fin(): - if ctx.git_repo: - ctx.git_repo.delete_repository() + if ctx.config.project_path: + git_repo = GitRepository(ctx.config.project_path) + git_repo.delete_repository() # Register the function for destroying the environment. request.addfinalizer(fin) @@ -279,6 +278,6 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache): @pytest.fixture def git_dir(): """Fixture to get the unittest directory.""" - git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}" + git_dir = DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}" git_dir.mkdir(parents=True, exist_ok=True) return git_dir diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index 9924a2e84..1351b418a 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -6,37 +6,105 @@ @File : test_design_api.py @Modifiled By: mashenquan, 2023-12-6. According to RFC 135 """ +import json + import pytest from metagpt.actions.design_api import WriteDesign -from metagpt.llm import LLM +from metagpt.const import METAGPT_ROOT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message +from metagpt.utils.project_repo import ProjectRepo from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON @pytest.mark.asyncio -async def test_design_api(context): - inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE - for prd in inputs: - await context.repo.docs.prd.save(filename="new_prd.txt", content=prd) +async def test_design(context): + # Mock new design env + prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。" + context.kwargs.project_path = context.config.project_path + context.kwargs.inc = False + filename = "prd.txt" + repo = ProjectRepo(context.kwargs.project_path) + await repo.docs.prd.save(filename=filename, content=prd) + kvs = { + "project_path": str(context.kwargs.project_path), + "changed_prd_filenames": [str(repo.docs.prd.workdir / filename)], + } + instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput") - design_api = WriteDesign(context=context) - - result = await design_api.run(Message(content=prd, instruct_content=None)) - logger.info(result) - - assert result - - -@pytest.mark.asyncio -async def test_refined_design_api(context): - await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON)) - await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE) - - design_api = WriteDesign(context=context, llm=LLM()) - - result = await design_api.run(Message(content="", instruct_content=None)) + design_api = WriteDesign(context=context) + result = await design_api.run([Message(content=prd, instruct_content=instruct_content)]) logger.info(result) - assert result + assert isinstance(result, AIMessage) + assert result.instruct_content + assert repo.docs.system_design.changed_files + + # Mock incremental design env + context.kwargs.inc = True + await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON)) + await repo.docs.system_design.save(filename=filename, content=DESIGN_SAMPLE) + + result = await design_api.run([Message(content="", instruct_content=instruct_content)]) + logger.info(result) + assert result + assert isinstance(result, AIMessage) + assert result.instruct_content + assert repo.docs.system_design.changed_files + + +@pytest.mark.parametrize( + ("user_requirement", "prd_filename", "exists_design_filename"), + [ + ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), + ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), + ( + "write 2048 game", + str(METAGPT_ROOT / "tests/data/prd.json"), + str(METAGPT_ROOT / "tests/data/system_design.json"), + ), + ], +) +@pytest.mark.asyncio +async def test_design_api(context, user_requirement, prd_filename, exists_design_filename): + action = WriteDesign() + result = await action.run( + user_requirement=user_requirement, prd_filename=prd_filename, exists_design_filename=exists_design_filename + ) + assert isinstance(result, AIMessage) + assert result.content + m = json.loads(result.content) + assert m + + +@pytest.mark.parametrize( + ("user_requirement", "prd_filename", "exists_design_filename"), + [ + ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), + ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), + ( + "write 2048 game", + str(METAGPT_ROOT / "tests/data/prd.json"), + str(METAGPT_ROOT / "tests/data/system_design.json"), + ), + ], +) +@pytest.mark.asyncio +async def test_design_api_dir(context, user_requirement, prd_filename, exists_design_filename): + action = WriteDesign() + result = await action.run( + user_requirement=user_requirement, + prd_filename=prd_filename, + exists_design_filename=exists_design_filename, + output_path=context.config.project_path, + ) + assert isinstance(result, AIMessage) + assert result.content + assert str(context.config.project_path) in result.content + assert result.instruct_content + assert result.instruct_content.changed_system_design_filenames + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_project_management.py b/tests/metagpt/actions/test_project_management.py index 5d0d11efb..26699dea7 100644 --- a/tests/metagpt/actions/test_project_management.py +++ b/tests/metagpt/actions/test_project_management.py @@ -5,13 +5,15 @@ @Author : alexanderwu @File : test_project_management.py """ +import json import pytest from metagpt.actions.project_management import WriteTasks -from metagpt.llm import LLM +from metagpt.const import METAGPT_ROOT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message +from metagpt.utils.project_repo import ProjectRepo from tests.data.incremental_dev_project.mock import ( REFINED_DESIGN_JSON, REFINED_PRD_JSON, @@ -22,29 +24,46 @@ from tests.metagpt.actions.mock_json import DESIGN, PRD @pytest.mark.asyncio async def test_task(context): - await context.repo.docs.prd.save("1.txt", content=str(PRD)) - await context.repo.docs.system_design.save("1.txt", content=str(DESIGN)) - logger.info(context.git_repo) + # Mock write tasks env + context.kwargs.project_path = context.config.project_path + context.kwargs.inc = False + repo = ProjectRepo(context.kwargs.project_path) + filename = "1.txt" + await repo.docs.prd.save(filename=filename, content=str(PRD)) + await repo.docs.system_design.save(filename=filename, content=str(DESIGN)) + kvs = { + "project_path": context.kwargs.project_path, + "changed_system_design_filenames": [str(repo.docs.system_design.workdir / filename)], + } + instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput") action = WriteTasks(context=context) - - result = await action.run(Message(content="", instruct_content=None)) + result = await action.run([Message(content="", instruct_content=instruct_content)]) logger.info(result) - assert result + assert result.instruct_content.changed_task_filenames + + # Mock incremental env + context.kwargs.inc = True + await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON)) + await repo.docs.system_design.save(filename=filename, content=str(REFINED_DESIGN_JSON)) + await repo.docs.task.save(filename=filename, content=TASK_SAMPLE) + + result = await action.run([Message(content="", instruct_content=instruct_content)]) + logger.info(result) + assert result + assert result.instruct_content.changed_task_filenames @pytest.mark.asyncio -async def test_refined_task(context): - await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON)) - await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON)) - await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE) - - logger.info(context.git_repo) - - action = WriteTasks(context=context, llm=LLM()) - - result = await action.run(Message(content="", instruct_content=None)) - logger.info(result) - +async def test_task_api(context): + action = WriteTasks() + result = await action.run(design_filename=str(METAGPT_ROOT / "tests/data/system_design.json")) assert result + assert result.content + m = json.loads(result.content) + assert m + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 42623f807..1c1772031 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -26,12 +26,7 @@ from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPL def setup_inc_workdir(context, inc: bool = False): """setup incremental workdir for testing""" - context.src_workspace = context.git_repo.workdir / "src" - if inc: - context.config.inc = inc - context.repo.old_workspace = context.repo.git_repo.workdir / "old" - context.config.project_path = "old" - + context.config.inc = inc return context diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 43aa336b7..8cbc01716 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -6,6 +6,7 @@ @File : test_write_prd.py @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`. """ +import json import pytest @@ -14,17 +15,16 @@ from metagpt.const import REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.roles.product_manager import ProductManager from metagpt.roles.role import RoleReactMode -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message from metagpt.utils.common import any_to_str -from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE -from tests.metagpt.actions.test_write_code import setup_inc_workdir +from metagpt.utils.project_repo import ProjectRepo +from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE @pytest.mark.asyncio async def test_write_prd(new_filename, context): product_manager = ProductManager(context=context) requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结" - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) product_manager.rc.react_mode = RoleReactMode.BY_ORDER prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement)) assert prd.cause_by == any_to_str(WritePRD) @@ -34,38 +34,39 @@ async def test_write_prd(new_filename, context): # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert product_manager.context.repo.docs.prd.changed_files + repo = ProjectRepo(context.kwargs.project_path) + assert repo.docs.prd.changed_files + repo.git_repo.archive() - -@pytest.mark.asyncio -async def test_write_prd_inc(new_filename, context, git_dir): - context = setup_inc_workdir(context, inc=True) - await context.repo.docs.prd.save("1.txt", PRD_SAMPLE) - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE) + # Mock incremental requirement + context.config.inc = True + context.config.project_path = context.kwargs.project_path + repo = ProjectRepo(context.config.project_path) + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE) action = WritePRD(context=context) - prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)) + prd = await action.run([Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)]) logger.info(NEW_REQUIREMENT_SAMPLE) logger.info(prd) # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert "Refined Requirements" in prd.content + assert repo.git_repo.changed_files @pytest.mark.asyncio async def test_fix_debug(new_filename, context, git_dir): - context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name + # Mock legacy project + context.kwargs.project_path = str(git_dir) + repo = ProjectRepo(context.kwargs.project_path) + repo.with_src_path(git_dir.name) + await repo.srcs.save(filename="main.py", content='if __name__ == "__main__":\nmain()') + requirements = "ValueError: undefined variable `st`." + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) - await context.repo.with_src_path(context.src_workspace).srcs.save( - filename="main.py", content='if __name__ == "__main__":\nmain()' - ) - requirements = "Please fix the bug in the code." - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) action = WritePRD(context=context) - - prd = await action.run(Message(content=requirements, instruct_content=None)) + prd = await action.run([Message(content=requirements, instruct_content=None)]) logger.info(prd) # Assert the prd is not None or empty @@ -73,5 +74,39 @@ async def test_fix_debug(new_filename, context, git_dir): assert prd.content != "" +@pytest.mark.asyncio +async def test_write_prd_api(context): + action = WritePRD() + result = await action.run(user_requirement="write a snake game.") + assert isinstance(result, AIMessage) + assert result.content + m = json.loads(result.content) + assert m + + result = await action.run(user_requirement="write a snake game.", output_path=str(context.config.project_path)) + assert isinstance(result, AIMessage) + assert result.content + assert result.instruct_content + assert str(context.config.project_path) in result.content + + legacy_prd_filename = result.instruct_content.changed_prd_filenames[-1] + + result = await action.run(user_requirement="Add moving enemy.", exists_prd_filename=legacy_prd_filename) + assert isinstance(result, AIMessage) + assert result.content + m = json.loads(result.content) + assert m + + result = await action.run( + user_requirement="Add moving enemy.", + output_path=str(context.config.project_path), + exists_prd_filename=legacy_prd_filename, + ) + assert isinstance(result, AIMessage) + assert result.content + assert result.instruct_content + assert str(context.config.project_path) in result.content + + if __name__ == "__main__": pytest.main([__file__, "-s"]) From 2a0107679e6e8bc1bfe582c3533f83543183cc47 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 3 Jun 2024 10:44:03 +0800 Subject: [PATCH 17/20] merge the newest rag in github --- metagpt/config2.py | 4 + metagpt/configs/embedding_config.py | 50 ++++++++++ metagpt/rag/engines/simple.py | 74 ++++++++++---- metagpt/rag/factories/base.py | 20 ++-- metagpt/rag/factories/embedding.py | 88 ++++++++++++++--- metagpt/rag/factories/index.py | 2 +- metagpt/rag/factories/llm.py | 7 +- metagpt/rag/factories/ranker.py | 24 +++++ metagpt/rag/factories/retriever.py | 89 +++++++++++++---- metagpt/rag/retrievers/bm25_retriever.py | 6 +- metagpt/rag/schema.py | 37 ++++++- metagpt/utils/async_helper.py | 15 +++ setup.py | 3 + tests/metagpt/rag/engines/test_simple.py | 24 ++--- tests/metagpt/rag/factories/test_base.py | 5 +- tests/metagpt/rag/factories/test_embedding.py | 97 +++++++++++++++---- tests/metagpt/rag/factories/test_retriever.py | 50 +++++++--- 17 files changed, 482 insertions(+), 113 deletions(-) create mode 100644 metagpt/configs/embedding_config.py diff --git a/metagpt/config2.py b/metagpt/config2.py index 8c61fdbf2..717fe63a9 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional from pydantic import BaseModel, model_validator from metagpt.configs.browser_config import BrowserConfig +from metagpt.configs.embedding_config import EmbeddingConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig @@ -48,6 +49,9 @@ class Config(CLIParams, YamlModel): # Key Parameters llm: LLMConfig + # RAG Embedding + embedding: EmbeddingConfig = EmbeddingConfig() + # Global Proxy. Not used by LLM, but by other tools such as browsers. proxy: str = "" diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py new file mode 100644 index 000000000..20de47999 --- /dev/null +++ b/metagpt/configs/embedding_config.py @@ -0,0 +1,50 @@ +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt.utils.yaml_model import YamlModel + + +class EmbeddingType(Enum): + OPENAI = "openai" + AZURE = "azure" + GEMINI = "gemini" + OLLAMA = "ollama" + + +class EmbeddingConfig(YamlModel): + """Config for Embedding. + + Examples: + --------- + api_type: "openai" + api_key: "YOU_API_KEY" + + api_type: "azure" + api_key: "YOU_API_KEY" + base_url: "YOU_BASE_URL" + api_version: "YOU_API_VERSION" + + api_type: "gemini" + api_key: "YOU_API_KEY" + + api_type: "ollama" + base_url: "YOU_BASE_URL" + model: "YOU_MODEL" + """ + + api_type: Optional[EmbeddingType] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + api_version: Optional[str] = None + + model: Optional[str] = None + embed_batch_size: Optional[int] = None + + @field_validator("api_type", mode="before") + @classmethod + def check_api_type(cls, v): + if v == "": + return None + return v diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 623b3f350..c237dcf69 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -4,8 +4,7 @@ import json import os from typing import Any, Optional, Union -from fsspec import AbstractFileSystem -from llama_index.core import SimpleDirectoryReader, VectorStoreIndex +from llama_index.core import SimpleDirectoryReader from llama_index.core.callbacks.base import CallbackManager from llama_index.core.embeddings import BaseEmbedding from llama_index.core.embeddings.mock_embed_model import MockEmbedding @@ -64,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine): response_synthesizer: Optional[BaseSynthesizer] = None, node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, callback_manager: Optional[CallbackManager] = None, - index: Optional[BaseIndex] = None, + transformations: Optional[list[TransformComponent]] = None, ) -> None: super().__init__( retriever=retriever, @@ -72,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine): node_postprocessors=node_postprocessors, callback_manager=callback_manager, ) - self.index = index + self._transformations = transformations or self._default_transformations() @classmethod def from_docs( @@ -84,7 +83,6 @@ class SimpleEngine(RetrieverQueryEngine): llm: LLM = None, retriever_configs: list[BaseRetrieverConfig] = None, ranker_configs: list[BaseRankerConfig] = None, - fs: Optional[AbstractFileSystem] = None, ) -> "SimpleEngine": """From docs. @@ -102,15 +100,20 @@ class SimpleEngine(RetrieverQueryEngine): if not input_dir and not input_files: raise ValueError("Must provide either `input_dir` or `input_files`.") - documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files, fs=fs).load_data() + documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() cls._fix_document_metadata(documents) - index = VectorStoreIndex.from_documents( - documents=documents, - transformations=transformations or [SentenceSplitter()], - embed_model=cls._resolve_embed_model(embed_model, retriever_configs), + transformations = transformations or cls._default_transformations() + nodes = run_transformations(documents, transformations=transformations) + + return cls._from_nodes( + nodes=nodes, + transformations=transformations, + embed_model=embed_model, + llm=llm, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, ) - return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @classmethod def from_objs( @@ -139,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine): raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] - index = VectorStoreIndex( + + return cls._from_nodes( nodes=nodes, - transformations=transformations or [SentenceSplitter()], - embed_model=cls._resolve_embed_model(embed_model, retriever_configs), + transformations=transformations, + embed_model=embed_model, + llm=llm, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, ) - return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @classmethod def from_index( @@ -163,6 +169,13 @@ class SimpleEngine(RetrieverQueryEngine): """Inplement tools.SearchInterface""" return await self.aquery(content) + def retrieve(self, query: QueryType) -> list[NodeWithScore]: + query_bundle = QueryBundle(query) if isinstance(query, str) else query + + nodes = super().retrieve(query_bundle) + self._try_reconstruct_obj(nodes) + return nodes + async def aretrieve(self, query: QueryType) -> list[NodeWithScore]: """Allow query to be str.""" query_bundle = QueryBundle(query) if isinstance(query, str) else query @@ -178,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine): documents = SimpleDirectoryReader(input_files=input_files).load_data() self._fix_document_metadata(documents) - nodes = run_transformations(documents, transformations=self.index._transformations) + nodes = run_transformations(documents, transformations=self._transformations) self._save_nodes(nodes) def add_objs(self, objs: list[RAGObject]): @@ -194,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine): self._persist(str(persist_dir), **kwargs) + @classmethod + def _from_nodes( + cls, + nodes: list[BaseNode], + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, + ) -> "SimpleEngine": + embed_model = cls._resolve_embed_model(embed_model, retriever_configs) + llm = llm or get_rag_llm() + + retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model) + rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] + + return cls( + retriever=retriever, + node_postprocessors=rankers, + response_synthesizer=get_response_synthesizer(llm=llm), + transformations=transformations, + ) + @classmethod def _from_index( cls, @@ -203,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine): ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": llm = llm or get_rag_llm() + retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] @@ -210,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine): retriever=retriever, node_postprocessors=rankers, response_synthesizer=get_response_synthesizer(llm=llm), - index=index, ) def _ensure_retriever_modifiable(self): @@ -261,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine): return MockEmbedding(embed_dim=1) return embed_model or get_rag_embedding() + + @staticmethod + def _default_transformations(): + return [SentenceSplitter()] diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index fbdfbf1a8..e58643efe 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -26,6 +26,9 @@ class GenericFactory: if creator: return creator(**kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Creator not registered for key: {key}") @@ -33,19 +36,26 @@ class ConfigBasedFactory(GenericFactory): """Designed to get objects based on object type.""" def get_instance(self, key: Any, **kwargs) -> Any: - """Key is config, such as a pydantic model. + """Get instance by the type of key. - Call func by the type of key, and the key will be passed to func. + Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func. + Raise Exception if key not found. """ creator = self._creators.get(type(key)) if creator: return creator(key, **kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Unknown config: `{type(key)}`, {key}") @staticmethod def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any: - """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.""" + """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs. + + Return None if not found. + """ if config is not None and hasattr(config, key): val = getattr(config, key) if val is not None: @@ -54,6 +64,4 @@ class ConfigBasedFactory(GenericFactory): if key in kwargs: return kwargs[key] - raise KeyError( - f"The key '{key}' is required but not provided in either configuration object or keyword arguments." - ) + return None diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 4247db256..3613fd228 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -1,37 +1,103 @@ """RAG Embedding Factory.""" +from __future__ import annotations + +from typing import Any from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding +from llama_index.embeddings.gemini import GeminiEmbedding +from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.embeddings.openai import OpenAIEmbedding from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.base import GenericFactory class RAGEmbeddingFactory(GenericFactory): - """Create LlamaIndex Embedding with MetaGPT's config.""" + """Create LlamaIndex Embedding with MetaGPT's embedding config.""" def __init__(self): creators = { + EmbeddingType.OPENAI: self._create_openai, + EmbeddingType.AZURE: self._create_azure, + EmbeddingType.GEMINI: self._create_gemini, + EmbeddingType.OLLAMA: self._create_ollama, + # For backward compatibility LLMType.OPENAI: self._create_openai, LLMType.AZURE: self._create_azure, } super().__init__(creators) - def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding: - """Key is LLMType, default use config.llm.api_type.""" - return super().get_instance(key or config.llm.api_type) + def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding: + """Key is EmbeddingType.""" + return super().get_instance(key or self._resolve_embedding_type()) - def _create_openai(self): - return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url) + def _resolve_embedding_type(self) -> EmbeddingType | LLMType: + """Resolves the embedding type. - def _create_azure(self): - return AzureOpenAIEmbedding( - azure_endpoint=config.llm.base_url, - api_key=config.llm.api_key, - api_version=config.llm.api_version, + If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE. + Raise TypeError if embedding type not found. + """ + if config.embedding.api_type: + return config.embedding.api_type + + if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]: + return config.llm.api_type + + raise TypeError("To use RAG, please set your embedding in config2.yaml.") + + def _create_openai(self) -> OpenAIEmbedding: + params = dict( + api_key=config.embedding.api_key or config.llm.api_key, + api_base=config.embedding.base_url or config.llm.base_url, ) + self._try_set_model_and_batch_size(params) + + return OpenAIEmbedding(**params) + + def _create_azure(self) -> AzureOpenAIEmbedding: + params = dict( + api_key=config.embedding.api_key or config.llm.api_key, + azure_endpoint=config.embedding.base_url or config.llm.base_url, + api_version=config.embedding.api_version or config.llm.api_version, + ) + + self._try_set_model_and_batch_size(params) + + return AzureOpenAIEmbedding(**params) + + def _create_gemini(self) -> GeminiEmbedding: + params = dict( + api_key=config.embedding.api_key, + api_base=config.embedding.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return GeminiEmbedding(**params) + + def _create_ollama(self) -> OllamaEmbedding: + params = dict( + base_url=config.embedding.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return OllamaEmbedding(**params) + + def _try_set_model_and_batch_size(self, params: dict): + """Set the model_name and embed_batch_size only when they are specified.""" + if config.embedding.model: + params["model_name"] = config.embedding.model + + if config.embedding.embed_batch_size: + params["embed_batch_size"] = config.embedding.embed_batch_size + + def _raise_for_key(self, key: Any): + raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}") + get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index a56471359..f897af3ad 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -48,7 +48,7 @@ class RAGIndexFactory(ConfigBasedFactory): def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: db = chromadb.PersistentClient(str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 17c499b76..9fd19cab5 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -1,5 +1,5 @@ """RAG LLM.""" - +import asyncio from typing import Any from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW @@ -15,7 +15,7 @@ from pydantic import Field from metagpt.config2 import config from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM -from metagpt.utils.async_helper import run_coroutine_in_new_loop +from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.token_counter import TOKEN_MAX @@ -39,7 +39,8 @@ class RAGLLM(CustomLLM): @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs)) + NestAsyncio.apply_once() + return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs)) @llm_completion_callback() async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse: diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 476fe8c1a..7abda162a 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -8,6 +8,8 @@ from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor from metagpt.rag.schema import ( BaseRankerConfig, + BGERerankConfig, + CohereRerankConfig, ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig, @@ -22,6 +24,8 @@ class RankerFactory(ConfigBasedFactory): LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker, ObjectRankerConfig: self._create_object_ranker, + CohereRerankConfig: self._create_cohere_rerank, + BGERerankConfig: self._create_bge_rerank, } super().__init__(creators) @@ -45,6 +49,26 @@ class RankerFactory(ConfigBasedFactory): ) return ColbertRerank(**config.model_dump()) + def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.cohere_rerank import CohereRerank + except ImportError: + raise ImportError( + "`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`" + ) + return CohereRerank(**config.model_dump()) + + def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.flag_embedding_reranker import ( + FlagEmbeddingReranker, + ) + except ImportError: + raise ImportError( + "`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`" + ) + return FlagEmbeddingReranker(**config.model_dump()) + def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: return ObjectSortPostprocessor(**config.model_dump()) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 65729002e..1460e131b 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -1,10 +1,13 @@ """RAG Retriever Factory.""" -import copy + +from functools import wraps import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore @@ -24,10 +27,25 @@ from metagpt.rag.schema import ( ElasticsearchKeywordRetrieverConfig, ElasticsearchRetrieverConfig, FAISSRetrieverConfig, - IndexRetrieverConfig, ) +def get_or_build_index(build_index_func): + """Decorator to get or build an index. + + Get index using `_extract_index` method, if not found, using build_index_func. + """ + + @wraps(build_index_func) + def wrapper(self, config, **kwargs): + index = self._extract_index(config, **kwargs) + if index is not None: + return index + return build_index_func(self, config, **kwargs) + + return wrapper + + class RetrieverFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" @@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory): return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0] def _create_default(self, **kwargs) -> RAGRetriever: - return self._extract_index(**kwargs).as_retriever() + index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs) + + return index.as_retriever() def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: - vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_faiss_index(config, **kwargs) return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: - config.index = copy.deepcopy(self._extract_index(config, **kwargs)) + index = self._extract_index(config, **kwargs) + nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs) - return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump()) + return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: - db = chromadb.PersistentClient(path=str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name) - - vector_store = ChromaVectorStore(chroma_collection=chroma_collection) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_chroma_index(config, **kwargs) return ChromaRetriever(**config.model_dump()) def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever: - vector_store = ElasticsearchStore(**config.store_config.model_dump()) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_es_index(config, **kwargs) return ElasticsearchRetriever(**config.model_dump()) def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: return self._val_from_config_or_kwargs("index", config, **kwargs) + def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]: + return self._val_from_config_or_kwargs("nodes", config, **kwargs) + + def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding: + return self._val_from_config_or_kwargs("embed_model", config, **kwargs) + + def _build_default_index(self, **kwargs) -> VectorStoreIndex: + index = VectorStoreIndex( + nodes=self._extract_nodes(**kwargs), + embed_model=self._extract_embed_model(**kwargs), + ) + + return index + + @get_or_build_index + def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + + @get_or_build_index + def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex: + db = chromadb.PersistentClient(path=str(config.persist_path)) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + + @get_or_build_index + def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + def _build_index_from_vector_store( - self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs + self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs ) -> VectorStoreIndex: storage_context = StorageContext.from_defaults(vector_store=vector_store) - old_index = self._extract_index(config, **kwargs) - new_index = VectorStoreIndex( - nodes=list(old_index.docstore.docs.values()), + index = VectorStoreIndex( + nodes=self._extract_nodes(config, **kwargs), storage_context=storage_context, - embed_model=old_index._embed_model, + embed_model=self._extract_embed_model(config, **kwargs), ) - return new_index + + return index get_retriever = RetrieverFactory().get_retriever diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 241820cf4..3b085cb73 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -40,8 +40,10 @@ class DynamicBM25Retriever(BM25Retriever): self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] self.bm25 = BM25Okapi(self._corpus) - self._index.insert_nodes(nodes, **kwargs) + if self._index: + self._index.insert_nodes(nodes, **kwargs) def persist(self, persist_dir: str, **kwargs) -> None: """Support persist.""" - self._index.storage_context.persist(persist_dir) + if self._index: + self._index.storage_context.persist(persist_dir) \ No newline at end of file diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 183f6e0c7..e7b2e5ce9 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,14 +1,17 @@ """RAG schemas.""" from pathlib import Path -from typing import Any, Literal, Union +from typing import Any, ClassVar, Literal, Optional, Union +from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.rag.interface import RAGObject @@ -31,7 +34,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig): class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" - dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.") + dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") + + _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { + EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, + } + + @model_validator(mode="after") + def check_dimensions(self): + if self.dimensions == 0: + self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) + + return self class BM25RetrieverConfig(IndexRetrieverConfig): @@ -45,6 +60,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class ElasticsearchStoreConfig(BaseModel): @@ -101,6 +119,16 @@ class ColbertRerankConfig(BaseRankerConfig): keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") +class CohereRerankConfig(BaseRankerConfig): + model: str = Field(default="rerank-english-v3.0") + api_key: str = Field(default="YOUR_COHERE_API") + + +class BGERerankConfig(BaseRankerConfig): + model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.") + use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.") + + class ObjectRankerConfig(BaseRankerConfig): field_name: str = Field(..., description="field name of the object, field's value must can be compared.") order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") @@ -130,6 +158,9 @@ class ChromaIndexConfig(VectorIndexConfig): """Config for chroma-based index.""" collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class BM25IndexConfig(BaseIndexConfig): diff --git a/metagpt/utils/async_helper.py b/metagpt/utils/async_helper.py index ee440ef44..cecb20c5d 100644 --- a/metagpt/utils/async_helper.py +++ b/metagpt/utils/async_helper.py @@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any: new_loop.call_soon_threadsafe(new_loop.stop) t.join() new_loop.close() + + +class NestAsyncio: + """Make asyncio event loop reentrant.""" + + is_applied = False + + @classmethod + def apply_once(cls): + """Ensures `nest_asyncio.apply()` is called only once.""" + if not cls.is_applied: + import nest_asyncio + + nest_asyncio.apply() + cls.is_applied = True diff --git a/setup.py b/setup.py index 382e13a47..79b65ad47 100644 --- a/setup.py +++ b/setup.py @@ -32,12 +32,15 @@ extras_require = { "llama-index-core==0.10.15", "llama-index-embeddings-azure-openai==0.1.6", "llama-index-embeddings-openai==0.1.5", + "llama-index-embeddings-gemini==0.1.6", + "llama-index-embeddings-ollama==0.1.2", "llama-index-llms-azure-openai==0.1.4", "llama-index-readers-file==0.1.4", "llama-index-retrievers-bm25==0.1.3", "llama-index-vector-stores-faiss==0.1.1", "llama-index-vector-stores-elasticsearch==0.1.6", "llama-index-vector-stores-chroma==0.1.6", + "docx2txt==0.8", ], } diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 9262ccb07..8c7a15be2 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -25,10 +25,6 @@ class TestSimpleEngine: def mock_simple_directory_reader(self, mocker): return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") - @pytest.fixture - def mock_vector_store_index(self, mocker): - return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") - @pytest.fixture def mock_get_retriever(self, mocker): return mocker.patch("metagpt.rag.engines.simple.get_retriever") @@ -45,7 +41,6 @@ class TestSimpleEngine: self, mocker, mock_simple_directory_reader, - mock_vector_store_index, mock_get_retriever, mock_get_rankers, mock_get_response_synthesizer, @@ -81,11 +76,8 @@ class TestSimpleEngine: # Assert mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) - mock_vector_store_index.assert_called_once() - mock_get_retriever.assert_called_once_with( - configs=retriever_configs, index=mock_vector_store_index.return_value - ) - mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm) + mock_get_retriever.assert_called_once() + mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) assert isinstance(engine, SimpleEngine) @@ -119,7 +111,7 @@ class TestSimpleEngine: # Assert assert isinstance(engine, SimpleEngine) - assert engine.index is not None + assert engine._transformations is not None def test_from_objs_with_bm25_config(self): # Setup @@ -137,6 +129,7 @@ class TestSimpleEngine: def test_from_index(self, mocker, mock_llm, mock_embedding): # Mock mock_index = mocker.MagicMock(spec=VectorStoreIndex) + mock_index.as_retriever.return_value = "retriever" mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index") mock_get_index.return_value = mock_index @@ -149,7 +142,7 @@ class TestSimpleEngine: # Assert assert isinstance(engine, SimpleEngine) - assert engine.index is mock_index + assert engine._retriever == "retriever" @pytest.mark.asyncio async def test_asearch(self, mocker): @@ -200,14 +193,11 @@ class TestSimpleEngine: mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever) - mock_index = mocker.MagicMock(spec=VectorStoreIndex) - mock_index._transformations = mocker.MagicMock() - mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations") mock_run_transformations.return_value = ["node1", "node2"] # Setup - engine = SimpleEngine(retriever=mock_retriever, index=mock_index) + engine = SimpleEngine(retriever=mock_retriever) input_files = ["test_file1", "test_file2"] # Exec @@ -230,7 +220,7 @@ class TestSimpleEngine: return "" objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)] - engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock()) + engine = SimpleEngine(retriever=mock_retriever) # Exec engine.add_objs(objs=objs) diff --git a/tests/metagpt/rag/factories/test_base.py b/tests/metagpt/rag/factories/test_base.py index 1d41e1872..0b0a44976 100644 --- a/tests/metagpt/rag/factories/test_base.py +++ b/tests/metagpt/rag/factories/test_base.py @@ -97,6 +97,5 @@ class TestConfigBasedFactory: def test_val_from_config_or_kwargs_key_error(self): # Test KeyError when the key is not found in both config object and kwargs config = DummyConfig(name=None) - with pytest.raises(KeyError) as exc_info: - ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) - assert "The key 'missing_key' is required but not provided" in str(exc_info.value) + val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) + assert val is None diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1ded6b4a8..1a9e9b2c9 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -1,5 +1,6 @@ import pytest +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.embedding import RAGEmbeddingFactory @@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory: self.embedding_factory = RAGEmbeddingFactory() @pytest.fixture - def mock_openai_embedding(self, mocker): + def mock_config(self, mocker): + return mocker.patch("metagpt.rag.factories.embedding.config") + + @staticmethod + def mock_openai_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") - @pytest.fixture - def mock_azure_embedding(self, mocker): + @staticmethod + def mock_azure_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") - def test_get_rag_embedding_openai(self, mock_openai_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.OPENAI) + @staticmethod + def mock_gemini_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") - # Assert - mock_openai_embedding.assert_called_once() + @staticmethod + def mock_ollama_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") - def test_get_rag_embedding_azure(self, mock_azure_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.AZURE) - - # Assert - mock_azure_embedding.assert_called_once() - - def test_get_rag_embedding_default(self, mocker, mock_openai_embedding): + @pytest.mark.parametrize( + ("mock_func", "embedding_type"), + [ + (mock_openai_embedding, LLMType.OPENAI), + (mock_azure_embedding, LLMType.AZURE), + (mock_openai_embedding, EmbeddingType.OPENAI), + (mock_azure_embedding, EmbeddingType.AZURE), + (mock_gemini_embedding, EmbeddingType.GEMINI), + (mock_ollama_embedding, EmbeddingType.OLLAMA), + ], + ) + def test_get_rag_embedding(self, mock_func, embedding_type, mocker): # Mock - mock_config = mocker.patch("metagpt.rag.factories.embedding.config") + mock = mock_func(mocker) + + # Exec + self.embedding_factory.get_rag_embedding(embedding_type) + + # Assert + mock.assert_called_once() + + def test_get_rag_embedding_default(self, mocker, mock_config): + # Mock + mock_openai_embedding = self.mock_openai_embedding(mocker) + + mock_config.embedding.api_type = None mock_config.llm.api_type = LLMType.OPENAI # Exec @@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory: # Assert mock_openai_embedding.assert_called_once() + + @pytest.mark.parametrize( + "model, embed_batch_size, expected_params", + [("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})], + ) + def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params): + # Mock + mock_config.embedding.model = model + mock_config.embedding.embed_batch_size = embed_batch_size + + # Setup + test_params = {} + + # Exec + self.embedding_factory._try_set_model_and_batch_size(test_params) + + # Assert + assert test_params == expected_params + + def test_resolve_embedding_type(self, mock_config): + # Mock + mock_config.embedding.api_type = EmbeddingType.OPENAI + + # Exec + embedding_type = self.embedding_factory._resolve_embedding_type() + + # Assert + assert embedding_type == EmbeddingType.OPENAI + + def test_resolve_embedding_type_exception(self, mock_config): + # Mock + mock_config.embedding.api_type = None + mock_config.llm.api_type = LLMType.GEMINI + + # Assert + with pytest.raises(TypeError): + self.embedding_factory._resolve_embedding_type() + + def test_raise_for_key(self): + with pytest.raises(ValueError): + self.embedding_factory._raise_for_key("key") diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index ef1cef7e0..cd55a32db 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -1,6 +1,8 @@ import faiss import pytest from llama_index.core import VectorStoreIndex +from llama_index.core.embeddings import MockEmbedding +from llama_index.core.schema import TextNode from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore @@ -43,6 +45,14 @@ class TestRetrieverFactory: def mock_es_vector_store(self, mocker): return mocker.MagicMock(spec=ElasticsearchStore) + @pytest.fixture + def mock_nodes(self, mocker): + return [TextNode(text="msg")] + + @pytest.fixture + def mock_embedding(self): + return MockEmbedding(embed_dim=1) + def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index): mock_config = FAISSRetrieverConfig(dimensions=128) mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index) @@ -52,42 +62,40 @@ class TestRetrieverFactory: assert isinstance(retriever, FAISSRetriever) - def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index): + def test_get_retriever_with_bm25_config(self, mocker, mock_nodes): mock_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes) assert isinstance(retriever, DynamicBM25Retriever) - def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index): - mock_faiss_config = FAISSRetrieverConfig(dimensions=128) + def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding): + mock_faiss_config = FAISSRetrieverConfig(dimensions=1) mock_bm25_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config]) + retriever = self.retriever_factory.get_retriever( + configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding + ) assert isinstance(retriever, SimpleHybridRetriever) - def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store): + def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding): mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection") mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient") mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock() mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) assert isinstance(retriever, ChromaRetriever) - def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store): + def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding): mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig()) mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) assert isinstance(retriever, ElasticsearchRetriever) @@ -111,3 +119,19 @@ class TestRetrieverFactory: extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index) assert extracted_index == mock_vector_store_index + + def test_get_or_build_when_get(self, mocker): + want = "existing_index" + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want) + + got = self.retriever_factory._build_es_index(None) + + assert got == want + + def test_get_or_build_when_build(self, mocker): + want = "call_build_es_index" + mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want) + + got = self.retriever_factory._build_es_index(None) + + assert got == want From b0bafc12d4ed5820dca3469056e2f8f78ef3913b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 4 Jun 2024 14:32:51 +0800 Subject: [PATCH 18/20] refactor: rollback code --- metagpt/roles/product_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 1f66758ea..d08933cb0 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -35,9 +35,9 @@ class ProductManager(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.enable_memory = False - self.rc.react_mode = RoleReactMode.BY_ORDER self.set_actions([PrepareDocuments(send_to=any_to_str(self)), WritePRD]) self._watch([UserRequirement, PrepareDocuments]) + self.rc.react_mode = RoleReactMode.BY_ORDER self.todo_action = any_to_name(WritePRD) async def _think(self) -> bool: From 04e72c05976ec815dce60036a1b2a013a53b806a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 4 Jun 2024 15:47:03 +0800 Subject: [PATCH 19/20] =?UTF-8?q?refactor:=20`exists`=20=E6=94=B9=20`legac?= =?UTF-8?q?y`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/actions/design_api.py | 28 ++++++++++++------------ metagpt/actions/write_prd.py | 20 ++++++++--------- tests/metagpt/actions/test_design_api.py | 12 +++++----- tests/metagpt/actions/test_write_prd.py | 4 ++-- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 981e1405a..b0a6a2861 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -62,7 +62,7 @@ class WriteDesign(Action): *, user_requirement: str = "", prd_filename: str = "", - exists_design_filename: str = "", + legacy_design_filename: str = "", extra_info: str = "", output_path: str = "", **kwargs, @@ -73,7 +73,7 @@ class WriteDesign(Action): Args: user_requirement (str): The user's requirements for the system design. prd_filename (str, optional): The filename of the Product Requirement Document (PRD). - exists_design_filename (str, optional): The filename of the existing design document. + legacy_design_filename (str, optional): The filename of the legacy design document. extra_info (str, optional): Additional information to be included in the system design. output_path (str, optional): The output path where the system design should be saved. @@ -92,9 +92,9 @@ class WriteDesign(Action): # Modify an exists system design. >>> user_requirement = "Your user requirements" >>> extra_info = "Your extra information" - >>> exists_design_filename = "/path/to/exists/design/filename" + >>> legacy_design_filename = "/path/to/exists/design/filename" >>> action = WriteDesign() - >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, exists_design_filename=exists_design_filename) + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename) >>> print(result.content) The design is balabala... @@ -111,9 +111,9 @@ class WriteDesign(Action): >>> user_requirement = "Your user requirements" >>> extra_info = "Your extra information" >>> prd_filename = "/path/to/prd/filename" - >>> exists_design_filename = "/path/to/exists/design/filename" + >>> legacy_design_filename = "/path/to/exists/design/filename" >>> action = WriteDesign() - >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, exists_design_filename=exists_design_filename, prd_filename=prd_filename) + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename) >>> print(result.content) The design is balabala... @@ -129,10 +129,10 @@ class WriteDesign(Action): # Modify an exists system design and save to the directory. >>> user_requirement = "Your user requirements" >>> extra_info = "Your extra information" - >>> exists_design_filename = "/path/to/exists/design/filename" + >>> legacy_design_filename = "/path/to/exists/design/filename" >>> output_path = "/path/to/save/" >>> action = WriteDesign() - >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, exists_design_filename=exists_design_filename) + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename) >>> print(result.content) System Design filename: "/path/to/design/filename" @@ -150,10 +150,10 @@ class WriteDesign(Action): >>> user_requirement = "Your user requirements" >>> extra_info = "Your extra information" >>> prd_filename = "/path/to/prd/filename" - >>> exists_design_filename = "/path/to/exists/design/filename" + >>> legacy_design_filename = "/path/to/exists/design/filename" >>> output_path = "/path/to/save/" >>> action = WriteDesign() - >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, exists_design_filename=exists_design_filename, prd_filename=prd_filename) + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename) >>> print(result.content) System Design filename: "/path/to/design/filename" """ @@ -161,7 +161,7 @@ class WriteDesign(Action): return await self._execute_api( user_requirement=user_requirement, prd_filename=prd_filename, - exists_design_filename=exists_design_filename, + legacy_design_filename=legacy_design_filename, extra_info=extra_info, output_path=output_path, ) @@ -266,7 +266,7 @@ class WriteDesign(Action): self, user_requirement: str = "", prd_filename: str = "", - exists_design_filename: str = "", + legacy_design_filename: str = "", extra_info: str = "", output_path: str = "", ) -> AIMessage: @@ -276,11 +276,11 @@ class WriteDesign(Action): if prd_filename: prd_content = await aread(filename=prd_filename) context += to_markdown_code_block(prd_content) - if not exists_design_filename: + if not legacy_design_filename: node = await self._new_system_design(context=context) design = Document(content=node.instruct_content.model_dump_json()) else: - old_design_content = await aread(filename=exists_design_filename) + old_design_content = await aread(filename=legacy_design_filename) design = await self._merge( prd_doc=Document(content=context), system_design_doc=Document(content=old_design_content) ) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index de3bcde84..58db97c75 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -83,7 +83,7 @@ class WritePRD(Action): *, user_requirement: str = "", output_path: str = "", - exists_prd_filename: str = "", + legacy_prd_filename: str = "", extra_info: str = "", **kwargs, ) -> AIMessage: @@ -93,7 +93,7 @@ class WritePRD(Action): Args: user_requirement (str): A string detailing the user's requirements. output_path (str, optional): The file path where the output document should be saved. Defaults to "". - exists_prd_filename (str, optional): The file path of an existing Product Requirement Document to use as a reference. Defaults to "". + legacy_prd_filename (str, optional): The file path of the legacy Product Requirement Document to use as a reference. Defaults to "". extra_info (str, optional): Additional information to include in the document. Defaults to "". **kwargs: Additional keyword arguments. @@ -112,9 +112,9 @@ class WritePRD(Action): # Modify a exists PRD(Product Requirement Document) >>> user_requirement = "YOUR REQUIREMENTS" >>> extra_info = "YOUR EXTRA INFO" - >>> exists_prd_filename = "/path/to/exists/prd_filename" + >>> legacy_prd_filename = "/path/to/exists/prd_filename" >>> write_prd = WritePRD() - >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, exists_prd_filename=exists_prd_filename) + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename) >>> print(result.content) The PRD is about balabala... @@ -130,10 +130,10 @@ class WritePRD(Action): # Modify a exists PRD(Product Requirement Document) and save to the directory. >>> user_requirement = "YOUR REQUIREMENTS" >>> extra_info = "YOUR EXTRA INFO" - >>> exists_prd_filename = "/path/to/exists/prd_filename" + >>> legacy_prd_filename = "/path/to/exists/prd_filename" >>> output_path = "/path/to/prd/directory/" >>> write_prd = WritePRD() - >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, exists_prd_filename=exists_prd_filename, output_path=output_path) + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename, output_path=output_path) >>> print(result.content) PRD filename: "/path/to/prd/directory/213434ad.json" @@ -142,7 +142,7 @@ class WritePRD(Action): return await self._execute_api( user_requirement=user_requirement, output_path=output_path, - exists_prd_filename=exists_prd_filename, + legacy_prd_filename=legacy_prd_filename, extra_info=extra_info, ) @@ -306,18 +306,18 @@ class WritePRD(Action): self.repo.git_repo.rename_root(self.project_name) async def _execute_api( - self, user_requirement: str, output_path: str, exists_prd_filename: str, extra_info: str + self, user_requirement: str, output_path: str, legacy_prd_filename: str, extra_info: str ) -> AIMessage: content = to_markdown_code_block(val=user_requirement, type_="text") if extra_info: content += to_markdown_code_block(val=extra_info) req = Document(content=content) - if not exists_prd_filename: + if not legacy_prd_filename: node = await self._new_prd(requirement=req.content) new_prd = Document(content=node.instruct_content.model_dump_json()) else: - content = await aread(filename=exists_prd_filename) + content = await aread(filename=legacy_prd_filename) old_prd = Document(content=content) new_prd = await self._merge(req=req, related_doc=old_prd) diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index 1351b418a..0a792fb15 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -55,7 +55,7 @@ async def test_design(context): @pytest.mark.parametrize( - ("user_requirement", "prd_filename", "exists_design_filename"), + ("user_requirement", "prd_filename", "legacy_design_filename"), [ ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), @@ -67,10 +67,10 @@ async def test_design(context): ], ) @pytest.mark.asyncio -async def test_design_api(context, user_requirement, prd_filename, exists_design_filename): +async def test_design_api(context, user_requirement, prd_filename, legacy_design_filename): action = WriteDesign() result = await action.run( - user_requirement=user_requirement, prd_filename=prd_filename, exists_design_filename=exists_design_filename + user_requirement=user_requirement, prd_filename=prd_filename, legacy_design_filename=legacy_design_filename ) assert isinstance(result, AIMessage) assert result.content @@ -79,7 +79,7 @@ async def test_design_api(context, user_requirement, prd_filename, exists_design @pytest.mark.parametrize( - ("user_requirement", "prd_filename", "exists_design_filename"), + ("user_requirement", "prd_filename", "legacy_design_filename"), [ ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), @@ -91,12 +91,12 @@ async def test_design_api(context, user_requirement, prd_filename, exists_design ], ) @pytest.mark.asyncio -async def test_design_api_dir(context, user_requirement, prd_filename, exists_design_filename): +async def test_design_api_dir(context, user_requirement, prd_filename, legacy_design_filename): action = WriteDesign() result = await action.run( user_requirement=user_requirement, prd_filename=prd_filename, - exists_design_filename=exists_design_filename, + legacy_design_filename=legacy_design_filename, output_path=context.config.project_path, ) assert isinstance(result, AIMessage) diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 8cbc01716..93a1b150c 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -91,7 +91,7 @@ async def test_write_prd_api(context): legacy_prd_filename = result.instruct_content.changed_prd_filenames[-1] - result = await action.run(user_requirement="Add moving enemy.", exists_prd_filename=legacy_prd_filename) + result = await action.run(user_requirement="Add moving enemy.", legacy_prd_filename=legacy_prd_filename) assert isinstance(result, AIMessage) assert result.content m = json.loads(result.content) @@ -100,7 +100,7 @@ async def test_write_prd_api(context): result = await action.run( user_requirement="Add moving enemy.", output_path=str(context.config.project_path), - exists_prd_filename=legacy_prd_filename, + legacy_prd_filename=legacy_prd_filename, ) assert isinstance(result, AIMessage) assert result.content From 4359abd84a3a53ffc3862ed9421f83a6e3a66ae6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 5 Jun 2024 12:11:45 +0800 Subject: [PATCH 20/20] =?UTF-8?q?feat:=20=E4=BF=AE=E6=94=B9output=5Fpath?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/actions/design_api.py | 69 +++++++++++++----------- metagpt/actions/write_prd.py | 47 ++++++++-------- metagpt/tools/libs/browser.py | 7 +-- metagpt/utils/common.py | 2 + metagpt/utils/file.py | 1 - metagpt/utils/parse_html.py | 3 +- tests/metagpt/actions/test_design_api.py | 9 ++-- tests/metagpt/actions/test_write_prd.py | 18 ++++--- 8 files changed, 83 insertions(+), 73 deletions(-) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index b0a6a2861..981dde53f 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -26,7 +26,11 @@ from metagpt.actions.design_api_an import ( REFINED_DESIGN_NODE, REFINED_PROGRAM_CALL_FLOW, ) -from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO +from metagpt.const import ( + DATA_API_DESIGN_FILE_REPO, + DEFAULT_WORKSPACE_ROOT, + SEQ_FLOW_FILE_REPO, +) from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message from metagpt.tools.tool_registry import register_tool @@ -64,7 +68,7 @@ class WriteDesign(Action): prd_filename: str = "", legacy_design_filename: str = "", extra_info: str = "", - output_path: str = "", + output_pathname: str = "", **kwargs, ) -> AIMessage: """ @@ -75,7 +79,7 @@ class WriteDesign(Action): prd_filename (str, optional): The filename of the Product Requirement Document (PRD). legacy_design_filename (str, optional): The filename of the legacy design document. extra_info (str, optional): Additional information to be included in the system design. - output_path (str, optional): The output path where the system design should be saved. + output_pathname (str, optional): The output path name of file that the system design should be saved to. Returns: AIMessage: An AIMessage object containing the system design. @@ -87,7 +91,7 @@ class WriteDesign(Action): >>> action = WriteDesign() >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info) >>> print(result.content) - The design is balabala... + System Design filename: "/path/to/design/filename" # Modify an exists system design. >>> user_requirement = "Your user requirements" @@ -96,7 +100,7 @@ class WriteDesign(Action): >>> action = WriteDesign() >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename) >>> print(result.content) - The design is balabala... + System Design filename: "/path/to/design/filename" # Write a new system design with the given PRD(Product Requirement Document). >>> user_requirement = "Your user requirements" @@ -105,7 +109,7 @@ class WriteDesign(Action): >>> action = WriteDesign() >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename) >>> print(result.content) - The design is balabala... + System Design filename: "/path/to/design/filename" # Modify an exists system design with the given PRD(Product Requirement Document). >>> user_requirement = "Your user requirements" @@ -115,45 +119,45 @@ class WriteDesign(Action): >>> action = WriteDesign() >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename) >>> print(result.content) - The design is balabala... + TSystem Design filename: "/path/to/design/filename" - # Write a new system design and save to the directory. + # Write a new system design and save to the path name. >>> user_requirement = "Your user requirements" >>> extra_info = "Your extra information" - >>> output_path = "/path/to/save/" + >>> output_pathname = "/path/to/design/filename" >>> action = WriteDesign() - >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, output_path=output_path) + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, output_pathname=output_pathname) >>> print(result.content) System Design filename: "/path/to/design/filename" - # Modify an exists system design and save to the directory. + # Modify an exists system design and save to the path name. >>> user_requirement = "Your user requirements" >>> extra_info = "Your extra information" >>> legacy_design_filename = "/path/to/exists/design/filename" - >>> output_path = "/path/to/save/" + >>> output_pathname = "/path/to/design/filename" >>> action = WriteDesign() - >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename) + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, output_pathname=output_pathname) >>> print(result.content) System Design filename: "/path/to/design/filename" - # Write a new system design with the given PRD(Product Requirement Document) and save to the directory. + # Write a new system design with the given PRD(Product Requirement Document) and save to the path name. >>> user_requirement = "Your user requirements" >>> extra_info = "Your extra information" >>> prd_filename = "/path/to/prd/filename" - >>> output_path = "/path/to/save/" + >>> output_pathname = "/path/to/design/filename" >>> action = WriteDesign() - >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename) + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename, output_pathname=output_pathname) >>> print(result.content) System Design filename: "/path/to/design/filename" - # Modify an exists system design with the given PRD(Product Requirement Document) and save to the directory. + # Modify an exists system design with the given PRD(Product Requirement Document) and save to the path name. >>> user_requirement = "Your user requirements" >>> extra_info = "Your extra information" >>> prd_filename = "/path/to/prd/filename" >>> legacy_design_filename = "/path/to/exists/design/filename" - >>> output_path = "/path/to/save/" + >>> output_pathname = "/path/to/design/filename" >>> action = WriteDesign() - >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename) + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename, output_pathname=output_pathname) >>> print(result.content) System Design filename: "/path/to/design/filename" """ @@ -163,7 +167,7 @@ class WriteDesign(Action): prd_filename=prd_filename, legacy_design_filename=legacy_design_filename, extra_info=extra_info, - output_path=output_path, + output_pathname=output_pathname, ) self.input_args = with_messages[-1].instruct_content @@ -268,14 +272,16 @@ class WriteDesign(Action): prd_filename: str = "", legacy_design_filename: str = "", extra_info: str = "", - output_path: str = "", + output_pathname: str = "", ) -> AIMessage: - context = to_markdown_code_block(user_requirement) - if extra_info: - context = to_markdown_code_block(extra_info) + prd_content = "" if prd_filename: prd_content = await aread(filename=prd_filename) - context += to_markdown_code_block(prd_content) + context = "### User Requirements\n{user_requirement}\n### Extra_info\n{extra_info}\n### PRD\n{prd}\n".format( + user_requirement=to_markdown_code_block(user_requirement), + extra_info=to_markdown_code_block(extra_info), + prd=to_markdown_code_block(prd_content), + ) if not legacy_design_filename: node = await self._new_system_design(context=context) design = Document(content=node.instruct_content.model_dump_json()) @@ -285,13 +291,14 @@ class WriteDesign(Action): prd_doc=Document(content=context), system_design_doc=Document(content=old_design_content) ) - if not output_path: - return AIMessage(content=design.content) - output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json" - await awrite(filename=output_filename, data=design.content) - kvs = {"changed_system_design_filenames": [output_filename]} + if not output_pathname: + output_path = DEFAULT_WORKSPACE_ROOT + output_path.mkdir(parents=True, exist_ok=True) + output_pathname = Path(output_path) / f"{uuid.uuid4().hex}.json" + await awrite(filename=output_pathname, data=design.content) + kvs = {"changed_system_design_filenames": [output_pathname]} return AIMessage( - content=f'System Design filename: "{str(output_filename)}"', + content=f'System Design filename: "{str(output_pathname)}"', instruct_content=AIMessage.create_instruct_value(kvs=kvs), ) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 58db97c75..5ed45cab4 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -35,6 +35,7 @@ from metagpt.actions.write_prd_an import ( from metagpt.const import ( BUGFIX_FILENAME, COMPETITIVE_ANALYSIS_FILE_REPO, + DEFAULT_WORKSPACE_ROOT, REQUIREMENT_FILENAME, ) from metagpt.logs import logger @@ -82,7 +83,7 @@ class WritePRD(Action): with_messages: List[Message] = None, *, user_requirement: str = "", - output_path: str = "", + output_pathname: str = "", legacy_prd_filename: str = "", extra_info: str = "", **kwargs, @@ -92,7 +93,7 @@ class WritePRD(Action): Args: user_requirement (str): A string detailing the user's requirements. - output_path (str, optional): The file path where the output document should be saved. Defaults to "". + output_pathname (str, optional): The path name of file that the output document should be saved to. Defaults to "". legacy_prd_filename (str, optional): The file path of the legacy Product Requirement Document to use as a reference. Defaults to "". extra_info (str, optional): Additional information to include in the document. Defaults to "". **kwargs: Additional keyword arguments. @@ -107,7 +108,7 @@ class WritePRD(Action): >>> write_prd = WritePRD() >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info) >>> print(result.content) - The PRD is about balabala... + PRD filename: "/path/to/prd/directory/213434ad.json" # Modify a exists PRD(Product Requirement Document) >>> user_requirement = "YOUR REQUIREMENTS" @@ -116,24 +117,24 @@ class WritePRD(Action): >>> write_prd = WritePRD() >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename) >>> print(result.content) - The PRD is about balabala... + PRD filename: "/path/to/prd/directory/213434ad.json" - # Write and save a new PRD(Product Requirement Document) to the directory. + # Write and save a new PRD(Product Requirement Document) to the path name. >>> user_requirement = "YOUR REQUIREMENTS" >>> extra_info = "YOUR EXTRA INFO" - >>> output_path = "/path/to/prd/directory/" + >>> output_pathname = "/path/to/prd/directory/213434ad.json" >>> write_prd = WritePRD() - >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, output_path=output_path) + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, output_pathname=output_pathname) >>> print(result.content) PRD filename: "/path/to/prd/directory/213434ad.json" - # Modify a exists PRD(Product Requirement Document) and save to the directory. + # Modify a exists PRD(Product Requirement Document) and save to the path name. >>> user_requirement = "YOUR REQUIREMENTS" >>> extra_info = "YOUR EXTRA INFO" >>> legacy_prd_filename = "/path/to/exists/prd_filename" - >>> output_path = "/path/to/prd/directory/" + >>> output_pathname = "/path/to/prd/directory/213434ad.json" >>> write_prd = WritePRD() - >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename, output_path=output_path) + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename, output_pathname=output_pathname) >>> print(result.content) PRD filename: "/path/to/prd/directory/213434ad.json" @@ -141,7 +142,7 @@ class WritePRD(Action): if not with_messages: return await self._execute_api( user_requirement=user_requirement, - output_path=output_path, + output_pathname=output_pathname, legacy_prd_filename=legacy_prd_filename, extra_info=extra_info, ) @@ -306,12 +307,12 @@ class WritePRD(Action): self.repo.git_repo.rename_root(self.project_name) async def _execute_api( - self, user_requirement: str, output_path: str, legacy_prd_filename: str, extra_info: str + self, user_requirement: str, output_pathname: str, legacy_prd_filename: str, extra_info: str ) -> AIMessage: - content = to_markdown_code_block(val=user_requirement, type_="text") - if extra_info: - content += to_markdown_code_block(val=extra_info) - + content = "#### User Requirements\n{user_requirement}\n#### Extra Info\n{extra_info}\n".format( + user_requirement=to_markdown_code_block(val=user_requirement), + extra_info=to_markdown_code_block(val=extra_info), + ) req = Document(content=content) if not legacy_prd_filename: node = await self._new_prd(requirement=req.content) @@ -321,10 +322,10 @@ class WritePRD(Action): old_prd = Document(content=content) new_prd = await self._merge(req=req, related_doc=old_prd) - if not output_path: - return AIMessage(content=new_prd.content) - - output_filename = Path(output_path) / f"{uuid.uuid4().hex}.json" - await awrite(filename=output_filename, data=new_prd.content) - kvs = AIMessage.create_instruct_value({"changed_prd_filenames": [str(output_filename)]}) - return AIMessage(content=f'PRD filename: "{str(output_filename)}"', instruct_content=kvs) + if not output_pathname: + output_path = DEFAULT_WORKSPACE_ROOT + output_path.mkdir(parents=True, exist_ok=True) + output_pathname = Path(output_path) / f"{uuid.uuid4().hex}.json" + await awrite(filename=output_pathname, data=new_prd.content) + kvs = AIMessage.create_instruct_value({"changed_prd_filenames": [str(output_pathname)]}) + return AIMessage(content=f'PRD filename: "{str(output_pathname)}"', instruct_content=kvs) diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py index 8d6daec11..955058ea0 100644 --- a/metagpt/tools/libs/browser.py +++ b/metagpt/tools/libs/browser.py @@ -1,11 +1,13 @@ from __future__ import annotations + import contextlib +from uuid import uuid4 from playwright.async_api import async_playwright -from metagpt.utils.file import MemoryFileSystem -from uuid import uuid4 + from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.tools.tool_registry import register_tool +from metagpt.utils.file import MemoryFileSystem from metagpt.utils.parse_html import simplify_html from metagpt.utils.report import BrowserReporter @@ -64,7 +66,6 @@ class Browser: # Since RAG is an optional optimization, if it fails, the simplified HTML can be used as a fallback. with contextlib.suppress(Exception): - from metagpt.rag.engines import SimpleEngine # avoid circular import # TODO make `from_docs` asynchronous diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 7303d1f47..28bfe623e 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -945,5 +945,7 @@ def get_markdown_code_block_type(filename: str) -> str: def to_markdown_code_block(val: str, type_: str = "") -> str: + if not val: + return val or "" val = val.replace("```", "\\`\\`\\`") return f"\n```{type_}\n{val}\n```\n" diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index a8ed482d9..8861f65dc 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -72,7 +72,6 @@ class File: class MemoryFileSystem(_MemoryFileSystem): - @classmethod def _strip_protocol(cls, path): return super()._strip_protocol(str(path)) diff --git a/metagpt/utils/parse_html.py b/metagpt/utils/parse_html.py index 3aac8ca6c..1ed3a620c 100644 --- a/metagpt/utils/parse_html.py +++ b/metagpt/utils/parse_html.py @@ -4,11 +4,10 @@ from __future__ import annotations from typing import Generator, Optional from urllib.parse import urljoin, urlparse +import htmlmin from bs4 import BeautifulSoup from pydantic import BaseModel, PrivateAttr -import htmlmin - class WebPage(BaseModel): inner_text: str diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index 0a792fb15..3398be5e6 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -6,12 +6,12 @@ @File : test_design_api.py @Modifiled By: mashenquan, 2023-12-6. According to RFC 135 """ -import json +from pathlib import Path import pytest from metagpt.actions.design_api import WriteDesign -from metagpt.const import METAGPT_ROOT +from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT from metagpt.logs import logger from metagpt.schema import AIMessage, Message from metagpt.utils.project_repo import ProjectRepo @@ -74,8 +74,7 @@ async def test_design_api(context, user_requirement, prd_filename, legacy_design ) assert isinstance(result, AIMessage) assert result.content - m = json.loads(result.content) - assert m + assert str(DEFAULT_WORKSPACE_ROOT) in result.content @pytest.mark.parametrize( @@ -97,7 +96,7 @@ async def test_design_api_dir(context, user_requirement, prd_filename, legacy_de user_requirement=user_requirement, prd_filename=prd_filename, legacy_design_filename=legacy_design_filename, - output_path=context.config.project_path, + output_pathname=str(Path(context.config.project_path) / "1.txt"), ) assert isinstance(result, AIMessage) assert result.content diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 93a1b150c..6cf1da1dc 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -6,12 +6,13 @@ @File : test_write_prd.py @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`. """ -import json +import uuid +from pathlib import Path import pytest from metagpt.actions import UserRequirement, WritePRD -from metagpt.const import REQUIREMENT_FILENAME +from metagpt.const import DEFAULT_WORKSPACE_ROOT, REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.roles.product_manager import ProductManager from metagpt.roles.role import RoleReactMode @@ -80,10 +81,12 @@ async def test_write_prd_api(context): result = await action.run(user_requirement="write a snake game.") assert isinstance(result, AIMessage) assert result.content - m = json.loads(result.content) - assert m + assert str(DEFAULT_WORKSPACE_ROOT) in result.content - result = await action.run(user_requirement="write a snake game.", output_path=str(context.config.project_path)) + result = await action.run( + user_requirement="write a snake game.", + output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"), + ) assert isinstance(result, AIMessage) assert result.content assert result.instruct_content @@ -94,12 +97,11 @@ async def test_write_prd_api(context): result = await action.run(user_requirement="Add moving enemy.", legacy_prd_filename=legacy_prd_filename) assert isinstance(result, AIMessage) assert result.content - m = json.loads(result.content) - assert m + assert str(DEFAULT_WORKSPACE_ROOT) in result.content result = await action.run( user_requirement="Add moving enemy.", - output_path=str(context.config.project_path), + output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"), legacy_prd_filename=legacy_prd_filename, ) assert isinstance(result, AIMessage)