From 222ae5ada336e3736aadc260bfc8cd8026e799a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 27 Mar 2024 21:11:40 +0800 Subject: [PATCH] feat: +import external repo --- metagpt/actions/extract_readme.py | 13 +- metagpt/actions/import_repo.py | 178 +++++++++++++++++- metagpt/actions/rebuild_sequence_view.py | 16 +- metagpt/roles/engineer.py | 12 +- tests/metagpt/actions/test_import_repo.py | 7 + .../actions/test_rebuild_sequence_view.py | 2 +- tests/mock/mock_llm.py | 11 +- 7 files changed, 216 insertions(+), 23 deletions(-) diff --git a/metagpt/actions/extract_readme.py b/metagpt/actions/extract_readme.py index aeb3608a0..69f5503a9 100644 --- a/metagpt/actions/extract_readme.py +++ b/metagpt/actions/extract_readme.py @@ -63,6 +63,7 @@ class ExtractReadMe(Action): "You are a tool can summarize git repository README.md file.", "Return the summary about what is the repository.", ], + stream=False, ) return summary @@ -77,6 +78,7 @@ class ExtractReadMe(Action): f"2. cd `{self.install_to_path}`;\n" f"3. install the repository.", ], + stream=False, ) return install @@ -89,6 +91,7 @@ class ExtractReadMe(Action): "Return a bash code block of markdown object to configure the repository if necessary, otherwise return" " a empty bash code block of markdown object", ], + stream=False, ) return configuration @@ -100,13 +103,21 @@ class ExtractReadMe(Action): "You are a tool can summarize all usages of git repository according to README.md file.", "Return a list of code block of markdown objects to demonstrates the usage of the repository.", ], + stream=False, ) return usage async def _get(self) -> str: if self._readme is not None: return self._readme - filename = Path(self.i_context).resolve() / "README.md" + root = Path(self.i_context).resolve() + filename = None + for file_path in root.iterdir(): + if file_path.is_file() and file_path.stem == "README": + filename = file_path + break + if not filename: + return "" self._readme = await aread(filename=filename, encoding="utf-8") self._filename = str(filename) return self._readme diff --git a/metagpt/actions/import_repo.py b/metagpt/actions/import_repo.py index 07a0b639c..5b1624a2f 100644 --- a/metagpt/actions/import_repo.py +++ b/metagpt/actions/import_repo.py @@ -1,20 +1,44 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import List +import json +import re +from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel from metagpt.actions import Action +from metagpt.actions.extract_readme import ExtractReadMe +from metagpt.actions.rebuild_class_view import RebuildClassView +from metagpt.actions.rebuild_sequence_view import RebuildSequenceView +from metagpt.const import GRAPH_REPO_FILE_REPO +from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.libs.git import git_clone +from metagpt.utils.common import ( + aread, + awrite, + list_files, + parse_json_code_block, + split_namespace, +) +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository +from metagpt.utils.graph_repository import GraphKeyword, GraphRepository from metagpt.utils.project_repo import ProjectRepo class ImportRepo(Action): repo_path: str + graph_db: Optional[GraphRepository] = None + rid: str = "" async def run(self, with_messages: List[Message] = None, **kwargs) -> Message: await self._create_repo() - pass + await self._create_prd() + await self._create_system_design() + self.context.git_repo.archive(comments="Import") async def _create_repo(self): path = await git_clone(url=self.repo_path, output_dir=self.config.workspace.path) @@ -22,3 +46,153 @@ class ImportRepo(Action): self.config.project_path = path self.context.git_repo = GitRepository(local_path=path, auto_init=True) self.context.repo = ProjectRepo(self.context.git_repo) + self.context.src_workspace = await self._guess_src_workspace() + await awrite( + filename=self.context.repo.workdir / ".src_workspace", + data=str(self.context.src_workspace.relative_to(self.context.repo.workdir)), + ) + + async def _create_prd(self): + action = ExtractReadMe(i_context=str(self.context.repo.workdir), context=self.context) + await action.run() + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SUMMARY) + prd = {"Project Name": self.context.repo.workdir.name} + for r in rows: + if Path(r.subject).stem == "README": + prd["Original Requirements"] = r.object_ + break + self.rid = FileRepository.new_filename() + await self.repo.docs.prd.save(filename=self.rid + ".json", content=json.dumps(prd)) + + async def _create_system_design(self): + action = RebuildClassView( + name="ReverseEngineering", i_context=str(self.context.src_workspace), context=self.context + ) + await action.run() + rows = await action.graph_db.select(predicate="hasMermaidClassDiagramFile") + class_view_filename = rows[0].object_ + logger.info(f"class view:{class_view_filename}") + + rows = await action.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO) + tag = "__name__:__main__" + entries = [] + src_workspace = self.context.src_workspace.relative_to(self.context.repo.workdir) + for r in rows: + if tag in r.subject: + path = split_namespace(r.subject)[0] + elif tag in r.object_: + path = split_namespace(r.object_)[0] + else: + continue + if Path(path).is_relative_to(src_workspace): + entries.append(Path(path)) + main_entry = await self._guess_main_entry(entries) + full_path = RebuildSequenceView.get_full_filename(self.context.repo.workdir, main_entry) + action = RebuildSequenceView(context=self.context, i_context=str(full_path)) + try: + await action.run() + except Exception as e: + logger.warning(f"{e}, use the last successful version.") + files = list_files(self.context.repo.resources.data_api_design.workdir) + pattern = re.compile(r"[^a-zA-Z0-9]") + name = re.sub(pattern, "_", str(main_entry)) + filename = Path(name).with_suffix(".sequence_diagram.mmd") + postfix = str(filename) + sequence_files = [i for i in files if postfix in str(i)] + content = await aread(filename=sequence_files[0]) + await self.context.repo.resources.data_api_design.save( + filename=self.repo.workdir.stem + ".sequence_diagram.mmd", content=content + ) + await self._save_system_design() + + async def _save_system_design(self): + class_view = await self.context.repo.resources.data_api_design.get( + filename=self.repo.workdir.stem + ".class_diagram.mmd" + ) + sequence_view = await self.context.repo.resources.data_api_design.get( + filename=self.repo.workdir.stem + ".sequence_diagram.mmd" + ) + file_list = self.context.git_repo.get_files(relative_path=".", root_relative_path=self.context.src_workspace) + data = { + "Data structures and interfaces": class_view.content, + "Program call flow": sequence_view.content, + "File list": [str(i) for i in file_list], + } + await self.context.repo.docs.system_design.save(filename=self.rid + ".json", content=json.dumps(data)) + + async def _guess_src_workspace(self) -> Path: + files = list_files(self.context.repo.workdir) + dirs = [i.parent for i in files if i.name == "__init__.py"] + distinct = set() + for i in dirs: + done = False + for j in distinct: + if i.is_relative_to(j): + done = True + break + if j.is_relative_to(i): + break + if not done: + distinct = {j for j in distinct if not j.is_relative_to(i)} + distinct.add(i) + if len(distinct) == 1: + return list(distinct)[0] + prompt = "\n".join([f"- {str(i)}" for i in distinct]) + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a tool to choose the source code path from a list of paths based on the directory name.", + "You should identify the source code path among paths such as unit test path, examples path, etc.", + "Return a markdown JSON object containing:\n" + '- a "src" field containing the source code path;\n' + '- a "reason" field containing explaining why other paths is not the source code path\n', + ], + ) + logger.debug(rsp) + json_blocks = parse_json_code_block(rsp) + + class Data(BaseModel): + src: str + reason: str + + data = Data.model_validate_json(json_blocks[0]) + logger.info(f"src_workspace: {data.src}") + return Path(data.src) + + async def _guess_main_entry(self, entries: List[Path]) -> Path: + if len(entries) == 1: + return entries[0] + + file_list = "## File List\n" + file_list += "\n".join([f"- {i}" for i in entries]) + + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_USAGE) + usage = "## Usage\n" + for r in rows: + if Path(r.subject).stem == "README": + usage += r.object_ + + prompt = file_list + "\n---\n" + usage + rsp = await self.llm.aask( + prompt, + system_msgs=[ + 'You are a tool to choose the source file path from "File List" which is used in "Usage".', + 'You choose the source file path based on the name of file and the class name and package name used in "Usage".', + "Return a markdown JSON object containing:\n" + '- a "filename" field containing the chosen source file path from "File List" which is used in "Usage";\n' + '- a "reason" field explaining why.', + ], + stream=False, + ) + logger.debug(rsp) + json_blocks = parse_json_code_block(rsp) + + class Data(BaseModel): + filename: str + reason: str + + data = Data.model_validate_json(json_blocks[0]) + logger.info(f"main: {data.filename}") + return Path(data.filename) diff --git a/metagpt/actions/rebuild_sequence_view.py b/metagpt/actions/rebuild_sequence_view.py index 0e67de908..fd356d58f 100644 --- a/metagpt/actions/rebuild_sequence_view.py +++ b/metagpt/actions/rebuild_sequence_view.py @@ -244,15 +244,6 @@ class RebuildSequenceView(Action): class_view = await self._get_uml_class_view(ns_class_name) source_code = await self._get_source_code(ns_class_name) - # prompt_blocks = [ - # "## Instruction\n" - # "You are a python code to UML 2.0 Use Case translator.\n" - # 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n' - # "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not " - # 'conflict with the information in "Mermaid Class Views".\n' - # 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external ' - # "system interactions with the internal system.\n" - # ] prompt_blocks = [] block = "## Participants\n" for p in participants: @@ -340,6 +331,7 @@ class RebuildSequenceView(Action): system_msgs=[ "You are a Mermaid Sequence Diagram translator in function detail.", "Translate the markdown text to a Mermaid Sequence Diagram.", + "Response must be concise.", "Return a markdown mermaid code block.", ], stream=False, @@ -440,7 +432,7 @@ class RebuildSequenceView(Action): rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO) filename = split_namespace(ns_class_name=ns_class_name)[0] if not rows: - src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename) + src_filename = RebuildSequenceView.get_full_filename(root=self.i_context, pathname=filename) if not src_filename: return "" return await aread(filename=src_filename, encoding="utf-8") @@ -450,7 +442,7 @@ class RebuildSequenceView(Action): ) @staticmethod - def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None: + def get_full_filename(root: str | Path, pathname: str | Path) -> Path | None: """ Convert package name to the full path of the module. @@ -466,7 +458,7 @@ class RebuildSequenceView(Action): "metagpt/management/skill_manager.py", then the returned value will be "/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py" """ - if re.match(r"^/.+", pathname): + if re.match(r"^/.+", str(pathname)): return pathname files = list_files(root=root) postfix = "/" + str(pathname) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 9d8f6884f..007e586ac 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -45,7 +45,7 @@ from metagpt.schema import ( Documents, Message, ) -from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set +from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set, aread IS_PASS_PROMPT = """ {context} @@ -239,7 +239,8 @@ class Engineer(Role): async def _think(self) -> Action | None: if not self.src_workspace: - self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name + name = self._get_src_workspace_name() + self.src_workspace = self.git_repo.workdir / name write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug]) write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode]) summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview]) @@ -383,3 +384,10 @@ class Engineer(Role): def action_description(self) -> str: """AgentStore uses this attribute to display to the user what actions the current role should take.""" return self.next_todo_action + + async def _get_src_workspace_name(self): + name = self.git_repo.workdir.name + src_workspace_filename = self.git_repo.workdir / ".src_workspace" + if src_workspace_filename.exists(): + name = await aread(filename=src_workspace_filename) + return name diff --git a/tests/metagpt/actions/test_import_repo.py b/tests/metagpt/actions/test_import_repo.py index 6b60abe36..cfe08ee75 100644 --- a/tests/metagpt/actions/test_import_repo.py +++ b/tests/metagpt/actions/test_import_repo.py @@ -5,6 +5,7 @@ import pytest from metagpt.actions.import_repo import ImportRepo from metagpt.context import Context +from metagpt.utils.common import list_files @pytest.mark.asyncio @@ -13,6 +14,12 @@ async def test_import_repo(repo_path): context = Context() action = ImportRepo(repo_path=repo_path, context=context) await action.run() + assert context.repo + prd = list_files(context.repo.docs.prd.workdir) + assert prd + design = list_files(context.repo.docs.system_design.workdir) + assert design + assert prd[0].stem == design[0].stem if __name__ == "__main__": diff --git a/tests/metagpt/actions/test_rebuild_sequence_view.py b/tests/metagpt/actions/test_rebuild_sequence_view.py index 9be3e8a99..e2827c334 100644 --- a/tests/metagpt/actions/test_rebuild_sequence_view.py +++ b/tests/metagpt/actions/test_rebuild_sequence_view.py @@ -61,7 +61,7 @@ async def test_rebuild(context, mocker): ], ) def test_get_full_filename(root, pathname, want): - res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname) + res = RebuildSequenceView.get_full_filename(root=root, pathname=pathname) assert res == want diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index c4262e080..f6c206d5e 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -3,6 +3,7 @@ from typing import Optional, Union from metagpt.config2 import config from metagpt.configs.llm_config import LLMType +from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import logger from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA @@ -22,7 +23,7 @@ class MockLLM(OriginalLLM): self.rsp_cache: dict = {} self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list - async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=LLM_API_TIMEOUT) -> str: """Overwrite original acompletion_text to cancel retry""" if stream: resp = await self._achat_completion_stream(messages, timeout=timeout) @@ -37,7 +38,7 @@ class MockLLM(OriginalLLM): system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, - timeout=3, + timeout=LLM_API_TIMEOUT, stream=True, ) -> str: if system_msgs: @@ -56,7 +57,7 @@ class MockLLM(OriginalLLM): rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp - async def original_aask_batch(self, msgs: list, timeout=3) -> str: + async def original_aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str: """A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked""" context = [] for msg in msgs: @@ -83,7 +84,7 @@ class MockLLM(OriginalLLM): system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, - timeout=3, + timeout=LLM_API_TIMEOUT, stream=True, ) -> str: # used to identify it a message has been called before @@ -98,7 +99,7 @@ class MockLLM(OriginalLLM): rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream) return rsp - async def aask_batch(self, msgs: list, timeout=3) -> str: + async def aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str: msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs]) rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout) return rsp