diff --git a/config/vault.example.yaml b/config/vault.example.yaml new file mode 100644 index 000000000..0e197d2a8 --- /dev/null +++ b/config/vault.example.yaml @@ -0,0 +1,48 @@ +# Usage: +# 1. Get value. +# >>> from metagpt.tools.libs.env import get_env +# >>> access_token = await get_env(key="access_token", app_name="github") +# >>> print(access_token) +# YOUR_ACCESS_TOKEN +# +# 2. Get description for LLM understanding. +# >>> from metagpt.tools.libs.env import get_env_description +# >>> descriptions = await get_env_description +# >>> for k, desc in descriptions.items(): +# >>> print(f"{key}:{desc}") +# await get_env(key="access_token", app_name="github"):Get github access token +# await get_env(key="access_token", app_name="gitlab"):Get gitlab access token +# ... + +vault: + github: + values: + access_token: "YOUR_ACCESS_TOKEN" + descriptions: + access_token: "Get github access token" + gitlab: + values: + access_token: "YOUR_ACCESS_TOKEN" + descriptions: + access_token: "Get gitlab access token" + iflytek_tts: + values: + api_id: "YOUR_APP_ID" + api_key: "YOUR_API_KEY" + api_secret: "YOUR_API_SECRET" + descriptions: + api_id: "Get the API ID of IFlyTek Text to Speech" + api_key: "Get the API KEY of IFlyTek Text to Speech" + api_secret: "Get the API SECRET of IFlyTek Text to Speech" + azure_tts: + values: + subscription_key: "YOUR_SUBSCRIPTION_KEY" + region: "YOUR_REGION" + descriptions: + subscription_key: "Get the subscription key of Azure Text to Speech." + region: "Get the region of Azure Text to Speech." + default: # All key-value pairs whose app name is an empty string are placed below + values: + proxy: "YOUR_PROXY" + descriptions: + proxy: "Get proxy for tools like requests, playwright, selenium, etc." \ No newline at end of file diff --git a/examples/di/crawl_webpage.py b/examples/di/crawl_webpage.py index b8226f4f4..10b230f2b 100644 --- a/examples/di/crawl_webpage.py +++ b/examples/di/crawl_webpage.py @@ -6,16 +6,19 @@ """ from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.tools.libs.browser import Browser as _ + PAPER_LIST_REQ = """" Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/, -and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables* +and save it to a csv file. paper title must include `multiagent` or `large language model`. +**Notice: view the page element before writing scraping code** """ ECOMMERCE_REQ = """ Get products data from website https://scrapeme.live/shop/ and save it as a csv file. -**Notice: Firstly parse the web page encoding and the text HTML structure; -The first page product name, price, product URL, and image URL must be saved in the csv;** +The first page product name, price, product URL, and image URL must be saved in the csv. +**Notice: view the page element before writing scraping code** """ NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**; @@ -25,11 +28,12 @@ NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间; 4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。 5. 将全部结果存在本地csv中 +**Notice: view the page element before writing scraping code** """ async def main(): - di = DataInterpreter(tools=["scrape_web_playwright"]) + di = DataInterpreter(tools=["Browser"]) await di.run(ECOMMERCE_REQ) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 5fd538720..b760c96d8 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -22,7 +22,6 @@ from metagpt.schema import ( SerializationMixin, TestingContext, ) -from metagpt.utils.project_repo import ProjectRepo class Action(SerializationMixin, ContextMixin, BaseModel): @@ -36,12 +35,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel): desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) - @property - def repo(self) -> ProjectRepo: - if not self.context.repo: - self.context.repo = ProjectRepo(self.context.git_repo) - return self.context.repo - @property def prompt_schema(self): return self.config.prompt_schema diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index b027616f7..8f0f52266 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -9,13 +9,15 @@ 2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. """ import re +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser +from metagpt.utils.project_repo import ProjectRepo PROMPT_TEMPLATE = """ NOTICE @@ -47,6 +49,8 @@ Now you should start rewriting the code: class DebugError(Action): i_context: RunCodeContext = Field(default_factory=RunCodeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) async def run(self, *args, **kwargs) -> str: output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename) @@ -59,9 +63,7 @@ class DebugError(Action): return "" logger.info(f"Debug and rewrite {self.i_context.test_filename}") - code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get( - filename=self.i_context.code_filename - ) + code_doc = await self.repo.srcs.get(filename=self.i_context.code_filename) if not code_doc: return "" test_doc = await self.repo.tests.get(filename=self.i_context.test_filename) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 613c4a47b..981dde53f 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -8,10 +8,14 @@ 1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. 2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ import json +import uuid from pathlib import Path -from typing import Optional +from typing import List, Optional + +from pydantic import BaseModel, Field from metagpt.actions import Action from metagpt.actions.design_api_an import ( @@ -22,10 +26,17 @@ from metagpt.actions.design_api_an import ( REFINED_DESIGN_NODE, REFINED_PROGRAM_CALL_FLOW, ) -from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO +from metagpt.const import ( + DATA_API_DESIGN_FILE_REPO, + DEFAULT_WORKSPACE_ROOT, + SEQ_FLOW_FILE_REPO, +) from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import aread, awrite, to_markdown_code_block from metagpt.utils.mermaid import mermaid_to_file +from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter, GalleryReporter NEW_REQ_TEMPLATE = """ @@ -37,6 +48,7 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write system design"]) class WriteDesign(Action): name: str = "" i_context: Optional[str] = None @@ -45,21 +57,134 @@ class WriteDesign(Action): "data structures, library tables, processes, and paths. Please provide your design, feedback " "clearly and in detail." ) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) - async def run(self, with_messages: Message, schema: str = None): - # Use `git status` to identify which PRD documents have been modified in the `docs/prd` directory. - changed_prds = self.repo.docs.prd.changed_files - # Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone - # changes. - changed_system_designs = self.repo.docs.system_design.changed_files + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + prd_filename: str = "", + legacy_design_filename: str = "", + extra_info: str = "", + output_pathname: str = "", + **kwargs, + ) -> AIMessage: + """ + Write a system design. + + Args: + user_requirement (str): The user's requirements for the system design. + prd_filename (str, optional): The filename of the Product Requirement Document (PRD). + legacy_design_filename (str, optional): The filename of the legacy design document. + extra_info (str, optional): Additional information to be included in the system design. + output_pathname (str, optional): The output path name of file that the system design should be saved to. + + Returns: + AIMessage: An AIMessage object containing the system design. + + Example: + # Write a new system design. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Modify an exists system design. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> legacy_design_filename = "/path/to/exists/design/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Write a new system design with the given PRD(Product Requirement Document). + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> prd_filename = "/path/to/prd/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Modify an exists system design with the given PRD(Product Requirement Document). + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> prd_filename = "/path/to/prd/filename" + >>> legacy_design_filename = "/path/to/exists/design/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename) + >>> print(result.content) + TSystem Design filename: "/path/to/design/filename" + + # Write a new system design and save to the path name. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> output_pathname = "/path/to/design/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, output_pathname=output_pathname) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Modify an exists system design and save to the path name. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> legacy_design_filename = "/path/to/exists/design/filename" + >>> output_pathname = "/path/to/design/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, output_pathname=output_pathname) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Write a new system design with the given PRD(Product Requirement Document) and save to the path name. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> prd_filename = "/path/to/prd/filename" + >>> output_pathname = "/path/to/design/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename, output_pathname=output_pathname) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + + # Modify an exists system design with the given PRD(Product Requirement Document) and save to the path name. + >>> user_requirement = "Your user requirements" + >>> extra_info = "Your extra information" + >>> prd_filename = "/path/to/prd/filename" + >>> legacy_design_filename = "/path/to/exists/design/filename" + >>> output_pathname = "/path/to/design/filename" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename, output_pathname=output_pathname) + >>> print(result.content) + System Design filename: "/path/to/design/filename" + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, + prd_filename=prd_filename, + legacy_design_filename=legacy_design_filename, + extra_info=extra_info, + output_pathname=output_pathname, + ) + + self.input_args = with_messages[-1].instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + changed_prds = self.input_args.changed_prd_filenames + changed_system_designs = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] # For those PRDs and design documents that have undergone changes, regenerate the design content. changed_files = Documents() - for filename in changed_prds.keys(): + for filename in changed_prds: doc = await self._update_system_design(filename=filename) changed_files.docs[filename] = doc - for filename in changed_system_designs.keys(): + for filename in changed_system_designs: if filename in changed_files.docs: continue doc = await self._update_system_design(filename=filename) @@ -68,6 +193,11 @@ class WriteDesign(Action): logger.info("Nothing has changed.") # Wait until all files under `docs/system_designs/` are processed before sending the publish message, # leaving room for global optimization in subsequent steps. + kvs = self.input_args.model_dump() + kvs["changed_system_design_filenames"] = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] return AIMessage( content="Designing is complete. " + "\n".join( @@ -75,6 +205,7 @@ class WriteDesign(Action): + list(self.repo.resources.data_api_design.changed_files.keys()) + list(self.repo.resources.seq_flow.changed_files.keys()) ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput"), cause_by=self, ) @@ -89,14 +220,15 @@ class WriteDesign(Action): return system_design_doc async def _update_system_design(self, filename) -> Document: - prd = await self.repo.docs.prd.get(filename) - old_system_design_doc = await self.repo.docs.system_design.get(filename) + root_relative_path = Path(filename).relative_to(self.repo.workdir) + prd = await Document.load(filename=filename, project_path=self.repo.workdir) + old_system_design_doc = await self.repo.docs.system_design.get(root_relative_path.name) async with DocsReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "design"}, "meta") if not old_system_design_doc: system_design = await self._new_system_design(context=prd.content) doc = await self.repo.docs.system_design.save( - filename=filename, + filename=prd.filename, content=system_design.instruct_content.model_dump_json(), dependencies={prd.root_relative_path}, ) @@ -133,3 +265,40 @@ class WriteDesign(Action): image_path = pathname.parent / f"{pathname.name}.png" if image_path.exists(): await GalleryReporter().async_report(image_path, "path") + + async def _execute_api( + self, + user_requirement: str = "", + prd_filename: str = "", + legacy_design_filename: str = "", + extra_info: str = "", + output_pathname: str = "", + ) -> AIMessage: + prd_content = "" + if prd_filename: + prd_content = await aread(filename=prd_filename) + context = "### User Requirements\n{user_requirement}\n### Extra_info\n{extra_info}\n### PRD\n{prd}\n".format( + user_requirement=to_markdown_code_block(user_requirement), + extra_info=to_markdown_code_block(extra_info), + prd=to_markdown_code_block(prd_content), + ) + if not legacy_design_filename: + node = await self._new_system_design(context=context) + design = Document(content=node.instruct_content.model_dump_json()) + else: + old_design_content = await aread(filename=legacy_design_filename) + design = await self._merge( + prd_doc=Document(content=context), system_design_doc=Document(content=old_design_content) + ) + + if not output_pathname: + output_path = DEFAULT_WORKSPACE_ROOT + output_path.mkdir(parents=True, exist_ok=True) + output_pathname = Path(output_path) / f"{uuid.uuid4().hex}.json" + await awrite(filename=output_pathname, data=design.content) + kvs = {"changed_system_design_filenames": [output_pathname]} + + return AIMessage( + content=f'System Design filename: "{str(output_pathname)}"', + instruct_content=AIMessage.create_instruct_value(kvs=kvs), + ) diff --git a/metagpt/actions/di/execute_nb_code.py b/metagpt/actions/di/execute_nb_code.py index b4fe949fe..64620d9cc 100644 --- a/metagpt/actions/di/execute_nb_code.py +++ b/metagpt/actions/di/execute_nb_code.py @@ -65,7 +65,7 @@ class ExecuteNbCode(Action): """execute notebook code block, return result to llm, and display it.""" nb: NotebookNode - nb_client: NotebookClient = None + nb_client: RealtimeOutputNotebookClient = None console: Console interaction: str timeout: int = 600 @@ -78,11 +78,15 @@ class ExecuteNbCode(Action): interaction=("ipython" if self.is_ipython() else "terminal"), ) self.reporter = NotebookReporter() + self.set_nb_client() + + def set_nb_client(self): self.nb_client = RealtimeOutputNotebookClient( - nb, - timeout=timeout, + self.nb, + timeout=self.timeout, resources={"metadata": {"path": DEFAULT_WORKSPACE_ROOT}}, notebook_reporter=self.reporter, + coalesce_streams=True, ) async def build(self): @@ -118,7 +122,7 @@ class ExecuteNbCode(Action): # sleep 1s to wait for the kernel to be cleaned up completely await asyncio.sleep(1) await self.build() - self.nb_client = NotebookClient(self.nb, timeout=self.timeout) + self.set_nb_client() def add_code_cell(self, code: str): self.nb.cells.append(new_code_cell(source=code)) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index eb674374c..393c483cc 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -17,6 +17,7 @@ from metagpt.logs import logger from metagpt.schema import AIMessage from metagpt.utils.common import any_to_str from metagpt.utils.file_repository import FileRepository +from metagpt.utils.project_repo import ProjectRepo class PrepareDocuments(Action): @@ -36,7 +37,7 @@ class PrepareDocuments(Action): def config(self): return self.context.config - def _init_repo(self): + def _init_repo(self) -> ProjectRepo: """Initialize the Git environment.""" if not self.config.project_path: name = self.config.project_name or FileRepository.new_filename() @@ -45,8 +46,9 @@ class PrepareDocuments(Action): path = Path(self.config.project_path) if path.exists() and not self.config.inc: shutil.rmtree(path) - self.config.project_path = path - self.context.set_repo_dir(path) + self.context.kwargs.project_path = path + self.context.kwargs.inc = self.config.inc + return ProjectRepo(path) async def run(self, with_messages, **kwargs): """Create and initialize the workspace folder, initialize the Git environment.""" @@ -67,10 +69,22 @@ class PrepareDocuments(Action): max_auto_summarize_code=0, ) - self._init_repo() + repo = self._init_repo() # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. - doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content) + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content) # Send a Message notification to the WritePRD action, instructing it to process requirements using # `docs/requirement.txt` and `docs/prd/`. - return AIMessage(content="", instruct_content=doc, cause_by=self, send_to=self.send_to) + return AIMessage( + content="", + instruct_content=AIMessage.create_instruct_value( + kvs={ + "project_path": str(repo.workdir), + "requirements_filename": str(repo.docs.workdir / REQUIREMENT_FILENAME), + "prd_filenames": [str(repo.docs.prd.workdir / i) for i in repo.docs.prd.all_files], + }, + class_name="PrepareDocumentsOutput", + ), + cause_by=self, + send_to=self.send_to, + ) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index ef0fe6fc6..b44bfb9f3 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -8,16 +8,23 @@ 1. Divide the context into three components: legacy code, unit test code, and console log. 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ import json -from typing import Optional +from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel, Field from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME from metagpt.logs import logger -from metagpt.schema import AIMessage, Document, Documents +from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import aread, to_markdown_code_block +from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter NEW_REQ_TEMPLATE = """ @@ -29,19 +36,56 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write a project schedule given a project system design file"]) class WriteTasks(Action): name: str = "CreateTasks" i_context: Optional[str] = None + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) - async def run(self, with_messages): - changed_system_designs = self.repo.docs.system_design.changed_files - changed_tasks = self.repo.docs.task.changed_files + async def run( + self, with_messages: List[Message] = None, *, user_requirement: str = "", design_filename: str = "", **kwargs + ) -> AIMessage: + """ + Write a project schedule given a project system design file. + + Args: + user_requirement (str, optional): A string specifying the user's requirements. Defaults to an empty string. + design_filename (str): The filename of the project system design file. Defaults to an empty string. + **kwargs: Additional keyword arguments. + + Returns: + AIMessage: The generated project schedule. + + Example: + # Write a new project schedule. + >>> design_filename = "/path/to/design/filename" + >>> action = WriteTasks() + >>> result = await action.run(design_filename=design_filename) + >>> print(result.content) + The project schedule is balabala... + + # Write a new project schedule with the user requirement. + >>> design_filename = "/path/to/design/filename" + >>> user_requirement = "Your user requirements" + >>> action = WriteTasks() + >>> result = await action.run(design_filename=design_filename, user_requirement=user_requirement) + >>> print(result.content) + The project schedule is balabala... + """ + if not with_messages: + return await self._execute_api(user_requirement=user_requirement, design_filename=design_filename) + + self.input_args = with_messages[-1].instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + changed_system_designs = self.input_args.changed_system_design_filenames + changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())] change_files = Documents() # Rewrite the system designs that have undergone changes based on the git head diff under # `docs/system_designs/`. for filename in changed_system_designs: task_doc = await self._update_tasks(filename=filename) - change_files.docs[filename] = task_doc + change_files.docs[str(self.repo.docs.task.workdir / task_doc.filename)] = task_doc # Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`. for filename in changed_tasks: @@ -54,6 +98,11 @@ class WriteTasks(Action): logger.info("Nothing has changed.") # Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for # global optimization in subsequent steps. + kvs = self.input_args.model_dump() + kvs["changed_task_filenames"] = [ + str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys()) + ] + kvs["python_package_dependency_filename"] = str(self.repo.workdir / PACKAGE_REQUIREMENTS_FILENAME) return AIMessage( content="WBS is completed. " + "\n".join( @@ -61,12 +110,14 @@ class WriteTasks(Action): + list(self.repo.docs.task.changed_files.keys()) + list(self.repo.resources.api_spec_and_task.changed_files.keys()) ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTaskOutput"), cause_by=self, ) async def _update_tasks(self, filename): - system_design_doc = await self.repo.docs.system_design.get(filename) - task_doc = await self.repo.docs.task.get(filename) + root_relative_path = Path(filename).relative_to(self.repo.workdir) + system_design_doc = await Document.load(filename=filename, project_path=self.repo.workdir) + task_doc = await self.repo.docs.task.get(root_relative_path.name) async with DocsReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "task"}, "meta") if task_doc: @@ -75,7 +126,7 @@ class WriteTasks(Action): else: rsp = await self._run_new_tasks(context=system_design_doc.content) task_doc = await self.repo.docs.task.save( - filename=filename, + filename=system_design_doc.filename, content=rsp.instruct_content.model_dump_json(), dependencies={system_design_doc.root_relative_path}, ) @@ -84,7 +135,7 @@ class WriteTasks(Action): await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") return task_doc - async def _run_new_tasks(self, context): + async def _run_new_tasks(self, context: str): node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema) return node @@ -106,3 +157,11 @@ class WriteTasks(Action): continue packages.add(pkg) await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages)) + + async def _execute_api(self, user_requirement: str = "", design_filename: str = ""): + context = to_markdown_code_block(user_requirement) + if not design_filename: + content = await aread(filename=design_filename) + context += to_markdown_code_block(content) + node = await self._run_new_tasks(context) + return AIMessage(content=node.instruct_content.model_dump_json()) diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index d21b62f83..e3556caa7 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -6,13 +6,16 @@ @Modified By: mashenquan, 2023/12/5. Archive the summarization content of issue discovery for use in WriteCode. """ from pathlib import Path +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext +from metagpt.utils.common import get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo PROMPT_TEMPLATE = """ NOTICE @@ -90,6 +93,8 @@ flowchart TB class SummarizeCode(Action): name: str = "SummarizeCode" i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): @@ -101,11 +106,10 @@ class SummarizeCode(Action): design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name) task_pathname = Path(self.i_context.task_filename) task_doc = await self.repo.docs.task.get(filename=task_pathname.name) - src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs code_blocks = [] for filename in self.i_context.codes_filenames: - code_doc = await src_file_repo.get(filename) - code_block = f"```python\n{code_doc.content}\n```\n-----" + code_doc = await self.repo.srcs.get(filename) + code_block = f"```{get_markdown_code_block_type(filename)}\n{code_doc.content}\n```\n---\n" code_blocks.append(code_block) format_example = FORMAT_EXAMPLE prompt = PROMPT_TEMPLATE.format( diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 67b859d23..da25fe621 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -16,17 +16,18 @@ """ import json +from pathlib import Path +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE -from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult -from metagpt.utils.common import CodeParser +from metagpt.utils.common import CodeParser, get_markdown_code_block_type from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import EditorReporter @@ -44,9 +45,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc {task} ## Legacy Code -```Code {code} -``` ## Debug logs ```text @@ -61,14 +60,14 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc ``` # Format example -## Code: {filename} +## Code: {demo_filename}.py ```python -## {filename} +## {demo_filename}.py ... ``` -## Code: {filename} +## Code: {demo_filename}.js ```javascript -// {filename} +// {demo_filename}.js ... ``` @@ -89,6 +88,8 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" i_context: Document = Field(default_factory=Document) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: @@ -97,10 +98,16 @@ class WriteCode(Action): return code async def run(self, *args, **kwargs) -> CodingContext: - bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME) + bug_feedback = None + if self.input_args and hasattr(self.input_args, "issue_filename"): + bug_feedback = await Document.load(self.input_args.issue_filename) coding_context = CodingContext.loads(self.i_context.content) + if not coding_context.code_plan_and_change_doc: + coding_context.code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get( + filename=coding_context.task_doc.filename + ) test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json") - requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME) + requirement_doc = await Document.load(self.input_args.requirements_filename) summary_doc = None if coding_context.design_doc and coding_context.design_doc.filename: summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename) @@ -109,29 +116,28 @@ class WriteCode(Action): test_detail = RunCodeResult.loads(test_doc.content) logs = test_detail.stderr - if bug_feedback: - code_context = coding_context.code_doc.content - elif self.config.inc: + if self.config.inc or bug_feedback: code_context = await self.get_codes( coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True ) else: code_context = await self.get_codes( - coding_context.task_doc, - exclude=self.i_context.filename, - project_repo=self.repo.with_src_path(self.context.src_workspace), + coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo ) if self.config.inc: prompt = REFINED_TEMPLATE.format( user_requirement=requirement_doc.content if requirement_doc else "", - code_plan_and_change=str(coding_context.code_plan_and_change_doc), + code_plan_and_change=coding_context.code_plan_and_change_doc.content + if coding_context.code_plan_and_change_doc + else "", design=coding_context.design_doc.content if coding_context.design_doc else "", task=coding_context.task_doc.content if coding_context.task_doc else "", code=code_context, logs=logs, feedback=bug_feedback.content if bug_feedback else "", filename=self.i_context.filename, + demo_filename=Path(self.i_context.filename).stem, summary_log=summary_doc.content if summary_doc else "", ) else: @@ -142,6 +148,7 @@ class WriteCode(Action): logs=logs, feedback=bug_feedback.content if bug_feedback else "", filename=self.i_context.filename, + demo_filename=Path(self.i_context.filename).stem, summary_log=summary_doc.content if summary_doc else "", ) logger.info(f"Writing {coding_context.filename}..") @@ -150,10 +157,11 @@ class WriteCode(Action): code = await self.write_code(prompt) if not coding_context.code_doc: # avoid root_path pydantic ValidationError if use WriteCode alone - root_path = self.context.src_workspace if self.context.src_workspace else "" - coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) + coding_context.code_doc = Document( + filename=coding_context.filename, root_path=str(self.repo.src_relative_path) + ) coding_context.code_doc.content = code - await reporter.async_report(self.repo.workdir / coding_context.code_doc.root_relative_path, "path") + await reporter.async_report(coding_context.code_doc, "document") return coding_context @staticmethod @@ -178,35 +186,32 @@ class WriteCode(Action): code_filenames = m.get(TASK_LIST.key, []) if not use_inc else m.get(REFINED_TASK_LIST.key, []) codes = [] src_file_repo = project_repo.srcs - # Incremental development scenario if use_inc: - src_files = src_file_repo.all_files - # Get the old workspace contained the old codes and old workspace are created in previous CodePlanAndChange - old_file_repo = project_repo.git_repo.new_file_repository(relative_path=project_repo.old_workspace) - old_files = old_file_repo.all_files - # Get the union of the files in the src and old workspaces - union_files_list = list(set(src_files) | set(old_files)) - for filename in union_files_list: + for filename in src_file_repo.all_files: + code_block_type = get_markdown_code_block_type(filename) # Exclude the current file from the all code snippets if filename == exclude: # If the file is in the old workspace, use the old code # Exclude unnecessary code to maintain a clean and focused main.py file, ensuring only relevant and # essential functionality is included for the project’s requirements - if filename in old_files and filename != "main.py": + if filename != "main.py": # Use old code - doc = await old_file_repo.get(filename=filename) + doc = await src_file_repo.get(filename=filename) # If the file is in the src workspace, skip it else: continue - codes.insert(0, f"-----Now, {filename} to be rewritten\n```{doc.content}```\n=====") + codes.insert( + 0, f"### The name of file to rewrite: `{filename}`\n```{code_block_type}\n{doc.content}```\n" + ) + logger.info(f"Prepare to rewrite `{filename}`") # The code snippets are generated from the src workspace else: doc = await src_file_repo.get(filename=filename) # If the file does not exist in the src workspace, skip it if not doc: continue - codes.append(f"----- {filename}\n```{doc.content}```") + codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n") # Normal scenario else: @@ -217,6 +222,7 @@ class WriteCode(Action): doc = await src_file_repo.get(filename=filename) if not doc: continue - codes.append(f"----- {filename}\n```{doc.content}```") + code_block_type = get_markdown_code_block_type(filename) + codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n") return "\n".join(codes) diff --git a/metagpt/actions/write_code_plan_and_change_an.py b/metagpt/actions/write_code_plan_and_change_an.py index a3c0e50a4..31482a94d 100644 --- a/metagpt/actions/write_code_plan_and_change_an.py +++ b/metagpt/actions/write_code_plan_and_change_an.py @@ -5,15 +5,16 @@ @Author : mannaandpoem @File : write_code_plan_and_change_an.py """ -import os -from typing import List +from typing import List, Optional -from pydantic import Field +from pydantic import BaseModel, Field from metagpt.actions.action import Action from metagpt.actions.action_node import ActionNode from metagpt.logs import logger -from metagpt.schema import CodePlanAndChangeContext +from metagpt.schema import CodePlanAndChangeContext, Document +from metagpt.utils.common import get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo DEVELOPMENT_PLAN = ActionNode( key="Development Plan", @@ -162,9 +163,8 @@ Role: You are a professional engineer; The main goal is to complete incremental {task} ## Legacy Code -```Code {code} -``` + ## Debug logs ```text @@ -179,14 +179,14 @@ Role: You are a professional engineer; The main goal is to complete incremental ``` # Format example -## Code: {filename} +## Code: {demo_filename}.py ```python -## {filename} +## {demo_filename}.py ... ``` -## Code: {filename} +## Code: {demo_filename}.js ```javascript -// {filename} +// {demo_filename}.js ... ``` @@ -211,13 +211,15 @@ WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChan class WriteCodePlanAndChange(Action): name: str = "WriteCodePlanAndChange" i_context: CodePlanAndChangeContext = Field(default_factory=CodePlanAndChangeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) async def run(self, *args, **kwargs): self.llm.system_prompt = "You are a professional software engineer, your primary responsibility is to " "meticulously craft comprehensive incremental development plan and deliver detailed incremental change" - prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename) - design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename) - task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename) + prd_doc = await Document.load(filename=self.i_context.prd_filename) + design_doc = await Document.load(filename=self.i_context.design_filename) + task_doc = await Document.load(filename=self.i_context.task_filename) context = CODE_PLAN_AND_CHANGE_CONTEXT.format( requirement=f"```text\n{self.i_context.requirement}\n```", issue=f"```text\n{self.i_context.issue}\n```", @@ -230,8 +232,9 @@ class WriteCodePlanAndChange(Action): return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json") async def get_old_codes(self) -> str: - self.repo.old_workspace = self.repo.git_repo.workdir / os.path.basename(self.config.project_path) - old_file_repo = self.repo.git_repo.new_file_repository(relative_path=self.repo.old_workspace) - old_codes = await old_file_repo.get_all() - codes = [f"----- {code.filename}\n```{code.content}```" for code in old_codes] + old_codes = await self.repo.srcs.get_all() + codes = [ + f"### File Name: `{code.filename}`\n```{get_markdown_code_block_type(code.filename)}\n{code.content}```\n" + for code in old_codes + ] return "\n".join(codes) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index f0faea701..ad99de2dd 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -7,16 +7,18 @@ @Modified By: mashenquan, 2023/11/27. Following the think-act principle, solidify the task parameters when creating the WriteCode object, rather than passing them in when calling the run function. """ +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode from metagpt.actions.action import Action -from metagpt.const import REQUIREMENT_FILENAME from metagpt.logs import logger -from metagpt.schema import CodingContext +from metagpt.schema import CodingContext, Document from metagpt.utils.common import CodeParser +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import EditorReporter PROMPT_TEMPLATE = """ # System @@ -126,18 +128,27 @@ or class WriteCodeReview(Action): name: str = "WriteCodeReview" i_context: CodingContext = Field(default_factory=CodingContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) - async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): + async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, doc): + filename = doc.filename cr_rsp = await self._aask(context_prompt + cr_prompt) result = CodeParser.parse_block("Code Review Result", cr_rsp) if "LGTM" in result: return result, None # if LBTM, rewrite code - rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}" - code_rsp = await self._aask(rewrite_prompt) - code = CodeParser.parse_code(text=code_rsp) + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report( + {"type": "code", "filename": filename, "src_path": doc.root_relative_path}, "meta" + ) + rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}" + code_rsp = await self._aask(rewrite_prompt) + code = CodeParser.parse_code(text=code_rsp) + doc.content = code + await reporter.async_report(doc, "document") return result, code async def run(self, *args, **kwargs) -> CodingContext: @@ -150,7 +161,7 @@ class WriteCodeReview(Action): code_context = await WriteCode.get_codes( self.i_context.task_doc, exclude=self.i_context.filename, - project_repo=self.repo.with_src_path(self.context.src_workspace), + project_repo=self.repo, use_inc=self.config.inc, ) @@ -160,7 +171,7 @@ class WriteCodeReview(Action): "## Code Files\n" + code_context + "\n", ] if self.config.inc: - requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME) + requirement_doc = await Document.load(filename=self.input_args.requirements_filename) insert_ctx_list = [ "## User New Requirements\n" + str(requirement_doc) + "\n", "## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n", @@ -182,7 +193,7 @@ class WriteCodeReview(Action): f"len(self.i_context.code_doc.content)={len2}" ) result, rewrited_code = await self.write_code_review_and_rewrite( - context_prompt, cr_prompt, self.i_context.code_doc.filename + context_prompt, cr_prompt, self.i_context.code_doc ) if "LBTM" in result: iterative_code = rewrited_code diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index a4f6e1dd1..5ed45cab4 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -9,12 +9,17 @@ 2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality. 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. """ from __future__ import annotations import json +import uuid from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode @@ -30,13 +35,16 @@ from metagpt.actions.write_prd_an import ( from metagpt.const import ( BUGFIX_FILENAME, COMPETITIVE_ANALYSIS_FILE_REPO, + DEFAULT_WORKSPACE_ROOT, REQUIREMENT_FILENAME, ) from metagpt.logs import logger from metagpt.schema import AIMessage, Document, Documents, Message -from metagpt.utils.common import CodeParser +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import CodeParser, aread, awrite, to_markdown_code_block from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file +from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import DocsReporter, GalleryReporter CONTEXT_TEMPLATE = """ @@ -59,6 +67,7 @@ NEW_REQ_TEMPLATE = """ """ +@register_tool(tags=["software development", "write product requirement documents"]) class WritePRD(Action): """WritePRD deal with the following situations: 1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated. @@ -66,10 +75,97 @@ class WritePRD(Action): 3. Requirement update: If the requirement is an update, the PRD document will be updated. """ - async def run(self, with_messages, *args, **kwargs) -> Message: - """Run the action.""" - req: Document = await self.repo.requirement - docs: list[Document] = await self.repo.docs.prd.get_all() + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + output_pathname: str = "", + legacy_prd_filename: str = "", + extra_info: str = "", + **kwargs, + ) -> AIMessage: + """ + Write a Product Requirement Document. + + Args: + user_requirement (str): A string detailing the user's requirements. + output_pathname (str, optional): The path name of file that the output document should be saved to. Defaults to "". + legacy_prd_filename (str, optional): The file path of the legacy Product Requirement Document to use as a reference. Defaults to "". + extra_info (str, optional): Additional information to include in the document. Defaults to "". + **kwargs: Additional keyword arguments. + + Returns: + AIMessage: The resulting message after generating the Product Requirement Document. + + Example: + # Write a new PRD(Product Requirement Document) + >>> user_requirement = "YOUR REQUIREMENTS" + >>> extra_info = "YOUR EXTRA INFO" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info) + >>> print(result.content) + PRD filename: "/path/to/prd/directory/213434ad.json" + + # Modify a exists PRD(Product Requirement Document) + >>> user_requirement = "YOUR REQUIREMENTS" + >>> extra_info = "YOUR EXTRA INFO" + >>> legacy_prd_filename = "/path/to/exists/prd_filename" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename) + >>> print(result.content) + PRD filename: "/path/to/prd/directory/213434ad.json" + + # Write and save a new PRD(Product Requirement Document) to the path name. + >>> user_requirement = "YOUR REQUIREMENTS" + >>> extra_info = "YOUR EXTRA INFO" + >>> output_pathname = "/path/to/prd/directory/213434ad.json" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, output_pathname=output_pathname) + >>> print(result.content) + PRD filename: "/path/to/prd/directory/213434ad.json" + + # Modify a exists PRD(Product Requirement Document) and save to the path name. + >>> user_requirement = "YOUR REQUIREMENTS" + >>> extra_info = "YOUR EXTRA INFO" + >>> legacy_prd_filename = "/path/to/exists/prd_filename" + >>> output_pathname = "/path/to/prd/directory/213434ad.json" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename, output_pathname=output_pathname) + >>> print(result.content) + PRD filename: "/path/to/prd/directory/213434ad.json" + + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, + output_pathname=output_pathname, + legacy_prd_filename=legacy_prd_filename, + extra_info=extra_info, + ) + + self.input_args = with_messages[-1].instruct_content + if not self.input_args: + self.repo = ProjectRepo(self.context.kwargs.project_path) + await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content) + self.input_args = AIMessage.create_instruct_value( + kvs={ + "project_path": self.context.kwargs.project_path, + "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), + "prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files], + }, + class_name="PrepareDocumentsOutput", + ) + else: + self.repo = ProjectRepo(self.input_args.project_path) + req = await Document.load(filename=self.input_args.requirements_filename) + docs: list[Document] = [ + await Document.load(filename=i, project_path=self.repo.workdir) for i in self.input_args.prd_filenames + ] + if not req: raise FileNotFoundError("No requirement document found.") @@ -82,10 +178,18 @@ class WritePRD(Action): # if requirement is related to other documents, update them, otherwise create a new one if related_docs := await self.get_related_docs(req, docs): logger.info(f"Requirement update detected: {req.content}") - await self._handle_requirement_update(req, related_docs) + await self._handle_requirement_update(req=req, related_docs=related_docs) else: logger.info(f"New requirement detected: {req.content}") await self._handle_new_requirement(req) + + kvs = self.input_args.model_dump() + kvs["changed_prd_filenames"] = [ + str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys()) + ] + kvs["project_path"] = str(self.repo.workdir) + kvs["requirements_filename"] = str(self.repo.docs.workdir / REQUIREMENT_FILENAME) + self.context.kwargs.project_path = str(self.repo.workdir) return AIMessage( content="PRD is completed. " + "\n".join( @@ -93,6 +197,7 @@ class WritePRD(Action): + list(self.repo.resources.prd.changed_files.keys()) + list(self.repo.resources.competitive_analysis.changed_files.keys()) ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput"), cause_by=self, ) @@ -103,19 +208,31 @@ class WritePRD(Action): return AIMessage( content=f"A new issue is received: {BUGFIX_FILENAME}", cause_by=FixBug, + instruct_content=AIMessage.create_instruct_value( + { + "project_path": str(self.repo.workdir), + "issue_filename": str(self.repo.docs.workdir / BUGFIX_FILENAME), + "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), + }, + class_name="IssueDetail", + ), send_to="Alex", # the name of Engineer ) + async def _new_prd(self, requirement: str) -> ActionNode: + project_name = self.project_name + context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name) + exclude = [PROJECT_NAME.key] if project_name else [] + node = await WRITE_PRD_NODE.fill( + context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema + ) # schema=schema + return node + async def _handle_new_requirement(self, req: Document) -> ActionOutput: """handle new requirement""" async with DocsReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "prd"}, "meta") - project_name = self.project_name - context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name) - exclude = [PROJECT_NAME.key] if project_name else [] - node = await WRITE_PRD_NODE.fill( - context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema - ) # schema=schema + node = await self._new_prd(req.content) await self._rename_workspace(node) new_prd_doc = await self.repo.docs.prd.save( filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json() @@ -128,7 +245,7 @@ class WritePRD(Action): async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput: # ... requirement update logic ... for doc in related_docs: - await self._update_prd(req, doc) + await self._update_prd(req=req, prd_doc=doc) return Documents.from_iterable(documents=related_docs).to_action_output() async def _is_bugfix(self, context: str) -> bool: @@ -159,7 +276,7 @@ class WritePRD(Action): async def _update_prd(self, req: Document, prd_doc: Document) -> Document: async with DocsReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "prd"}, "meta") - new_prd_doc: Document = await self._merge(req, prd_doc) + new_prd_doc: Document = await self._merge(req=req, related_doc=prd_doc) await self.repo.docs.prd.save_doc(doc=new_prd_doc) await self._save_competitive_analysis(new_prd_doc) md = await self.repo.resources.prd.save_pdf(doc=new_prd_doc) @@ -186,4 +303,29 @@ class WritePRD(Action): ws_name = CodeParser.parse_str(block="Project Name", text=prd) if ws_name: self.project_name = ws_name - self.repo.git_repo.rename_root(self.project_name) + if self.repo: + self.repo.git_repo.rename_root(self.project_name) + + async def _execute_api( + self, user_requirement: str, output_pathname: str, legacy_prd_filename: str, extra_info: str + ) -> AIMessage: + content = "#### User Requirements\n{user_requirement}\n#### Extra Info\n{extra_info}\n".format( + user_requirement=to_markdown_code_block(val=user_requirement), + extra_info=to_markdown_code_block(val=extra_info), + ) + req = Document(content=content) + if not legacy_prd_filename: + node = await self._new_prd(requirement=req.content) + new_prd = Document(content=node.instruct_content.model_dump_json()) + else: + content = await aread(filename=legacy_prd_filename) + old_prd = Document(content=content) + new_prd = await self._merge(req=req, related_doc=old_prd) + + if not output_pathname: + output_path = DEFAULT_WORKSPACE_ROOT + output_path.mkdir(parents=True, exist_ok=True) + output_pathname = Path(output_path) / f"{uuid.uuid4().hex}.json" + await awrite(filename=output_pathname, data=new_prd.content) + kvs = AIMessage.create_instruct_value({"changed_prd_filenames": [str(output_pathname)]}) + return AIMessage(content=f'PRD filename: "{str(output_pathname)}"', instruct_content=kvs) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index a33685cd3..1ceb2aade 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -5,7 +5,7 @@ @Author : alexanderwu @File : write_prd_an.py """ -from typing import List +from typing import List, Union from metagpt.actions.action_node import ActionNode @@ -132,7 +132,7 @@ REQUIREMENT_ANALYSIS = ActionNode( REFINED_REQUIREMENT_ANALYSIS = ActionNode( key="Refined Requirement Analysis", - expected_type=List[str], + expected_type=Union[List[str], str], instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project " "due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements " "required for the refined project scope.", diff --git a/metagpt/config2.py b/metagpt/config2.py index 8c61fdbf2..717fe63a9 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional from pydantic import BaseModel, model_validator from metagpt.configs.browser_config import BrowserConfig +from metagpt.configs.embedding_config import EmbeddingConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig @@ -48,6 +49,9 @@ class Config(CLIParams, YamlModel): # Key Parameters llm: LLMConfig + # RAG Embedding + embedding: EmbeddingConfig = EmbeddingConfig() + # Global Proxy. Not used by LLM, but by other tools such as browsers. proxy: str = "" diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py new file mode 100644 index 000000000..20de47999 --- /dev/null +++ b/metagpt/configs/embedding_config.py @@ -0,0 +1,50 @@ +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt.utils.yaml_model import YamlModel + + +class EmbeddingType(Enum): + OPENAI = "openai" + AZURE = "azure" + GEMINI = "gemini" + OLLAMA = "ollama" + + +class EmbeddingConfig(YamlModel): + """Config for Embedding. + + Examples: + --------- + api_type: "openai" + api_key: "YOU_API_KEY" + + api_type: "azure" + api_key: "YOU_API_KEY" + base_url: "YOU_BASE_URL" + api_version: "YOU_API_VERSION" + + api_type: "gemini" + api_key: "YOU_API_KEY" + + api_type: "ollama" + base_url: "YOU_BASE_URL" + model: "YOU_MODEL" + """ + + api_type: Optional[EmbeddingType] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + api_version: Optional[str] = None + + model: Optional[str] = None + embed_batch_size: Optional[int] = None + + @field_validator("api_type", mode="before") + @classmethod + def check_api_type(cls, v): + if v == "": + return None + return v diff --git a/metagpt/context.py b/metagpt/context.py index f1c3568d9..384e8da48 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -8,7 +8,6 @@ from __future__ import annotations import os -from pathlib import Path from typing import Any, Dict, Optional from pydantic import BaseModel, ConfigDict @@ -22,8 +21,6 @@ from metagpt.utils.cost_manager import ( FireworksCostManager, TokenCostManager, ) -from metagpt.utils.git_repository import GitRepository -from metagpt.utils.project_repo import ProjectRepo class AttrDict(BaseModel): @@ -66,9 +63,6 @@ class Context(BaseModel): kwargs: AttrDict = AttrDict() config: Config = Config.default() - repo: Optional[ProjectRepo] = None - git_repo: Optional[GitRepository] = None - src_workspace: Optional[Path] = None cost_manager: CostManager = CostManager() _llm: Optional[BaseLLM] = None @@ -80,11 +74,6 @@ class Context(BaseModel): # env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env - def set_repo_dir(self, path: str | Path): - repo_path = Path(path) - self.git_repo = GitRepository(local_path=repo_path, auto_init=True) - self.repo = ProjectRepo(self.git_repo) - def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: """Return a CostManager instance""" if llm_config.api_type == LLMType.FIREWORKS: @@ -117,7 +106,6 @@ class Context(BaseModel): Dict[str, Any]: A dictionary containing serialized data. """ return { - "workdir": str(self.repo.workdir) if self.repo else "", "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, "cost_manager": self.cost_manager.model_dump_json(), } @@ -130,13 +118,6 @@ class Context(BaseModel): """ if not serialized_data: return - workdir = serialized_data.get("workdir") - if workdir: - self.git_repo = GitRepository(local_path=workdir, auto_init=True) - self.repo = ProjectRepo(self.git_repo) - src_workspace = self.git_repo.workdir / self.git_repo.workdir.name - if src_workspace.exists(): - self.src_workspace = src_workspace kwargs = serialized_data.get("kwargs") if kwargs: for k, v in kwargs.items(): diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index fe1660fc6..5d6d3a286 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -22,6 +22,7 @@ from metagpt.logs import logger from metagpt.memory import Memory from metagpt.schema import Message from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to +from metagpt.utils.git_repository import GitRepository if TYPE_CHECKING: from metagpt.roles.role import Role # noqa: F401 @@ -243,8 +244,9 @@ class Environment(ExtEnv): self.member_addrs[obj] = addresses def archive(self, auto_archive=True): - if auto_archive and self.context.git_repo: - self.context.git_repo.archive() + if auto_archive and self.context.kwargs.get("project_path"): + git_repo = GitRepository(self.context.kwargs.project_path) + git_repo.archive() @classmethod def model_rebuild(cls, **kwargs): diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 5c5810308..c237dcf69 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -4,7 +4,7 @@ import json import os from typing import Any, Optional, Union -from llama_index.core import SimpleDirectoryReader, VectorStoreIndex +from llama_index.core import SimpleDirectoryReader from llama_index.core.callbacks.base import CallbackManager from llama_index.core.embeddings import BaseEmbedding from llama_index.core.embeddings.mock_embed_model import MockEmbedding @@ -63,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine): response_synthesizer: Optional[BaseSynthesizer] = None, node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, callback_manager: Optional[CallbackManager] = None, - index: Optional[BaseIndex] = None, + transformations: Optional[list[TransformComponent]] = None, ) -> None: super().__init__( retriever=retriever, @@ -71,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine): node_postprocessors=node_postprocessors, callback_manager=callback_manager, ) - self.index = index + self._transformations = transformations or self._default_transformations() @classmethod def from_docs( @@ -103,12 +103,17 @@ class SimpleEngine(RetrieverQueryEngine): documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() cls._fix_document_metadata(documents) - index = VectorStoreIndex.from_documents( - documents=documents, - transformations=transformations or [SentenceSplitter()], - embed_model=cls._resolve_embed_model(embed_model, retriever_configs), + transformations = transformations or cls._default_transformations() + nodes = run_transformations(documents, transformations=transformations) + + return cls._from_nodes( + nodes=nodes, + transformations=transformations, + embed_model=embed_model, + llm=llm, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, ) - return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @classmethod def from_objs( @@ -137,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine): raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] - index = VectorStoreIndex( + + return cls._from_nodes( nodes=nodes, - transformations=transformations or [SentenceSplitter()], - embed_model=cls._resolve_embed_model(embed_model, retriever_configs), + transformations=transformations, + embed_model=embed_model, + llm=llm, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, ) - return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @classmethod def from_index( @@ -161,6 +169,13 @@ class SimpleEngine(RetrieverQueryEngine): """Inplement tools.SearchInterface""" return await self.aquery(content) + def retrieve(self, query: QueryType) -> list[NodeWithScore]: + query_bundle = QueryBundle(query) if isinstance(query, str) else query + + nodes = super().retrieve(query_bundle) + self._try_reconstruct_obj(nodes) + return nodes + async def aretrieve(self, query: QueryType) -> list[NodeWithScore]: """Allow query to be str.""" query_bundle = QueryBundle(query) if isinstance(query, str) else query @@ -176,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine): documents = SimpleDirectoryReader(input_files=input_files).load_data() self._fix_document_metadata(documents) - nodes = run_transformations(documents, transformations=self.index._transformations) + nodes = run_transformations(documents, transformations=self._transformations) self._save_nodes(nodes) def add_objs(self, objs: list[RAGObject]): @@ -192,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine): self._persist(str(persist_dir), **kwargs) + @classmethod + def _from_nodes( + cls, + nodes: list[BaseNode], + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, + ) -> "SimpleEngine": + embed_model = cls._resolve_embed_model(embed_model, retriever_configs) + llm = llm or get_rag_llm() + + retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model) + rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] + + return cls( + retriever=retriever, + node_postprocessors=rankers, + response_synthesizer=get_response_synthesizer(llm=llm), + transformations=transformations, + ) + @classmethod def _from_index( cls, @@ -201,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine): ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": llm = llm or get_rag_llm() + retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever rankers = get_rankers(configs=ranker_configs, llm=llm) # Default [] @@ -208,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine): retriever=retriever, node_postprocessors=rankers, response_synthesizer=get_response_synthesizer(llm=llm), - index=index, ) def _ensure_retriever_modifiable(self): @@ -259,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine): return MockEmbedding(embed_dim=1) return embed_model or get_rag_embedding() + + @staticmethod + def _default_transformations(): + return [SentenceSplitter()] diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index fbdfbf1a8..e58643efe 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -26,6 +26,9 @@ class GenericFactory: if creator: return creator(**kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Creator not registered for key: {key}") @@ -33,19 +36,26 @@ class ConfigBasedFactory(GenericFactory): """Designed to get objects based on object type.""" def get_instance(self, key: Any, **kwargs) -> Any: - """Key is config, such as a pydantic model. + """Get instance by the type of key. - Call func by the type of key, and the key will be passed to func. + Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func. + Raise Exception if key not found. """ creator = self._creators.get(type(key)) if creator: return creator(key, **kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Unknown config: `{type(key)}`, {key}") @staticmethod def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any: - """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.""" + """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs. + + Return None if not found. + """ if config is not None and hasattr(config, key): val = getattr(config, key) if val is not None: @@ -54,6 +64,4 @@ class ConfigBasedFactory(GenericFactory): if key in kwargs: return kwargs[key] - raise KeyError( - f"The key '{key}' is required but not provided in either configuration object or keyword arguments." - ) + return None diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 4247db256..3613fd228 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -1,37 +1,103 @@ """RAG Embedding Factory.""" +from __future__ import annotations + +from typing import Any from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding +from llama_index.embeddings.gemini import GeminiEmbedding +from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.embeddings.openai import OpenAIEmbedding from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.base import GenericFactory class RAGEmbeddingFactory(GenericFactory): - """Create LlamaIndex Embedding with MetaGPT's config.""" + """Create LlamaIndex Embedding with MetaGPT's embedding config.""" def __init__(self): creators = { + EmbeddingType.OPENAI: self._create_openai, + EmbeddingType.AZURE: self._create_azure, + EmbeddingType.GEMINI: self._create_gemini, + EmbeddingType.OLLAMA: self._create_ollama, + # For backward compatibility LLMType.OPENAI: self._create_openai, LLMType.AZURE: self._create_azure, } super().__init__(creators) - def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding: - """Key is LLMType, default use config.llm.api_type.""" - return super().get_instance(key or config.llm.api_type) + def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding: + """Key is EmbeddingType.""" + return super().get_instance(key or self._resolve_embedding_type()) - def _create_openai(self): - return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url) + def _resolve_embedding_type(self) -> EmbeddingType | LLMType: + """Resolves the embedding type. - def _create_azure(self): - return AzureOpenAIEmbedding( - azure_endpoint=config.llm.base_url, - api_key=config.llm.api_key, - api_version=config.llm.api_version, + If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE. + Raise TypeError if embedding type not found. + """ + if config.embedding.api_type: + return config.embedding.api_type + + if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]: + return config.llm.api_type + + raise TypeError("To use RAG, please set your embedding in config2.yaml.") + + def _create_openai(self) -> OpenAIEmbedding: + params = dict( + api_key=config.embedding.api_key or config.llm.api_key, + api_base=config.embedding.base_url or config.llm.base_url, ) + self._try_set_model_and_batch_size(params) + + return OpenAIEmbedding(**params) + + def _create_azure(self) -> AzureOpenAIEmbedding: + params = dict( + api_key=config.embedding.api_key or config.llm.api_key, + azure_endpoint=config.embedding.base_url or config.llm.base_url, + api_version=config.embedding.api_version or config.llm.api_version, + ) + + self._try_set_model_and_batch_size(params) + + return AzureOpenAIEmbedding(**params) + + def _create_gemini(self) -> GeminiEmbedding: + params = dict( + api_key=config.embedding.api_key, + api_base=config.embedding.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return GeminiEmbedding(**params) + + def _create_ollama(self) -> OllamaEmbedding: + params = dict( + base_url=config.embedding.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return OllamaEmbedding(**params) + + def _try_set_model_and_batch_size(self, params: dict): + """Set the model_name and embed_batch_size only when they are specified.""" + if config.embedding.model: + params["model_name"] = config.embedding.model + + if config.embedding.embed_batch_size: + params["embed_batch_size"] = config.embedding.embed_batch_size + + def _raise_for_key(self, key: Any): + raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}") + get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index a56471359..f897af3ad 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -48,7 +48,7 @@ class RAGIndexFactory(ConfigBasedFactory): def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: db = chromadb.PersistentClient(str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 17c499b76..9fd19cab5 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -1,5 +1,5 @@ """RAG LLM.""" - +import asyncio from typing import Any from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW @@ -15,7 +15,7 @@ from pydantic import Field from metagpt.config2 import config from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM -from metagpt.utils.async_helper import run_coroutine_in_new_loop +from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.token_counter import TOKEN_MAX @@ -39,7 +39,8 @@ class RAGLLM(CustomLLM): @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs)) + NestAsyncio.apply_once() + return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs)) @llm_completion_callback() async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse: diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 476fe8c1a..7abda162a 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -8,6 +8,8 @@ from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor from metagpt.rag.schema import ( BaseRankerConfig, + BGERerankConfig, + CohereRerankConfig, ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig, @@ -22,6 +24,8 @@ class RankerFactory(ConfigBasedFactory): LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker, ObjectRankerConfig: self._create_object_ranker, + CohereRerankConfig: self._create_cohere_rerank, + BGERerankConfig: self._create_bge_rerank, } super().__init__(creators) @@ -45,6 +49,26 @@ class RankerFactory(ConfigBasedFactory): ) return ColbertRerank(**config.model_dump()) + def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.cohere_rerank import CohereRerank + except ImportError: + raise ImportError( + "`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`" + ) + return CohereRerank(**config.model_dump()) + + def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.flag_embedding_reranker import ( + FlagEmbeddingReranker, + ) + except ImportError: + raise ImportError( + "`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`" + ) + return FlagEmbeddingReranker(**config.model_dump()) + def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: return ObjectSortPostprocessor(**config.model_dump()) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 65729002e..1460e131b 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -1,10 +1,13 @@ """RAG Retriever Factory.""" -import copy + +from functools import wraps import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore @@ -24,10 +27,25 @@ from metagpt.rag.schema import ( ElasticsearchKeywordRetrieverConfig, ElasticsearchRetrieverConfig, FAISSRetrieverConfig, - IndexRetrieverConfig, ) +def get_or_build_index(build_index_func): + """Decorator to get or build an index. + + Get index using `_extract_index` method, if not found, using build_index_func. + """ + + @wraps(build_index_func) + def wrapper(self, config, **kwargs): + index = self._extract_index(config, **kwargs) + if index is not None: + return index + return build_index_func(self, config, **kwargs) + + return wrapper + + class RetrieverFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" @@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory): return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0] def _create_default(self, **kwargs) -> RAGRetriever: - return self._extract_index(**kwargs).as_retriever() + index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs) + + return index.as_retriever() def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: - vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_faiss_index(config, **kwargs) return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: - config.index = copy.deepcopy(self._extract_index(config, **kwargs)) + index = self._extract_index(config, **kwargs) + nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs) - return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump()) + return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: - db = chromadb.PersistentClient(path=str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name) - - vector_store = ChromaVectorStore(chroma_collection=chroma_collection) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_chroma_index(config, **kwargs) return ChromaRetriever(**config.model_dump()) def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever: - vector_store = ElasticsearchStore(**config.store_config.model_dump()) - config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + config.index = self._build_es_index(config, **kwargs) return ElasticsearchRetriever(**config.model_dump()) def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: return self._val_from_config_or_kwargs("index", config, **kwargs) + def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]: + return self._val_from_config_or_kwargs("nodes", config, **kwargs) + + def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding: + return self._val_from_config_or_kwargs("embed_model", config, **kwargs) + + def _build_default_index(self, **kwargs) -> VectorStoreIndex: + index = VectorStoreIndex( + nodes=self._extract_nodes(**kwargs), + embed_model=self._extract_embed_model(**kwargs), + ) + + return index + + @get_or_build_index + def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + + @get_or_build_index + def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex: + db = chromadb.PersistentClient(path=str(config.persist_path)) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + + @get_or_build_index + def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + + return self._build_index_from_vector_store(config, vector_store, **kwargs) + def _build_index_from_vector_store( - self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs + self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs ) -> VectorStoreIndex: storage_context = StorageContext.from_defaults(vector_store=vector_store) - old_index = self._extract_index(config, **kwargs) - new_index = VectorStoreIndex( - nodes=list(old_index.docstore.docs.values()), + index = VectorStoreIndex( + nodes=self._extract_nodes(config, **kwargs), storage_context=storage_context, - embed_model=old_index._embed_model, + embed_model=self._extract_embed_model(config, **kwargs), ) - return new_index + + return index get_retriever = RetrieverFactory().get_retriever diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 241820cf4..3b085cb73 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -40,8 +40,10 @@ class DynamicBM25Retriever(BM25Retriever): self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] self.bm25 = BM25Okapi(self._corpus) - self._index.insert_nodes(nodes, **kwargs) + if self._index: + self._index.insert_nodes(nodes, **kwargs) def persist(self, persist_dir: str, **kwargs) -> None: """Support persist.""" - self._index.storage_context.persist(persist_dir) + if self._index: + self._index.storage_context.persist(persist_dir) \ No newline at end of file diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 183f6e0c7..e7b2e5ce9 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,14 +1,17 @@ """RAG schemas.""" from pathlib import Path -from typing import Any, Literal, Union +from typing import Any, ClassVar, Literal, Optional, Union +from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.rag.interface import RAGObject @@ -31,7 +34,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig): class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" - dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.") + dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") + + _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { + EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, + } + + @model_validator(mode="after") + def check_dimensions(self): + if self.dimensions == 0: + self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) + + return self class BM25RetrieverConfig(IndexRetrieverConfig): @@ -45,6 +60,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class ElasticsearchStoreConfig(BaseModel): @@ -101,6 +119,16 @@ class ColbertRerankConfig(BaseRankerConfig): keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") +class CohereRerankConfig(BaseRankerConfig): + model: str = Field(default="rerank-english-v3.0") + api_key: str = Field(default="YOUR_COHERE_API") + + +class BGERerankConfig(BaseRankerConfig): + model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.") + use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.") + + class ObjectRankerConfig(BaseRankerConfig): field_name: str = Field(..., description="field name of the object, field's value must can be compared.") order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") @@ -130,6 +158,9 @@ class ChromaIndexConfig(VectorIndexConfig): """Config for chroma-based index.""" collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class BM25IndexConfig(BaseIndexConfig): diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index 465beff05..9e1761c85 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -6,11 +6,9 @@ @File : architect.py """ -from metagpt.actions import UserRequirement, WritePRD +from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign -from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.roles.role import Role -from metagpt.utils.common import any_to_str class Architect(Role): @@ -36,22 +34,7 @@ class Architect(Role): super().__init__(**kwargs) self.enable_memory = False # Initialize actions specific to the Architect role - self.set_actions([PrepareDocuments(send_to=any_to_str(self), context=self.context), WriteDesign]) + self.set_actions([WriteDesign]) # Set events or actions the Architect should watch or be aware of - self._watch({UserRequirement, PrepareDocuments, WritePRD}) - - async def _think(self) -> bool: - """Decide what to do""" - mappings = { - any_to_str(UserRequirement): 0, - any_to_str(PrepareDocuments): 1, - any_to_str(WritePRD): 1, - } - for i in self.rc.news: - idx = mappings.get(i.cause_by, -1) - if idx < 0: - continue - self.rc.todo = self.actions[idx] - return bool(self.rc.todo) - return False + self._watch({WritePRD}) diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index 0fc95b9d6..fc298ea4c 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -20,6 +20,7 @@ from metagpt.strategy.thinking_command import ( ) from metagpt.tools.tool_recommend import BM25ToolRecommender from metagpt.utils.common import CodeParser +from metagpt.utils.report import ThoughtReporter class DataAnalyst(DataInterpreter): @@ -82,8 +83,8 @@ class DataAnalyst(DataInterpreter): available_commands=prepare_command_prompt(self.available_commands), ) context = self.llm.format_msg(self.working_memory.get() + [Message(content=prompt, role="user")]) - - rsp = await self.llm.aask(context) + async with ThoughtReporter(): + rsp = await self.llm.aask(context) self.commands = json.loads(CodeParser.parse_code(block=None, text=rsp)) self.rc.memory.add(Message(content=rsp, role="assistant")) diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py index e147cbbe3..bdfc0e294 100644 --- a/metagpt/roles/di/data_interpreter.py +++ b/metagpt/roles/di/data_interpreter.py @@ -15,6 +15,7 @@ from metagpt.schema import Message, Task, TaskResult from metagpt.strategy.task_type import TaskType from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.utils.common import CodeParser +from metagpt.utils.report import ThoughtReporter REACT_THINK_PROMPT = """ # User Requirement @@ -73,7 +74,8 @@ class DataInterpreter(Role): return True prompt = REACT_THINK_PROMPT.format(user_requirement=self.user_requirement, context=context) - rsp = await self.llm.aask(prompt) + async with ThoughtReporter(): + rsp = await self.llm.aask(prompt) rsp_dict = json.loads(CodeParser.parse_code(text=rsp)) self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant")) need_action = rsp_dict["state"] diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index afd4a68c4..62df9d580 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -19,6 +19,7 @@ from metagpt.strategy.planner import Planner from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser +from metagpt.utils.report import ThoughtReporter @register_tool(include_functions=["ask_human", "reply_to_human"]) @@ -118,7 +119,8 @@ class RoleZero(Role): ) context = self.llm.format_msg(self.rc.memory.get(self.memory_k) + [Message(content=prompt, role="user")]) print(*context, sep="\n" + "*" * 5 + "\n") - self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg) + async with ThoughtReporter(): + self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg) self.rc.memory.add(Message(content=self.command_rsp, role="assistant")) return True diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index d9e375a9a..919b4bf13 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -24,20 +24,15 @@ from collections import defaultdict from pathlib import Path from typing import List, Optional, Set -from metagpt.actions import ( - Action, - UserRequirement, - WriteCode, - WriteCodeReview, - WriteTasks, -) +from pydantic import BaseModel, Field + +from metagpt.actions import WriteCode, WriteCodeReview, WriteTasks from metagpt.actions.fix_bug import FixBug from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST from metagpt.actions.summarize_code import SummarizeCode from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange from metagpt.const import ( - BUGFIX_FILENAME, CODE_PLAN_AND_CHANGE_FILE_REPO, MESSAGE_ROUTE_TO_SELF, REQUIREMENT_FILENAME, @@ -63,6 +58,7 @@ from metagpt.utils.common import ( init_python_folder, ) from metagpt.utils.git_repository import ChangeType +from metagpt.utils.project_repo import ProjectRepo IS_PASS_PROMPT = """ {context} @@ -100,23 +96,14 @@ class Engineer(Role): summarize_todos: list = [] next_todo_action: str = "" n_summarize: int = 0 + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.enable_memory = False self.set_actions([WriteCode]) - self._watch( - [ - UserRequirement, - PrepareDocuments, - WriteTasks, - SummarizeCode, - WriteCode, - WriteCodeReview, - FixBug, - WriteCodePlanAndChange, - ] - ) + self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug, WriteCodePlanAndChange]) self.code_todos = [] self.summarize_todos = [] self.next_todo_action = any_to_name(WriteCode) @@ -139,14 +126,20 @@ class Engineer(Role): coding_context = await todo.run() # Code review if review: - action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm) + action = WriteCodeReview( + i_context=coding_context, + repo=self.repo, + input_args=self.input_args, + context=self.context, + llm=self.llm, + ) self._init_action(action) coding_context = await action.run() dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path} if self.config.inc: dependencies.add(coding_context.code_plan_and_change_doc.root_relative_path) - await self.project_repo.srcs.save( + await self.repo.srcs.save( filename=coding_context.filename, dependencies=list(dependencies), content=coding_context.code_doc.content, @@ -186,9 +179,9 @@ class Engineer(Role): summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name dependencies = {todo.i_context.design_filename, todo.i_context.task_filename} for filename in todo.i_context.codes_filenames: - rpath = self.project_repo.src_relative_path / filename + rpath = self.repo.src_relative_path / filename dependencies.add(str(rpath)) - await self.project_repo.resources.code_summary.save( + await self.repo.resources.code_summary.save( filename=summary_filename, content=summary, dependencies=dependencies ) is_pass, reason = await self._is_pass(summary) @@ -196,23 +189,39 @@ class Engineer(Role): todo.i_context.reason = reason tasks.append(todo.i_context.model_dump()) - await self.project_repo.docs.code_summary.save( + await self.repo.docs.code_summary.save( filename=Path(todo.i_context.design_filename).name, content=todo.i_context.model_dump_json(), dependencies=dependencies, ) else: - await self.project_repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name) + await self.repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name) self.summarize_todos = [] logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}") if not tasks or self.config.max_auto_summarize_code == 0: self.n_summarize = 0 + kvs = self.input_args.model_dump() + kvs["changed_src_filenames"] = [ + str(self.repo.srcs.workdir / i) for i in list(self.repo.srcs.changed_files.keys()) + ] + if self.repo.docs.code_plan_and_change.changed_files: + kvs["changed_code_plan_and_change_filenames"] = [ + str(self.repo.docs.code_plan_and_change.workdir / i) + for i in list(self.repo.docs.code_plan_and_change.changed_files.keys()) + ] + if self.repo.docs.code_summary.changed_files: + kvs["changed_code_summary_filenames"] = [ + str(self.repo.docs.code_summary.workdir / i) + for i in list(self.repo.docs.code_summary.changed_files.keys()) + ] return AIMessage( - content=f"Coding is complete. The source code is at {self.project_repo.workdir.name}/{self.project_repo.srcs.root_path}, containing: " + content=f"Coding is complete. The source code is at {self.repo.workdir.name}/{self.repo.srcs.root_path}, containing: " + "\n".join( - list(self.project_repo.resources.code_summary.changed_files.keys()) - + list(self.project_repo.srcs.changed_files.keys()) + list(self.repo.resources.code_summary.changed_files.keys()) + + list(self.repo.srcs.changed_files.keys()) + + list(self.repo.resources.code_plan_and_change.changed_files.keys()) ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="SummarizeCodeOutput"), cause_by=SummarizeCode, send_to="Edward", # The name of QaEngineer ) @@ -227,15 +236,15 @@ class Engineer(Role): code_plan_and_change = node.instruct_content.model_dump_json() dependencies = { REQUIREMENT_FILENAME, - str(self.project_repo.docs.prd.root_path / self.rc.todo.i_context.prd_filename), - str(self.project_repo.docs.system_design.root_path / self.rc.todo.i_context.design_filename), - str(self.project_repo.docs.task.root_path / self.rc.todo.i_context.task_filename), + str(Path(self.rc.todo.i_context.prd_filename).relative_to(self.repo.workdir)), + str(Path(self.rc.todo.i_context.design_filename).relative_to(self.repo.workdir)), + str(Path(self.rc.todo.i_context.task_filename).relative_to(self.repo.workdir)), } code_plan_and_change_filepath = Path(self.rc.todo.i_context.design_filename) - await self.project_repo.docs.code_plan_and_change.save( + await self.repo.docs.code_plan_and_change.save( filename=code_plan_and_change_filepath.name, content=code_plan_and_change, dependencies=dependencies ) - await self.project_repo.resources.code_plan_and_change.save( + await self.repo.resources.code_plan_and_change.save( filename=code_plan_and_change_filepath.with_suffix(".md").name, content=node.content, dependencies=dependencies, @@ -250,55 +259,49 @@ class Engineer(Role): return True, rsp return False, rsp - async def _think(self) -> Action | None: + async def _think(self) -> bool: if not self.rc.news: - return None + return False msg = self.rc.news[0] - if msg.cause_by == any_to_str(UserRequirement): - self.rc.todo = PrepareDocuments( - key_descriptions={ - "project_path": 'the project path if exists in "Original Requirement"', - "src_filename": 'the file name of the source code file explicitly requested for modification if exists in "Original Requirement"', - }, - context=self.context, - send_to=any_to_str(self), - ) - return self.rc.todo - - if not self.src_workspace: - self.src_workspace = get_project_srcs_path(self.project_repo.workdir) + input_args = msg.instruct_content + if msg.cause_by in {any_to_str(WriteTasks), any_to_str(FixBug)}: + self.input_args = input_args + self.repo = ProjectRepo(input_args.project_path) + if self.repo.src_relative_path is None: + path = get_project_srcs_path(self.repo.workdir) + self.repo.with_src_path(path) write_plan_and_change_filters = any_to_str_set([PrepareDocuments, WriteTasks, FixBug]) write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode]) summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview]) if self.config.inc and msg.cause_by in write_plan_and_change_filters: logger.debug(f"TODO WriteCodePlanAndChange:{msg.model_dump_json()}") await self._new_code_plan_and_change_action(cause_by=msg.cause_by) - return self.rc.todo + return bool(self.rc.todo) if msg.cause_by in write_code_filters: logger.debug(f"TODO WriteCode:{msg.model_dump_json()}") await self._new_code_actions() - return self.rc.todo + return bool(self.rc.todo) if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self): logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}") await self._new_summarize_actions() - return self.rc.todo - return None + return bool(self.rc.todo) + return False async def _new_coding_context(self, filename, dependency) -> Optional[CodingContext]: - old_code_doc = await self.project_repo.srcs.get(filename) + old_code_doc = await self.repo.srcs.get(filename) if not old_code_doc: - old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="") + old_code_doc = Document(root_path=str(self.repo.src_relative_path), filename=filename, content="") dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)} task_doc = None design_doc = None code_plan_and_change_doc = await self._get_any_code_plan_and_change() if await self._is_fixbug() else None for i in dependencies: if str(i.parent) == TASK_FILE_REPO: - task_doc = await self.project_repo.docs.task.get(i.name) + task_doc = await self.repo.docs.task.get(i.name) elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO: - design_doc = await self.project_repo.docs.system_design.get(i.name) + design_doc = await self.repo.docs.system_design.get(i.name) elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO: - code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name) + code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(i.name) if not task_doc or not design_doc: if filename == "__init__.py": # `__init__.py` created by `init_python_folder` return None @@ -318,34 +321,66 @@ class Engineer(Role): if not context: return None # `__init__.py` created by `init_python_folder` coding_doc = Document( - root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json() + root_path=str(self.repo.src_relative_path), filename=filename, content=context.model_dump_json() ) return coding_doc async def _new_code_actions(self): bug_fix = await self._is_fixbug() # Prepare file repos - changed_src_files = self.project_repo.srcs.changed_files + changed_src_files = self.repo.srcs.changed_files if self.context.kwargs.src_filename: changed_src_files = {self.context.kwargs.src_filename: ChangeType.UNTRACTED} if bug_fix: - changed_src_files = self.project_repo.srcs.all_files - changed_task_files = self.project_repo.docs.task.changed_files + changed_src_files = self.repo.srcs.all_files changed_files = Documents() # Recode caused by upstream changes. - for filename in changed_task_files: - design_doc = await self.project_repo.docs.system_design.get(filename) - task_doc = await self.project_repo.docs.task.get(filename) - code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename) + if hasattr(self.input_args, "changed_task_filenames"): + changed_task_filenames = self.input_args.changed_task_filenames + else: + changed_task_filenames = [ + str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys()) + ] + for filename in changed_task_filenames: + task_filename = Path(filename) + design_filename = None + if hasattr(self.input_args, "changed_system_design_filenames"): + changed_system_design_filenames = self.input_args.changed_system_design_filenames + else: + changed_system_design_filenames = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] + for i in changed_system_design_filenames: + if task_filename.name == Path(i).name: + design_filename = Path(i) + break + code_plan_and_change_filename = None + if hasattr(self.input_args, "changed_code_plan_and_change_filenames"): + changed_code_plan_and_change_filenames = self.input_args.changed_code_plan_and_change_filenames + else: + changed_code_plan_and_change_filenames = [ + str(self.repo.docs.code_plan_and_change.workdir / i) + for i in list(self.repo.docs.code_plan_and_change.changed_files.keys()) + ] + for i in changed_code_plan_and_change_filenames: + if task_filename.name == Path(i).name: + code_plan_and_change_filename = Path(i) + break + design_doc = await Document.load(filename=design_filename, project_path=self.repo.workdir) + task_doc = await Document.load(filename=task_filename, project_path=self.repo.workdir) + code_plan_and_change_doc = await Document.load( + filename=code_plan_and_change_filename, project_path=self.repo.workdir + ) task_list = self._parse_tasks(task_doc) await self._init_python_folder(task_list) for task_filename in task_list: if self.context.kwargs.src_filename and task_filename != self.context.kwargs.src_filename: continue - old_code_doc = await self.project_repo.srcs.get(task_filename) + old_code_doc = await self.repo.srcs.get(task_filename) if not old_code_doc: old_code_doc = Document( - root_path=str(self.project_repo.src_relative_path), filename=task_filename, content="" + root_path=str(self.repo.src_relative_path), filename=task_filename, content="" ) if not code_plan_and_change_doc: context = CodingContext( @@ -360,7 +395,7 @@ class Engineer(Role): code_plan_and_change_doc=code_plan_and_change_doc, ) coding_doc = Document( - root_path=str(self.project_repo.src_relative_path), + root_path=str(self.repo.src_relative_path), filename=task_filename, content=context.model_dump_json(), ) @@ -371,10 +406,11 @@ class Engineer(Role): ) changed_files.docs[task_filename] = coding_doc self.code_todos = [ - WriteCode(i_context=i, context=self.context, llm=self.llm) for i in changed_files.docs.values() + WriteCode(i_context=i, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm) + for i in changed_files.docs.values() ] # Code directly modified by the user. - dependency = await self.git_repo.get_dependency() + dependency = await self.repo.git_repo.get_dependency() for filename in changed_src_files: if filename in changed_files.docs: continue @@ -382,24 +418,30 @@ class Engineer(Role): if not coding_doc: continue # `__init__.py` created by `init_python_folder` changed_files.docs[filename] = coding_doc - self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm)) + self.code_todos.append( + WriteCode( + i_context=coding_doc, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ) + ) if self.code_todos: self.set_todo(self.code_todos[0]) async def _new_summarize_actions(self): - src_files = self.project_repo.srcs.all_files + src_files = self.repo.srcs.all_files # Generate a SummarizeCode action for each pair of (system_design_doc, task_doc). summarizations = defaultdict(list) for filename in src_files: - dependencies = await self.project_repo.srcs.get_dependency(filename=filename) + dependencies = await self.repo.srcs.get_dependency(filename=filename) ctx = CodeSummarizeContext.loads(filenames=list(dependencies)) summarizations[ctx].append(filename) for ctx, filenames in summarizations.items(): if not ctx.design_filename or not ctx.task_filename: continue # cause by `__init__.py` which is created by `init_python_folder` ctx.codes_filenames = filenames - new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm) + new_summarize = SummarizeCode( + i_context=ctx, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ) for i, act in enumerate(self.summarize_todos): if act.i_context.task_filename == new_summarize.i_context.task_filename: self.summarize_todos[i] = new_summarize @@ -412,16 +454,37 @@ class Engineer(Role): async def _new_code_plan_and_change_action(self, cause_by: str): """Create a WriteCodePlanAndChange action for subsequent to-do actions.""" - files = self.project_repo.all_files options = {} if cause_by != any_to_str(FixBug): - requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME) + requirement_doc = await Document.load(filename=self.input_args.requirements_filename) options["requirement"] = requirement_doc.content else: - fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME) + fixbug_doc = await Document.load(filename=self.input_args.issue_filename) options["issue"] = fixbug_doc.content - code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, **options) - self.rc.todo = WriteCodePlanAndChange(i_context=code_plan_and_change_ctx, context=self.context, llm=self.llm) + # The code here is flawed: if there are multiple unrelated requirements, this piece of logic will break + if hasattr(self.input_args, "changed_prd_filenames"): + code_plan_and_change_ctx = CodePlanAndChangeContext( + requirement=options.get("requirement", ""), + issue=options.get("issue", ""), + prd_filename=self.input_args.changed_prd_filenames[0], + design_filename=self.input_args.changed_system_design_filenames[0], + task_filename=self.input_args.changed_task_filenames[0], + ) + else: + code_plan_and_change_ctx = CodePlanAndChangeContext( + requirement=options.get("requirement", ""), + issue=options.get("issue", ""), + prd_filename=str(self.repo.docs.prd.workdir / self.repo.docs.prd.all_files[0]), + design_filename=str(self.repo.docs.system_design.workdir / self.repo.docs.system_design.all_files[0]), + task_filename=str(self.repo.docs.task.workdir / self.repo.docs.task.all_files[0]), + ) + self.rc.todo = WriteCodePlanAndChange( + i_context=code_plan_and_change_ctx, + repo=self.repo, + input_args=self.input_args, + context=self.context, + llm=self.llm, + ) @property def action_description(self) -> str: @@ -433,17 +496,16 @@ class Engineer(Role): filename = Path(i) if filename.suffix != ".py": continue - workdir = self.src_workspace / filename.parent + workdir = self.repo.srcs.workdir / filename.parent await init_python_folder(workdir) async def _is_fixbug(self) -> bool: - fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME) - return bool(fixbug_doc and fixbug_doc.content) + return bool(self.input_args and hasattr(self.input_args, "issue_filename")) async def _get_any_code_plan_and_change(self) -> Optional[Document]: - changed_files = self.project_repo.docs.code_plan_and_change.changed_files + changed_files = self.repo.docs.code_plan_and_change.changed_files for filename in changed_files.keys(): - doc = await self.project_repo.docs.code_plan_and_change.get(filename) + doc = await self.repo.docs.code_plan_and_change.get(filename) if doc and doc.content: return doc return None diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 4beab5366..d08933cb0 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -7,10 +7,12 @@ @Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135. """ + from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.roles.role import Role, RoleReactMode from metagpt.utils.common import any_to_name, any_to_str +from metagpt.utils.git_repository import GitRepository class ProductManager(Role): @@ -40,7 +42,7 @@ class ProductManager(Role): async def _think(self) -> bool: """Decide what to do""" - if self.git_repo and not self.config.git_reinit: + if GitRepository.is_git_dir(self.config.project_path) and not self.config.git_reinit: self._set_state(1) else: self._set_state(0) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 70bd3bf8b..d6374e673 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -6,11 +6,9 @@ @File : project_manager.py """ -from metagpt.actions import UserRequirement, WriteTasks +from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign -from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.roles.role import Role -from metagpt.utils.common import any_to_str class ProjectManager(Role): @@ -35,20 +33,5 @@ class ProjectManager(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.enable_memory = False - self.set_actions([PrepareDocuments(send_to=any_to_str(self), context=self.context), WriteTasks]) - self._watch([UserRequirement, PrepareDocuments, WriteDesign]) - - async def _think(self) -> bool: - """Decide what to do""" - mappings = { - any_to_str(UserRequirement): 0, - any_to_str(PrepareDocuments): 1, - any_to_str(WriteDesign): 1, - } - for i in self.rc.news: - idx = mappings.get(i.cause_by, -1) - if idx < 0: - continue - self.rc.todo = self.actions[idx] - return bool(self.rc.todo) - return False + self.set_actions([WriteTasks]) + self._watch([WriteDesign]) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index f76baff3f..fc8fa5353 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -14,6 +14,9 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ +from typing import Optional + +from pydantic import BaseModel, Field from metagpt.actions import DebugError, RunCode, UserRequirement, WriteTest from metagpt.actions.prepare_documents import PrepareDocuments @@ -25,9 +28,11 @@ from metagpt.schema import AIMessage, Document, Message, RunCodeContext, Testing from metagpt.utils.common import ( any_to_str, any_to_str_set, + get_project_srcs_path, init_python_folder, parse_recipient, ) +from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.report import EditorReporter @@ -41,6 +46,8 @@ class QaEngineer(Role): ) test_round_allowed: int = 5 test_round: int = 0 + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -48,31 +55,26 @@ class QaEngineer(Role): # FIXME: a bit hack here, only init one action to circumvent _think() logic, # will overwrite _think() in future updates - self.set_actions( - [ - WriteTest, - ] - ) - self._watch([UserRequirement, PrepareDocuments, SummarizeCode, WriteTest, RunCode, DebugError]) + self.set_actions([WriteTest]) + self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 async def _write_test(self, message: Message) -> None: - src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs reqa_file = self.context.kwargs.reqa_file or self.config.reqa_file - changed_files = {reqa_file} if reqa_file else set(src_file_repo.changed_files.keys()) + changed_files = {reqa_file} if reqa_file else set(self.repo.srcs.changed_files.keys()) for filename in changed_files: # write tests if not filename or "test" in filename: continue - code_doc = await src_file_repo.get(filename) - if not code_doc: + code_doc = await self.repo.srcs.get(filename) + if not code_doc or not code_doc.content: continue if not code_doc.filename.endswith(".py"): continue - test_doc = await self.project_repo.tests.get("test_" + code_doc.filename) + test_doc = await self.repo.tests.get("test_" + code_doc.filename) if not test_doc: test_doc = Document( - root_path=str(self.project_repo.tests.root_path), filename="test_" + code_doc.filename, content="" + root_path=str(self.repo.tests.root_path), filename="test_" + code_doc.filename, content="" ) logger.info(f"Writing {test_doc.filename}..") context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc) @@ -81,40 +83,38 @@ class QaEngineer(Role): async with EditorReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "test", "filename": test_doc.filename}, "meta") - doc = await self.project_repo.tests.save_doc( + doc = await self.repo.tests.save_doc( doc=context.test_doc, dependencies={context.code_doc.root_relative_path} ) - await reporter.async_report(self.project_repo.workdir / doc.root_relative_path, "path") + await reporter.async_report(self.repo.workdir / doc.root_relative_path, "path") # prepare context for run tests in next round run_code_context = RunCodeContext( command=["python", context.test_doc.root_relative_path], code_filename=context.code_doc.filename, test_filename=context.test_doc.filename, - working_directory=str(self.project_repo.workdir), - additional_python_paths=[str(self.context.src_workspace)], + working_directory=str(self.repo.workdir), + additional_python_paths=[str(self.repo.srcs.workdir)], ) self.publish_message( AIMessage(content=run_code_context.model_dump_json(), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_SELF) ) - logger.info(f"Done {str(self.project_repo.tests.workdir)} generating.") + logger.info(f"Done {str(self.repo.tests.workdir)} generating.") async def _run_code(self, msg): run_code_context = RunCodeContext.loads(msg.content) - src_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get( - run_code_context.code_filename - ) + src_doc = await self.repo.srcs.get(run_code_context.code_filename) if not src_doc: return - test_doc = await self.project_repo.tests.get(run_code_context.test_filename) + test_doc = await self.repo.tests.get(run_code_context.test_filename) if not test_doc: return run_code_context.code = src_doc.content run_code_context.test_code = test_doc.content result = await RunCode(i_context=run_code_context, context=self.context, llm=self.llm).run() run_code_context.output_filename = run_code_context.test_filename + ".json" - await self.project_repo.test_outputs.save( + await self.repo.test_outputs.save( filename=run_code_context.output_filename, content=result.model_dump_json(), dependencies={src_doc.root_relative_path, test_doc.root_relative_path}, @@ -124,31 +124,53 @@ class QaEngineer(Role): # the recipient might be Engineer or myself recipient = parse_recipient(result.summary) mappings = {"Engineer": "Alex", "QaEngineer": "Edward"} - self.publish_message( - AIMessage( - content=run_code_context.model_dump_json(), - cause_by=RunCode, - send_to=mappings.get(recipient, MESSAGE_ROUTE_TO_NONE), + if recipient != "Engineer": + self.publish_message( + AIMessage( + content=run_code_context.model_dump_json(), + cause_by=RunCode, + instruct_content=self.input_args, + send_to=MESSAGE_ROUTE_TO_SELF, + ) + ) + else: + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] + self.publish_message( + AIMessage( + content=run_code_context.model_dump_json(), + cause_by=RunCode, + instruct_content=self.input_args, + send_to=mappings.get(recipient, MESSAGE_ROUTE_TO_NONE), + ) ) - ) async def _debug_error(self, msg): run_code_context = RunCodeContext.loads(msg.content) - code = await DebugError(i_context=run_code_context, context=self.context, llm=self.llm).run() - await self.project_repo.tests.save(filename=run_code_context.test_filename, content=code) + code = await DebugError( + i_context=run_code_context, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm + ).run() + await self.repo.tests.save(filename=run_code_context.test_filename, content=code) run_code_context.output = None self.publish_message( AIMessage(content=run_code_context.model_dump_json(), cause_by=DebugError, send_to=MESSAGE_ROUTE_TO_SELF) ) async def _act(self) -> Message: - if self.project_path: - await init_python_folder(self.project_repo.tests.workdir) + if self.input_args.project_path: + await init_python_folder(self.repo.tests.workdir) if self.test_round > self.test_round_allowed: + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] result_msg = AIMessage( content=f"Exceeding {self.test_round_allowed} rounds of tests, stop. " - + "\n".join(list(self.project_repo.tests.changed_files.keys())), + + "\n".join(list(self.repo.tests.changed_files.keys())), cause_by=WriteTest, + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTestOutput"), send_to=MESSAGE_ROUTE_TO_NONE, ) return result_msg @@ -171,8 +193,13 @@ class QaEngineer(Role): elif msg.cause_by == any_to_str(UserRequirement): return await self._parse_user_requirement(msg) self.test_round += 1 + kvs = self.input_args.model_dump() + kvs["changed_test_filenames"] = [ + str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys()) + ] return AIMessage( content=f"Round {self.test_round} of tests done", + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTestOutput"), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_NONE, ) @@ -190,3 +217,15 @@ class QaEngineer(Role): if not self.src_workspace: self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name return rsp + + async def _think(self) -> bool: + if not self.rc.news: + return False + msg = self.rc.news[0] + if msg.cause_by == any_to_str(SummarizeCode): + self.input_args = msg.instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + if self.repo.src_relative_path is None: + path = get_project_srcs_path(self.repo.workdir) + self.repo.with_src_path(path) + return True diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 1eaa77fa3..344e1df5e 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -45,7 +45,6 @@ from metagpt.schema import ( ) from metagpt.strategy.planner import Planner from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator -from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output if TYPE_CHECKING: @@ -196,29 +195,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel): value.context = self.context self.rc.todo = value - @property - def git_repo(self): - """Git repo""" - return self.context.git_repo - - @git_repo.setter - def git_repo(self, value): - self.context.git_repo = value - - @property - def src_workspace(self): - """Source workspace under git repo""" - return self.context.src_workspace - - @src_workspace.setter - def src_workspace(self, value): - self.context.src_workspace = value - - @property - def project_repo(self) -> ProjectRepo: - project_repo = ProjectRepo(self.context.git_repo) - return project_repo.with_src_path(self.context.src_workspace) if self.context.src_workspace else project_repo - @property def prompt_schema(self): """Prompt schema: json/markdown""" @@ -410,8 +386,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): msg = response else: msg = AIMessage(content=response or "", cause_by=self.rc.todo, sent_from=self) - if self.enable_memory: - self.rc.memory.add(msg) + self.rc.memory.add(msg) return msg diff --git a/metagpt/schema.py b/metagpt/schema.py index d867ef125..69c7a519b 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -31,6 +31,7 @@ from pydantic import ( ConfigDict, Field, PrivateAttr, + create_model, field_serializer, field_validator, model_serializer, @@ -43,14 +44,19 @@ from metagpt.const import ( MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, - PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) from metagpt.logs import logger from metagpt.repo_parser import DotClassInfo from metagpt.tools.tool_registry import register_tool -from metagpt.utils.common import CodeParser, any_to_str, any_to_str_set, import_class +from metagpt.utils.common import ( + CodeParser, + any_to_str, + any_to_str_set, + aread, + import_class, +) from metagpt.utils.exceptions import handle_exception from metagpt.utils.report import TaskReporter from metagpt.utils.serialize import ( @@ -158,6 +164,30 @@ class Document(BaseModel): def __repr__(self): return self.content + @classmethod + async def load( + cls, filename: Union[str, Path], project_path: Optional[Union[str, Path]] = None + ) -> Optional["Document"]: + """ + Load a document from a file. + + Args: + filename (Union[str, Path]): The path to the file to load. + project_path (Optional[Union[str, Path]], optional): The path to the project. Defaults to None. + + Returns: + Optional[Document]: The loaded document, or None if the file does not exist. + + """ + if not filename or not Path(filename).exists(): + return None + content = await aread(filename=filename) + doc = cls(content=content, filename=str(filename)) + if project_path and Path(filename).is_relative_to(project_path): + doc.root_path = Path(filename).relative_to(project_path).parent + doc.filename = Path(filename).name + return doc + class Documents(BaseModel): """A class representing a collection of documents. @@ -361,6 +391,22 @@ class Message(BaseModel): def add_metadata(self, key: str, value: str): self.metadata[key] = value + @staticmethod + def create_instruct_value(kvs: Dict[str, Any], class_name: str = "") -> BaseModel: + """ + Dynamically creates a Pydantic BaseModel subclass based on a given dictionary. + + Parameters: + - data: A dictionary from which to create the BaseModel subclass. + + Returns: + - A Pydantic BaseModel subclass instance populated with the given data. + """ + if not class_name: + class_name = "DM" + uuid.uuid4().hex[0:8] + dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()}) + return dynamic_class.model_validate(kvs) + class UserMessage(Message): """便于支持OpenAI的消息 @@ -787,22 +833,6 @@ class CodePlanAndChangeContext(BaseModel): design_filename: str = "" task_filename: str = "" - @staticmethod - def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext: - ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", "")) - for filename in filenames: - filename = Path(filename) - if filename.is_relative_to(PRDS_FILE_REPO): - ctx.prd_filename = filename.name - continue - if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO): - ctx.design_filename = filename.name - continue - if filename.is_relative_to(TASK_FILE_REPO): - ctx.task_filename = filename.name - continue - return ctx - # mermaid class view class UMLClassMeta(BaseModel): diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 7f0c56388..2ea16f55f 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -68,7 +68,7 @@ def generate_repo( company.run_project(idea, send_to=any_to_str(ProductManager)) asyncio.run(company.run(n_round=n_round)) - return ctx.repo + return ctx.kwargs.get("project_path") @app.command("", help="Start a new project.") diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py index 7fde804fe..955058ea0 100644 --- a/metagpt/tools/libs/browser.py +++ b/metagpt/tools/libs/browser.py @@ -1,9 +1,14 @@ from __future__ import annotations +import contextlib +from uuid import uuid4 + from playwright.async_api import async_playwright from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.tools.tool_registry import register_tool +from metagpt.utils.file import MemoryFileSystem +from metagpt.utils.parse_html import simplify_html from metagpt.utils.report import BrowserReporter @@ -35,16 +40,48 @@ class Browser: print("Now on page ", url) await self._view() - async def open_new_page(self, url: str): + async def open_new_page(self, url: str, timeout: float = 30000): """open a new page in the browser and view the page""" async with self.reporter as reporter: page = await self.browser.new_page() await reporter.async_report(url, "url") - await page.goto(url) + await page.goto(url, timeout=timeout) self.pages[url] = page await self._set_current_page(page, url) await reporter.async_report(page, "page") + async def view_page_element_to_scrape(self, requirement: str, keep_links: bool = False) -> None: + """view the HTML content of current page to understand the structure. When executed, the content will be printed out + + Args: + requirement (str): Providing a clear and detailed requirement helps in focusing the inspection on the desired elements. + keep_links (bool): Whether to keep the hyperlinks in the HTML content. Set to True if links are required + """ + html = await self.current_page.content() + html = simplify_html(html, url=self.current_page.url, keep_links=keep_links) + mem_fs = MemoryFileSystem() + filename = f"{uuid4().hex}.html" + with mem_fs.open(filename, "w") as f: + f.write(html) + + # Since RAG is an optional optimization, if it fails, the simplified HTML can be used as a fallback. + with contextlib.suppress(Exception): + from metagpt.rag.engines import SimpleEngine # avoid circular import + + # TODO make `from_docs` asynchronous + engine = SimpleEngine.from_docs(input_files=[filename], fs=mem_fs) + nodes = await engine.aretrieve(requirement) + html = "\n".join(i.text for i in nodes) + + mem_fs.rm_file(filename) + print(html) + + async def get_page_content(self) -> str: + """Get the HTML content of current page.""" + html = await self.current_page.content() + html_content = html.strip() + return html_content + async def switch_page(self, url: str): """switch to an opened page in the browser and view the page""" if url in self.pages: @@ -152,8 +189,8 @@ class Browser: async def _view(self, keep_len: int = 5000) -> str: """simulate human viewing the current page, return the visible text with links""" - visible_text_with_links = await self.current_page.evaluate(VIEW_CONTENT_JS) - print("The visible text and their links (if any): ", visible_text_with_links[:keep_len]) + # visible_text_with_links = await self.current_page.evaluate(VIEW_CONTENT_JS) + # print("The visible text and their links (if any): ", visible_text_with_links[:keep_len]) # html_content = await self._view_page_html(keep_len=keep_len) # print("The html content: ", html_content) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 354e5330d..9b36d5eea 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -103,7 +103,7 @@ class Editor: file_path=file_path, block_content=block_content, ) - self.resource.report(result.file_path, "path") + self.resource.report(result.file_path, "path", extra={"type": "search", "line": i, "symbol": symbol}) return result return None diff --git a/metagpt/tools/libs/git.py b/metagpt/tools/libs/git.py index eb3fd6822..b4d759bf4 100644 --- a/metagpt/tools/libs/git.py +++ b/metagpt/tools/libs/git.py @@ -9,7 +9,6 @@ from github.Issue import Issue from github.PullRequest import PullRequest from metagpt.tools.tool_registry import register_tool -from metagpt.utils.git_repository import GitBranch, GitRepository @register_tool(tags=["software development", "git", "Commit the changes and push to remote git repository."]) @@ -18,7 +17,7 @@ async def git_push( access_token: str, comments: str = "Commit", new_branch: str = "", -) -> GitBranch: +) -> "GitBranch": """ Pushes changes from a local Git repository to its remote counterpart. @@ -49,6 +48,8 @@ async def git_push( base branch:'master', head branch:'feature/new', repo_name:'iorisa/snake-game' """ + from metagpt.utils.git_repository import GitRepository + if not GitRepository.is_git_dir(local_path): raise ValueError("Invalid local git repository") diff --git a/metagpt/utils/async_helper.py b/metagpt/utils/async_helper.py index ee440ef44..cecb20c5d 100644 --- a/metagpt/utils/async_helper.py +++ b/metagpt/utils/async_helper.py @@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any: new_loop.call_soon_threadsafe(new_loop.stop) t.join() new_loop.close() + + +class NestAsyncio: + """Make asyncio event loop reentrant.""" + + is_applied = False + + @classmethod + def apply_once(cls): + """Ensures `nest_asyncio.apply()` is called only once.""" + if not cls.is_applied: + import nest_asyncio + + nest_asyncio.apply() + cls.is_applied = True diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e2520ef13..28bfe623e 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -646,7 +646,7 @@ def role_raise_decorator(func): raise Exception(format_trackback_info(limit=None)) except Exception as e: if self.latest_observed_msg: - logger.warning( + logger.exception( "There is a exception in role's execution, in order to resume, " "we delete the newest role communication message in the role's memory." ) @@ -667,6 +667,8 @@ def role_raise_decorator(func): @handle_exception async def aread(filename: str | Path, encoding="utf-8") -> str: """Read file asynchronously.""" + if not filename or not Path(filename).exists(): + return "" try: async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader: content = await reader.read() @@ -899,3 +901,51 @@ async def init_python_folder(workdir: str | Path): return async with aiofiles.open(init_filename, "a"): os.utime(init_filename, None) + + +def get_markdown_code_block_type(filename: str) -> str: + if not filename: + return "" + ext = Path(filename).suffix + types = { + ".py": "python", + ".js": "javascript", + ".java": "java", + ".cpp": "cpp", + ".c": "c", + ".html": "html", + ".css": "css", + ".xml": "xml", + ".json": "json", + ".yaml": "yaml", + ".md": "markdown", + ".sql": "sql", + ".rb": "ruby", + ".php": "php", + ".sh": "bash", + ".swift": "swift", + ".go": "go", + ".rs": "rust", + ".pl": "perl", + ".asm": "assembly", + ".r": "r", + ".scss": "scss", + ".sass": "sass", + ".lua": "lua", + ".ts": "typescript", + ".tsx": "tsx", + ".jsx": "jsx", + ".yml": "yaml", + ".ini": "ini", + ".toml": "toml", + ".svg": "xml", # SVG can often be treated as XML + # Add more file extensions and corresponding code block types as needed + } + return types.get(ext, "") + + +def to_markdown_code_block(val: str, type_: str = "") -> str: + if not val: + return val or "" + val = val.replace("```", "\\`\\`\\`") + return f"\n```{type_}\n{val}\n```\n" diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index f62b44eb8..8861f65dc 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -9,6 +9,7 @@ from pathlib import Path import aiofiles +from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem from metagpt.logs import logger from metagpt.utils.exceptions import handle_exception @@ -68,3 +69,9 @@ class File: content = b"".join(chunks) logger.debug(f"Successfully read file, the path of file: {file_path}") return content + + +class MemoryFileSystem(_MemoryFileSystem): + @classmethod + def _strip_protocol(cls, path): + return super()._strip_protocol(str(path)) diff --git a/metagpt/utils/git_repository.py b/metagpt/utils/git_repository.py index 7b09c1775..f3d6350bd 100644 --- a/metagpt/utils/git_repository.py +++ b/metagpt/utils/git_repository.py @@ -156,6 +156,8 @@ class GitRepository: :param local_path: The local path to check. :return: True if the directory is a Git repository, False otherwise. """ + if not local_path: + return False git_dir = Path(local_path) / ".git" if git_dir.exists() and is_git_dir(git_dir): return True diff --git a/metagpt/utils/parse_html.py b/metagpt/utils/parse_html.py index 65aa3f236..1ed3a620c 100644 --- a/metagpt/utils/parse_html.py +++ b/metagpt/utils/parse_html.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import Generator, Optional from urllib.parse import urljoin, urlparse +import htmlmin from bs4 import BeautifulSoup from pydantic import BaseModel, PrivateAttr @@ -38,6 +39,22 @@ class WebPage(BaseModel): elif url.startswith(("http://", "https://")): yield urljoin(self.url, url) + def get_slim_soup(self, keep_links: bool = False): + soup = _get_soup(self.html) + keep_attrs = ["class"] + if keep_links: + keep_attrs.append("href") + + for i in soup.find_all(True): + for name in list(i.attrs): + if i[name] and name not in keep_attrs: + del i[name] + + for i in soup.find_all(["svg", "img", "video", "audio"]): + i.decompose() + + return soup + def get_html_content(page: str, base: str): soup = _get_soup(page) @@ -48,7 +65,12 @@ def get_html_content(page: str, base: str): def _get_soup(page: str): soup = BeautifulSoup(page, "html.parser") # https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup - for s in soup(["style", "script", "[document]", "head", "title"]): + for s in soup(["style", "script", "[document]", "head", "title", "footer"]): s.extract() return soup + + +def simplify_html(html: str, url: str, keep_links: bool = False): + html = WebPage(inner_text="", html=html, url=url).get_slim_soup(keep_links).decode() + return htmlmin.minify(html, remove_comments=True, remove_empty_space=True) diff --git a/metagpt/utils/project_repo.py b/metagpt/utils/project_repo.py index 64ed602a9..5761c0188 100644 --- a/metagpt/utils/project_repo.py +++ b/metagpt/utils/project_repo.py @@ -140,10 +140,11 @@ class ProjectRepo(FileRepository): return bool(code_files) def with_src_path(self, path: str | Path) -> ProjectRepo: - try: - self._srcs_path = Path(path).relative_to(self.workdir) - except ValueError: - self._srcs_path = Path(path) + path = Path(path) + if path.is_relative_to(self.workdir): + self._srcs_path = path.relative_to(self.workdir) + else: + self._srcs_path = path return self @property diff --git a/metagpt/utils/report.py b/metagpt/utils/report.py index a61c77381..2d72af111 100644 --- a/metagpt/utils/report.py +++ b/metagpt/utils/report.py @@ -39,6 +39,7 @@ class BlockType(str, Enum): GALLERY = "Gallery" NOTEBOOK = "Notebook" DOCS = "Docs" + THOUGHT = "Thought" END_MARKER_NAME = "end_marker" @@ -55,23 +56,23 @@ class ResourceReporter(BaseModel): callback_url: str = Field(METAGPT_REPORTER_DEFAULT_URL, description="The URL to which the report should be sent") _llm_task: Optional[asyncio.Task] = PrivateAttr(None) - def report(self, value: Any, name: str): + def report(self, value: Any, name: str, extra: Optional[dict] = None): """Synchronously report resource observation data. Args: value: The data to report. name: The type name of the data. """ - return self._report(value, name) + return self._report(value, name, extra) - async def async_report(self, value: Any, name: str): + async def async_report(self, value: Any, name: str, extra: Optional[dict] = None): """Asynchronously report resource observation data. Args: value: The data to report. name: The type name of the data. """ - return await self._async_report(value, name) + return await self._async_report(value, name, extra) @classmethod def set_report_fn(cls, fn: Callable): @@ -100,20 +101,20 @@ class ResourceReporter(BaseModel): """ cls._async_report = fn - def _report(self, value: Any, name: str): + def _report(self, value: Any, name: str, extra: Optional[dict] = None): if not self.callback_url: return - data = self._format_data(value, name) + data = self._format_data(value, name, extra) resp = requests.post(self.callback_url, json=data) resp.raise_for_status() return resp.text - async def _async_report(self, value: Any, name: str): + async def _async_report(self, value: Any, name: str, extra: Optional[dict] = None): if not self.callback_url: return - data = self._format_data(value, name) + data = self._format_data(value, name, extra) url = self.callback_url _result = urlparse(url) sessiion_kwargs = {} @@ -129,9 +130,16 @@ class ResourceReporter(BaseModel): resp.raise_for_status() return await resp.text() - def _format_data(self, value, name): + def _format_data(self, value, name, extra): data = self.model_dump(mode="json", exclude=("callback_url", "llm_stream")) - data["value"] = str(value) if isinstance(value, Path) else value + if isinstance(value, BaseModel): + value = value.model_dump(mode="json") + elif isinstance(value, Path): + value = str(value) + + if name == "path": + value = os.path.abspath(value) + data["value"] = value data["name"] = name role = CURRENT_ROLE.get(None) if role: @@ -139,6 +147,8 @@ class ResourceReporter(BaseModel): else: role_name = os.environ.get("METAGPT_ROLE") data["role"] = role_name + if extra: + data["extra"] = extra return data def __enter__(self): @@ -252,6 +262,16 @@ class TaskReporter(ObjectReporter): block: Literal[BlockType.TASK] = BlockType.TASK +class ThoughtReporter(ObjectReporter): + """Reporter for object resources to Task Block.""" + + block: Literal[BlockType.THOUGHT] = BlockType.THOUGHT + + async def __aenter__(self): + await self.async_report({}) + return await super().__aenter__() + + class FileReporter(ResourceReporter): """File resource callback for reporting complete file paths. @@ -259,13 +279,23 @@ class FileReporter(ResourceReporter): if the file can be partially output for display first, use streaming callback. """ - def report(self, value: Union[Path, dict, Any], name: Literal["path", "meta", "content"] = "path"): + def report( + self, + value: Union[Path, dict, Any], + name: Literal["path", "meta", "content"] = "path", + extra: Optional[dict] = None, + ): """Report file resource synchronously.""" - return super().report(value, name) + return super().report(value, name, extra) - async def async_report(self, value: Path, name: Literal["path", "meta", "content"] = "path"): + async def async_report( + self, + value: Union[Path, dict, Any], + name: Literal["path", "meta", "content"] = "path", + extra: Optional[dict] = None, + ): """Report file resource asynchronously.""" - return await super().async_report(value, name) + return await super().async_report(value, name, extra) class NotebookReporter(FileReporter): diff --git a/requirements.txt b/requirements.txt index b40c69c9f..83a904156 100644 --- a/requirements.txt +++ b/requirements.txt @@ -71,4 +71,6 @@ dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation gymnasium==0.29.1 pylint~=3.0.3 -pygithub~=2.3 \ No newline at end of file +pygithub~=2.3 +htmlmin +fsspec diff --git a/setup.py b/setup.py index 382e13a47..79b65ad47 100644 --- a/setup.py +++ b/setup.py @@ -32,12 +32,15 @@ extras_require = { "llama-index-core==0.10.15", "llama-index-embeddings-azure-openai==0.1.6", "llama-index-embeddings-openai==0.1.5", + "llama-index-embeddings-gemini==0.1.6", + "llama-index-embeddings-ollama==0.1.2", "llama-index-llms-azure-openai==0.1.4", "llama-index-readers-file==0.1.4", "llama-index-retrievers-bm25==0.1.3", "llama-index-vector-stores-faiss==0.1.1", "llama-index-vector-stores-elasticsearch==0.1.6", "llama-index-vector-stores-chroma==0.1.6", + "docx2txt==0.8", ], } diff --git a/tests/conftest.py b/tests/conftest.py index f26ab2ef9..1f6661f7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,6 @@ import logging import os import re import uuid -from pathlib import Path from typing import Callable import aiohttp.web @@ -23,7 +22,6 @@ from metagpt.context import Context as MetagptContext from metagpt.llm import LLM from metagpt.logs import logger from metagpt.utils.git_repository import GitRepository -from metagpt.utils.project_repo import ProjectRepo from tests.mock.mock_aiohttp import MockAioResponse from tests.mock.mock_curl_cffi import MockCurlCffiResponse from tests.mock.mock_httplib2 import MockHttplib2Response @@ -149,13 +147,14 @@ def loguru_caplog(caplog): @pytest.fixture(scope="function") def context(request): ctx = MetagptContext() - ctx.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") - ctx.repo = ProjectRepo(ctx.git_repo) + repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") + ctx.config.project_path = str(repo.workdir) # Destroy git repo at the end of the test session. def fin(): - if ctx.git_repo: - ctx.git_repo.delete_repository() + if ctx.config.project_path: + git_repo = GitRepository(ctx.config.project_path) + git_repo.delete_repository() # Register the function for destroying the environment. request.addfinalizer(fin) @@ -279,6 +278,6 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache): @pytest.fixture def git_dir(): """Fixture to get the unittest directory.""" - git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}" + git_dir = DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}" git_dir.mkdir(parents=True, exist_ok=True) return git_dir diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 989e2249c..bc85925a8 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -303,5 +303,4 @@ def test_action_node_from_pydantic_and_print_everything(): if __name__ == "__main__": - test_create_model_class() - test_create_model_class_with_mapping() + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index 9924a2e84..3398be5e6 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -6,37 +6,104 @@ @File : test_design_api.py @Modifiled By: mashenquan, 2023-12-6. According to RFC 135 """ +from pathlib import Path + import pytest from metagpt.actions.design_api import WriteDesign -from metagpt.llm import LLM +from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message +from metagpt.utils.project_repo import ProjectRepo from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON @pytest.mark.asyncio -async def test_design_api(context): - inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE - for prd in inputs: - await context.repo.docs.prd.save(filename="new_prd.txt", content=prd) +async def test_design(context): + # Mock new design env + prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。" + context.kwargs.project_path = context.config.project_path + context.kwargs.inc = False + filename = "prd.txt" + repo = ProjectRepo(context.kwargs.project_path) + await repo.docs.prd.save(filename=filename, content=prd) + kvs = { + "project_path": str(context.kwargs.project_path), + "changed_prd_filenames": [str(repo.docs.prd.workdir / filename)], + } + instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput") - design_api = WriteDesign(context=context) - - result = await design_api.run(Message(content=prd, instruct_content=None)) - logger.info(result) - - assert result - - -@pytest.mark.asyncio -async def test_refined_design_api(context): - await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON)) - await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE) - - design_api = WriteDesign(context=context, llm=LLM()) - - result = await design_api.run(Message(content="", instruct_content=None)) + design_api = WriteDesign(context=context) + result = await design_api.run([Message(content=prd, instruct_content=instruct_content)]) logger.info(result) - assert result + assert isinstance(result, AIMessage) + assert result.instruct_content + assert repo.docs.system_design.changed_files + + # Mock incremental design env + context.kwargs.inc = True + await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON)) + await repo.docs.system_design.save(filename=filename, content=DESIGN_SAMPLE) + + result = await design_api.run([Message(content="", instruct_content=instruct_content)]) + logger.info(result) + assert result + assert isinstance(result, AIMessage) + assert result.instruct_content + assert repo.docs.system_design.changed_files + + +@pytest.mark.parametrize( + ("user_requirement", "prd_filename", "legacy_design_filename"), + [ + ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), + ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), + ( + "write 2048 game", + str(METAGPT_ROOT / "tests/data/prd.json"), + str(METAGPT_ROOT / "tests/data/system_design.json"), + ), + ], +) +@pytest.mark.asyncio +async def test_design_api(context, user_requirement, prd_filename, legacy_design_filename): + action = WriteDesign() + result = await action.run( + user_requirement=user_requirement, prd_filename=prd_filename, legacy_design_filename=legacy_design_filename + ) + assert isinstance(result, AIMessage) + assert result.content + assert str(DEFAULT_WORKSPACE_ROOT) in result.content + + +@pytest.mark.parametrize( + ("user_requirement", "prd_filename", "legacy_design_filename"), + [ + ("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None), + ("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None), + ( + "write 2048 game", + str(METAGPT_ROOT / "tests/data/prd.json"), + str(METAGPT_ROOT / "tests/data/system_design.json"), + ), + ], +) +@pytest.mark.asyncio +async def test_design_api_dir(context, user_requirement, prd_filename, legacy_design_filename): + action = WriteDesign() + result = await action.run( + user_requirement=user_requirement, + prd_filename=prd_filename, + legacy_design_filename=legacy_design_filename, + output_pathname=str(Path(context.config.project_path) / "1.txt"), + ) + assert isinstance(result, AIMessage) + assert result.content + assert str(context.config.project_path) in result.content + assert result.instruct_content + assert result.instruct_content.changed_system_design_filenames + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_project_management.py b/tests/metagpt/actions/test_project_management.py index 5d0d11efb..26699dea7 100644 --- a/tests/metagpt/actions/test_project_management.py +++ b/tests/metagpt/actions/test_project_management.py @@ -5,13 +5,15 @@ @Author : alexanderwu @File : test_project_management.py """ +import json import pytest from metagpt.actions.project_management import WriteTasks -from metagpt.llm import LLM +from metagpt.const import METAGPT_ROOT from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message +from metagpt.utils.project_repo import ProjectRepo from tests.data.incremental_dev_project.mock import ( REFINED_DESIGN_JSON, REFINED_PRD_JSON, @@ -22,29 +24,46 @@ from tests.metagpt.actions.mock_json import DESIGN, PRD @pytest.mark.asyncio async def test_task(context): - await context.repo.docs.prd.save("1.txt", content=str(PRD)) - await context.repo.docs.system_design.save("1.txt", content=str(DESIGN)) - logger.info(context.git_repo) + # Mock write tasks env + context.kwargs.project_path = context.config.project_path + context.kwargs.inc = False + repo = ProjectRepo(context.kwargs.project_path) + filename = "1.txt" + await repo.docs.prd.save(filename=filename, content=str(PRD)) + await repo.docs.system_design.save(filename=filename, content=str(DESIGN)) + kvs = { + "project_path": context.kwargs.project_path, + "changed_system_design_filenames": [str(repo.docs.system_design.workdir / filename)], + } + instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput") action = WriteTasks(context=context) - - result = await action.run(Message(content="", instruct_content=None)) + result = await action.run([Message(content="", instruct_content=instruct_content)]) logger.info(result) - assert result + assert result.instruct_content.changed_task_filenames + + # Mock incremental env + context.kwargs.inc = True + await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON)) + await repo.docs.system_design.save(filename=filename, content=str(REFINED_DESIGN_JSON)) + await repo.docs.task.save(filename=filename, content=TASK_SAMPLE) + + result = await action.run([Message(content="", instruct_content=instruct_content)]) + logger.info(result) + assert result + assert result.instruct_content.changed_task_filenames @pytest.mark.asyncio -async def test_refined_task(context): - await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON)) - await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON)) - await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE) - - logger.info(context.git_repo) - - action = WriteTasks(context=context, llm=LLM()) - - result = await action.run(Message(content="", instruct_content=None)) - logger.info(result) - +async def test_task_api(context): + action = WriteTasks() + result = await action.run(design_filename=str(METAGPT_ROOT / "tests/data/system_design.json")) assert result + assert result.content + m = json.loads(result.content) + assert m + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 42623f807..1c1772031 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -26,12 +26,7 @@ from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPL def setup_inc_workdir(context, inc: bool = False): """setup incremental workdir for testing""" - context.src_workspace = context.git_repo.workdir / "src" - if inc: - context.config.inc = inc - context.repo.old_workspace = context.repo.git_repo.workdir / "old" - context.config.project_path = "old" - + context.config.inc = inc return context diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 43aa336b7..6cf1da1dc 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -6,25 +6,26 @@ @File : test_write_prd.py @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`. """ +import uuid +from pathlib import Path import pytest from metagpt.actions import UserRequirement, WritePRD -from metagpt.const import REQUIREMENT_FILENAME +from metagpt.const import DEFAULT_WORKSPACE_ROOT, REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.roles.product_manager import ProductManager from metagpt.roles.role import RoleReactMode -from metagpt.schema import Message +from metagpt.schema import AIMessage, Message from metagpt.utils.common import any_to_str -from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE -from tests.metagpt.actions.test_write_code import setup_inc_workdir +from metagpt.utils.project_repo import ProjectRepo +from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE @pytest.mark.asyncio async def test_write_prd(new_filename, context): product_manager = ProductManager(context=context) requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结" - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) product_manager.rc.react_mode = RoleReactMode.BY_ORDER prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement)) assert prd.cause_by == any_to_str(WritePRD) @@ -34,38 +35,39 @@ async def test_write_prd(new_filename, context): # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert product_manager.context.repo.docs.prd.changed_files + repo = ProjectRepo(context.kwargs.project_path) + assert repo.docs.prd.changed_files + repo.git_repo.archive() - -@pytest.mark.asyncio -async def test_write_prd_inc(new_filename, context, git_dir): - context = setup_inc_workdir(context, inc=True) - await context.repo.docs.prd.save("1.txt", PRD_SAMPLE) - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE) + # Mock incremental requirement + context.config.inc = True + context.config.project_path = context.kwargs.project_path + repo = ProjectRepo(context.config.project_path) + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE) action = WritePRD(context=context) - prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)) + prd = await action.run([Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)]) logger.info(NEW_REQUIREMENT_SAMPLE) logger.info(prd) # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert "Refined Requirements" in prd.content + assert repo.git_repo.changed_files @pytest.mark.asyncio async def test_fix_debug(new_filename, context, git_dir): - context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name + # Mock legacy project + context.kwargs.project_path = str(git_dir) + repo = ProjectRepo(context.kwargs.project_path) + repo.with_src_path(git_dir.name) + await repo.srcs.save(filename="main.py", content='if __name__ == "__main__":\nmain()') + requirements = "ValueError: undefined variable `st`." + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) - await context.repo.with_src_path(context.src_workspace).srcs.save( - filename="main.py", content='if __name__ == "__main__":\nmain()' - ) - requirements = "Please fix the bug in the code." - await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements) action = WritePRD(context=context) - - prd = await action.run(Message(content=requirements, instruct_content=None)) + prd = await action.run([Message(content=requirements, instruct_content=None)]) logger.info(prd) # Assert the prd is not None or empty @@ -73,5 +75,40 @@ async def test_fix_debug(new_filename, context, git_dir): assert prd.content != "" +@pytest.mark.asyncio +async def test_write_prd_api(context): + action = WritePRD() + result = await action.run(user_requirement="write a snake game.") + assert isinstance(result, AIMessage) + assert result.content + assert str(DEFAULT_WORKSPACE_ROOT) in result.content + + result = await action.run( + user_requirement="write a snake game.", + output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"), + ) + assert isinstance(result, AIMessage) + assert result.content + assert result.instruct_content + assert str(context.config.project_path) in result.content + + legacy_prd_filename = result.instruct_content.changed_prd_filenames[-1] + + result = await action.run(user_requirement="Add moving enemy.", legacy_prd_filename=legacy_prd_filename) + assert isinstance(result, AIMessage) + assert result.content + assert str(DEFAULT_WORKSPACE_ROOT) in result.content + + result = await action.run( + user_requirement="Add moving enemy.", + output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"), + legacy_prd_filename=legacy_prd_filename, + ) + assert isinstance(result, AIMessage) + assert result.content + assert result.instruct_content + assert str(context.config.project_path) in result.content + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 9262ccb07..8c7a15be2 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -25,10 +25,6 @@ class TestSimpleEngine: def mock_simple_directory_reader(self, mocker): return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") - @pytest.fixture - def mock_vector_store_index(self, mocker): - return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") - @pytest.fixture def mock_get_retriever(self, mocker): return mocker.patch("metagpt.rag.engines.simple.get_retriever") @@ -45,7 +41,6 @@ class TestSimpleEngine: self, mocker, mock_simple_directory_reader, - mock_vector_store_index, mock_get_retriever, mock_get_rankers, mock_get_response_synthesizer, @@ -81,11 +76,8 @@ class TestSimpleEngine: # Assert mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) - mock_vector_store_index.assert_called_once() - mock_get_retriever.assert_called_once_with( - configs=retriever_configs, index=mock_vector_store_index.return_value - ) - mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm) + mock_get_retriever.assert_called_once() + mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) assert isinstance(engine, SimpleEngine) @@ -119,7 +111,7 @@ class TestSimpleEngine: # Assert assert isinstance(engine, SimpleEngine) - assert engine.index is not None + assert engine._transformations is not None def test_from_objs_with_bm25_config(self): # Setup @@ -137,6 +129,7 @@ class TestSimpleEngine: def test_from_index(self, mocker, mock_llm, mock_embedding): # Mock mock_index = mocker.MagicMock(spec=VectorStoreIndex) + mock_index.as_retriever.return_value = "retriever" mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index") mock_get_index.return_value = mock_index @@ -149,7 +142,7 @@ class TestSimpleEngine: # Assert assert isinstance(engine, SimpleEngine) - assert engine.index is mock_index + assert engine._retriever == "retriever" @pytest.mark.asyncio async def test_asearch(self, mocker): @@ -200,14 +193,11 @@ class TestSimpleEngine: mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever) - mock_index = mocker.MagicMock(spec=VectorStoreIndex) - mock_index._transformations = mocker.MagicMock() - mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations") mock_run_transformations.return_value = ["node1", "node2"] # Setup - engine = SimpleEngine(retriever=mock_retriever, index=mock_index) + engine = SimpleEngine(retriever=mock_retriever) input_files = ["test_file1", "test_file2"] # Exec @@ -230,7 +220,7 @@ class TestSimpleEngine: return "" objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)] - engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock()) + engine = SimpleEngine(retriever=mock_retriever) # Exec engine.add_objs(objs=objs) diff --git a/tests/metagpt/rag/factories/test_base.py b/tests/metagpt/rag/factories/test_base.py index 1d41e1872..0b0a44976 100644 --- a/tests/metagpt/rag/factories/test_base.py +++ b/tests/metagpt/rag/factories/test_base.py @@ -97,6 +97,5 @@ class TestConfigBasedFactory: def test_val_from_config_or_kwargs_key_error(self): # Test KeyError when the key is not found in both config object and kwargs config = DummyConfig(name=None) - with pytest.raises(KeyError) as exc_info: - ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) - assert "The key 'missing_key' is required but not provided" in str(exc_info.value) + val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) + assert val is None diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1ded6b4a8..1a9e9b2c9 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -1,5 +1,6 @@ import pytest +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.embedding import RAGEmbeddingFactory @@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory: self.embedding_factory = RAGEmbeddingFactory() @pytest.fixture - def mock_openai_embedding(self, mocker): + def mock_config(self, mocker): + return mocker.patch("metagpt.rag.factories.embedding.config") + + @staticmethod + def mock_openai_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") - @pytest.fixture - def mock_azure_embedding(self, mocker): + @staticmethod + def mock_azure_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") - def test_get_rag_embedding_openai(self, mock_openai_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.OPENAI) + @staticmethod + def mock_gemini_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") - # Assert - mock_openai_embedding.assert_called_once() + @staticmethod + def mock_ollama_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") - def test_get_rag_embedding_azure(self, mock_azure_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.AZURE) - - # Assert - mock_azure_embedding.assert_called_once() - - def test_get_rag_embedding_default(self, mocker, mock_openai_embedding): + @pytest.mark.parametrize( + ("mock_func", "embedding_type"), + [ + (mock_openai_embedding, LLMType.OPENAI), + (mock_azure_embedding, LLMType.AZURE), + (mock_openai_embedding, EmbeddingType.OPENAI), + (mock_azure_embedding, EmbeddingType.AZURE), + (mock_gemini_embedding, EmbeddingType.GEMINI), + (mock_ollama_embedding, EmbeddingType.OLLAMA), + ], + ) + def test_get_rag_embedding(self, mock_func, embedding_type, mocker): # Mock - mock_config = mocker.patch("metagpt.rag.factories.embedding.config") + mock = mock_func(mocker) + + # Exec + self.embedding_factory.get_rag_embedding(embedding_type) + + # Assert + mock.assert_called_once() + + def test_get_rag_embedding_default(self, mocker, mock_config): + # Mock + mock_openai_embedding = self.mock_openai_embedding(mocker) + + mock_config.embedding.api_type = None mock_config.llm.api_type = LLMType.OPENAI # Exec @@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory: # Assert mock_openai_embedding.assert_called_once() + + @pytest.mark.parametrize( + "model, embed_batch_size, expected_params", + [("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})], + ) + def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params): + # Mock + mock_config.embedding.model = model + mock_config.embedding.embed_batch_size = embed_batch_size + + # Setup + test_params = {} + + # Exec + self.embedding_factory._try_set_model_and_batch_size(test_params) + + # Assert + assert test_params == expected_params + + def test_resolve_embedding_type(self, mock_config): + # Mock + mock_config.embedding.api_type = EmbeddingType.OPENAI + + # Exec + embedding_type = self.embedding_factory._resolve_embedding_type() + + # Assert + assert embedding_type == EmbeddingType.OPENAI + + def test_resolve_embedding_type_exception(self, mock_config): + # Mock + mock_config.embedding.api_type = None + mock_config.llm.api_type = LLMType.GEMINI + + # Assert + with pytest.raises(TypeError): + self.embedding_factory._resolve_embedding_type() + + def test_raise_for_key(self): + with pytest.raises(ValueError): + self.embedding_factory._raise_for_key("key") diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index ef1cef7e0..cd55a32db 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -1,6 +1,8 @@ import faiss import pytest from llama_index.core import VectorStoreIndex +from llama_index.core.embeddings import MockEmbedding +from llama_index.core.schema import TextNode from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore @@ -43,6 +45,14 @@ class TestRetrieverFactory: def mock_es_vector_store(self, mocker): return mocker.MagicMock(spec=ElasticsearchStore) + @pytest.fixture + def mock_nodes(self, mocker): + return [TextNode(text="msg")] + + @pytest.fixture + def mock_embedding(self): + return MockEmbedding(embed_dim=1) + def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index): mock_config = FAISSRetrieverConfig(dimensions=128) mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index) @@ -52,42 +62,40 @@ class TestRetrieverFactory: assert isinstance(retriever, FAISSRetriever) - def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index): + def test_get_retriever_with_bm25_config(self, mocker, mock_nodes): mock_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes) assert isinstance(retriever, DynamicBM25Retriever) - def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index): - mock_faiss_config = FAISSRetrieverConfig(dimensions=128) + def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding): + mock_faiss_config = FAISSRetrieverConfig(dimensions=1) mock_bm25_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config]) + retriever = self.retriever_factory.get_retriever( + configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding + ) assert isinstance(retriever, SimpleHybridRetriever) - def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store): + def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding): mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection") mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient") mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock() mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) assert isinstance(retriever, ChromaRetriever) - def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store): + def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding): mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig()) mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store) - mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding) assert isinstance(retriever, ElasticsearchRetriever) @@ -111,3 +119,19 @@ class TestRetrieverFactory: extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index) assert extracted_index == mock_vector_store_index + + def test_get_or_build_when_get(self, mocker): + want = "existing_index" + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want) + + got = self.retriever_factory._build_es_index(None) + + assert got == want + + def test_get_or_build_when_build(self, mocker): + want = "call_build_es_index" + mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want) + + got = self.retriever_factory._build_es_index(None) + + assert got == want diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 6f54b062d..48f13f4a2 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -392,5 +392,11 @@ async def test_parse_resources(context, content: str, key_descriptions): assert k in result +@pytest.mark.parametrize(("name", "value"), [("c1", {"age": 10, "name": "Alice"}), ("", {"path": __file__})]) +def test_create_instruct_value(name, value): + obj = Message.create_instruct_value(kvs=value, class_name=name) + assert obj.model_dump() == value + + if __name__ == "__main__": pytest.main([__file__, "-s"])