feat: +import external repo

This commit is contained in:
莘权 马 2024-03-27 21:11:40 +08:00
parent f4240ca483
commit 222ae5ada3
7 changed files with 216 additions and 23 deletions

View file

@ -63,6 +63,7 @@ class ExtractReadMe(Action):
"You are a tool can summarize git repository README.md file.",
"Return the summary about what is the repository.",
],
stream=False,
)
return summary
@ -77,6 +78,7 @@ class ExtractReadMe(Action):
f"2. cd `{self.install_to_path}`;\n"
f"3. install the repository.",
],
stream=False,
)
return install
@ -89,6 +91,7 @@ class ExtractReadMe(Action):
"Return a bash code block of markdown object to configure the repository if necessary, otherwise return"
" a empty bash code block of markdown object",
],
stream=False,
)
return configuration
@ -100,13 +103,21 @@ class ExtractReadMe(Action):
"You are a tool can summarize all usages of git repository according to README.md file.",
"Return a list of code block of markdown objects to demonstrates the usage of the repository.",
],
stream=False,
)
return usage
async def _get(self) -> str:
if self._readme is not None:
return self._readme
filename = Path(self.i_context).resolve() / "README.md"
root = Path(self.i_context).resolve()
filename = None
for file_path in root.iterdir():
if file_path.is_file() and file_path.stem == "README":
filename = file_path
break
if not filename:
return ""
self._readme = await aread(filename=filename, encoding="utf-8")
self._filename = str(filename)
return self._readme

View file

@ -1,20 +1,44 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import List
import json
import re
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel
from metagpt.actions import Action
from metagpt.actions.extract_readme import ExtractReadMe
from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools.libs.git import git_clone
from metagpt.utils.common import (
aread,
awrite,
list_files,
parse_json_code_block,
split_namespace,
)
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
from metagpt.utils.project_repo import ProjectRepo
class ImportRepo(Action):
repo_path: str
graph_db: Optional[GraphRepository] = None
rid: str = ""
async def run(self, with_messages: List[Message] = None, **kwargs) -> Message:
await self._create_repo()
pass
await self._create_prd()
await self._create_system_design()
self.context.git_repo.archive(comments="Import")
async def _create_repo(self):
path = await git_clone(url=self.repo_path, output_dir=self.config.workspace.path)
@ -22,3 +46,153 @@ class ImportRepo(Action):
self.config.project_path = path
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
self.context.repo = ProjectRepo(self.context.git_repo)
self.context.src_workspace = await self._guess_src_workspace()
await awrite(
filename=self.context.repo.workdir / ".src_workspace",
data=str(self.context.src_workspace.relative_to(self.context.repo.workdir)),
)
async def _create_prd(self):
action = ExtractReadMe(i_context=str(self.context.repo.workdir), context=self.context)
await action.run()
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SUMMARY)
prd = {"Project Name": self.context.repo.workdir.name}
for r in rows:
if Path(r.subject).stem == "README":
prd["Original Requirements"] = r.object_
break
self.rid = FileRepository.new_filename()
await self.repo.docs.prd.save(filename=self.rid + ".json", content=json.dumps(prd))
async def _create_system_design(self):
action = RebuildClassView(
name="ReverseEngineering", i_context=str(self.context.src_workspace), context=self.context
)
await action.run()
rows = await action.graph_db.select(predicate="hasMermaidClassDiagramFile")
class_view_filename = rows[0].object_
logger.info(f"class view:{class_view_filename}")
rows = await action.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
tag = "__name__:__main__"
entries = []
src_workspace = self.context.src_workspace.relative_to(self.context.repo.workdir)
for r in rows:
if tag in r.subject:
path = split_namespace(r.subject)[0]
elif tag in r.object_:
path = split_namespace(r.object_)[0]
else:
continue
if Path(path).is_relative_to(src_workspace):
entries.append(Path(path))
main_entry = await self._guess_main_entry(entries)
full_path = RebuildSequenceView.get_full_filename(self.context.repo.workdir, main_entry)
action = RebuildSequenceView(context=self.context, i_context=str(full_path))
try:
await action.run()
except Exception as e:
logger.warning(f"{e}, use the last successful version.")
files = list_files(self.context.repo.resources.data_api_design.workdir)
pattern = re.compile(r"[^a-zA-Z0-9]")
name = re.sub(pattern, "_", str(main_entry))
filename = Path(name).with_suffix(".sequence_diagram.mmd")
postfix = str(filename)
sequence_files = [i for i in files if postfix in str(i)]
content = await aread(filename=sequence_files[0])
await self.context.repo.resources.data_api_design.save(
filename=self.repo.workdir.stem + ".sequence_diagram.mmd", content=content
)
await self._save_system_design()
async def _save_system_design(self):
class_view = await self.context.repo.resources.data_api_design.get(
filename=self.repo.workdir.stem + ".class_diagram.mmd"
)
sequence_view = await self.context.repo.resources.data_api_design.get(
filename=self.repo.workdir.stem + ".sequence_diagram.mmd"
)
file_list = self.context.git_repo.get_files(relative_path=".", root_relative_path=self.context.src_workspace)
data = {
"Data structures and interfaces": class_view.content,
"Program call flow": sequence_view.content,
"File list": [str(i) for i in file_list],
}
await self.context.repo.docs.system_design.save(filename=self.rid + ".json", content=json.dumps(data))
async def _guess_src_workspace(self) -> Path:
files = list_files(self.context.repo.workdir)
dirs = [i.parent for i in files if i.name == "__init__.py"]
distinct = set()
for i in dirs:
done = False
for j in distinct:
if i.is_relative_to(j):
done = True
break
if j.is_relative_to(i):
break
if not done:
distinct = {j for j in distinct if not j.is_relative_to(i)}
distinct.add(i)
if len(distinct) == 1:
return list(distinct)[0]
prompt = "\n".join([f"- {str(i)}" for i in distinct])
rsp = await self.llm.aask(
prompt,
system_msgs=[
"You are a tool to choose the source code path from a list of paths based on the directory name.",
"You should identify the source code path among paths such as unit test path, examples path, etc.",
"Return a markdown JSON object containing:\n"
'- a "src" field containing the source code path;\n'
'- a "reason" field containing explaining why other paths is not the source code path\n',
],
)
logger.debug(rsp)
json_blocks = parse_json_code_block(rsp)
class Data(BaseModel):
src: str
reason: str
data = Data.model_validate_json(json_blocks[0])
logger.info(f"src_workspace: {data.src}")
return Path(data.src)
async def _guess_main_entry(self, entries: List[Path]) -> Path:
if len(entries) == 1:
return entries[0]
file_list = "## File List\n"
file_list += "\n".join([f"- {i}" for i in entries])
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_USAGE)
usage = "## Usage\n"
for r in rows:
if Path(r.subject).stem == "README":
usage += r.object_
prompt = file_list + "\n---\n" + usage
rsp = await self.llm.aask(
prompt,
system_msgs=[
'You are a tool to choose the source file path from "File List" which is used in "Usage".',
'You choose the source file path based on the name of file and the class name and package name used in "Usage".',
"Return a markdown JSON object containing:\n"
'- a "filename" field containing the chosen source file path from "File List" which is used in "Usage";\n'
'- a "reason" field explaining why.',
],
stream=False,
)
logger.debug(rsp)
json_blocks = parse_json_code_block(rsp)
class Data(BaseModel):
filename: str
reason: str
data = Data.model_validate_json(json_blocks[0])
logger.info(f"main: {data.filename}")
return Path(data.filename)

View file

@ -244,15 +244,6 @@ class RebuildSequenceView(Action):
class_view = await self._get_uml_class_view(ns_class_name)
source_code = await self._get_source_code(ns_class_name)
# prompt_blocks = [
# "## Instruction\n"
# "You are a python code to UML 2.0 Use Case translator.\n"
# 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n'
# "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
# 'conflict with the information in "Mermaid Class Views".\n'
# 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
# "system interactions with the internal system.\n"
# ]
prompt_blocks = []
block = "## Participants\n"
for p in participants:
@ -340,6 +331,7 @@ class RebuildSequenceView(Action):
system_msgs=[
"You are a Mermaid Sequence Diagram translator in function detail.",
"Translate the markdown text to a Mermaid Sequence Diagram.",
"Response must be concise.",
"Return a markdown mermaid code block.",
],
stream=False,
@ -440,7 +432,7 @@ class RebuildSequenceView(Action):
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
filename = split_namespace(ns_class_name=ns_class_name)[0]
if not rows:
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
src_filename = RebuildSequenceView.get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
return ""
return await aread(filename=src_filename, encoding="utf-8")
@ -450,7 +442,7 @@ class RebuildSequenceView(Action):
)
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
def get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
"""
Convert package name to the full path of the module.
@ -466,7 +458,7 @@ class RebuildSequenceView(Action):
"metagpt/management/skill_manager.py", then the returned value will be
"/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py"
"""
if re.match(r"^/.+", pathname):
if re.match(r"^/.+", str(pathname)):
return pathname
files = list_files(root=root)
postfix = "/" + str(pathname)

View file

@ -45,7 +45,7 @@ from metagpt.schema import (
Documents,
Message,
)
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set, aread
IS_PASS_PROMPT = """
{context}
@ -239,7 +239,8 @@ class Engineer(Role):
async def _think(self) -> Action | None:
if not self.src_workspace:
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
name = self._get_src_workspace_name()
self.src_workspace = self.git_repo.workdir / name
write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug])
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode])
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
@ -383,3 +384,10 @@ class Engineer(Role):
def action_description(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.next_todo_action
async def _get_src_workspace_name(self):
name = self.git_repo.workdir.name
src_workspace_filename = self.git_repo.workdir / ".src_workspace"
if src_workspace_filename.exists():
name = await aread(filename=src_workspace_filename)
return name

View file

@ -5,6 +5,7 @@ import pytest
from metagpt.actions.import_repo import ImportRepo
from metagpt.context import Context
from metagpt.utils.common import list_files
@pytest.mark.asyncio
@ -13,6 +14,12 @@ async def test_import_repo(repo_path):
context = Context()
action = ImportRepo(repo_path=repo_path, context=context)
await action.run()
assert context.repo
prd = list_files(context.repo.docs.prd.workdir)
assert prd
design = list_files(context.repo.docs.system_design.workdir)
assert design
assert prd[0].stem == design[0].stem
if __name__ == "__main__":

View file

@ -61,7 +61,7 @@ async def test_rebuild(context, mocker):
],
)
def test_get_full_filename(root, pathname, want):
res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname)
res = RebuildSequenceView.get_full_filename(root=root, pathname=pathname)
assert res == want

View file

@ -3,6 +3,7 @@ from typing import Optional, Union
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
@ -22,7 +23,7 @@ class MockLLM(OriginalLLM):
self.rsp_cache: dict = {}
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=LLM_API_TIMEOUT) -> str:
"""Overwrite original acompletion_text to cancel retry"""
if stream:
resp = await self._achat_completion_stream(messages, timeout=timeout)
@ -37,7 +38,7 @@ class MockLLM(OriginalLLM):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=LLM_API_TIMEOUT,
stream=True,
) -> str:
if system_msgs:
@ -56,7 +57,7 @@ class MockLLM(OriginalLLM):
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
async def original_aask_batch(self, msgs: list, timeout=3) -> str:
async def original_aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
"""A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked"""
context = []
for msg in msgs:
@ -83,7 +84,7 @@ class MockLLM(OriginalLLM):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=LLM_API_TIMEOUT,
stream=True,
) -> str:
# used to identify it a message has been called before
@ -98,7 +99,7 @@ class MockLLM(OriginalLLM):
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream)
return rsp
async def aask_batch(self, msgs: list, timeout=3) -> str:
async def aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
return rsp