mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
feat: +import external repo
This commit is contained in:
parent
f4240ca483
commit
222ae5ada3
7 changed files with 216 additions and 23 deletions
|
|
@ -63,6 +63,7 @@ class ExtractReadMe(Action):
|
|||
"You are a tool can summarize git repository README.md file.",
|
||||
"Return the summary about what is the repository.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
return summary
|
||||
|
||||
|
|
@ -77,6 +78,7 @@ class ExtractReadMe(Action):
|
|||
f"2. cd `{self.install_to_path}`;\n"
|
||||
f"3. install the repository.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
return install
|
||||
|
||||
|
|
@ -89,6 +91,7 @@ class ExtractReadMe(Action):
|
|||
"Return a bash code block of markdown object to configure the repository if necessary, otherwise return"
|
||||
" a empty bash code block of markdown object",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
return configuration
|
||||
|
||||
|
|
@ -100,13 +103,21 @@ class ExtractReadMe(Action):
|
|||
"You are a tool can summarize all usages of git repository according to README.md file.",
|
||||
"Return a list of code block of markdown objects to demonstrates the usage of the repository.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
return usage
|
||||
|
||||
async def _get(self) -> str:
|
||||
if self._readme is not None:
|
||||
return self._readme
|
||||
filename = Path(self.i_context).resolve() / "README.md"
|
||||
root = Path(self.i_context).resolve()
|
||||
filename = None
|
||||
for file_path in root.iterdir():
|
||||
if file_path.is_file() and file_path.stem == "README":
|
||||
filename = file_path
|
||||
break
|
||||
if not filename:
|
||||
return ""
|
||||
self._readme = await aread(filename=filename, encoding="utf-8")
|
||||
self._filename = str(filename)
|
||||
return self._readme
|
||||
|
|
|
|||
|
|
@ -1,20 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.extract_readme import ExtractReadMe
|
||||
from metagpt.actions.rebuild_class_view import RebuildClassView
|
||||
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools.libs.git import git_clone
|
||||
from metagpt.utils.common import (
|
||||
aread,
|
||||
awrite,
|
||||
list_files,
|
||||
parse_json_code_block,
|
||||
split_namespace,
|
||||
)
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class ImportRepo(Action):
|
||||
repo_path: str
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
rid: str = ""
|
||||
|
||||
async def run(self, with_messages: List[Message] = None, **kwargs) -> Message:
|
||||
await self._create_repo()
|
||||
pass
|
||||
await self._create_prd()
|
||||
await self._create_system_design()
|
||||
self.context.git_repo.archive(comments="Import")
|
||||
|
||||
async def _create_repo(self):
|
||||
path = await git_clone(url=self.repo_path, output_dir=self.config.workspace.path)
|
||||
|
|
@ -22,3 +46,153 @@ class ImportRepo(Action):
|
|||
self.config.project_path = path
|
||||
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
self.context.repo = ProjectRepo(self.context.git_repo)
|
||||
self.context.src_workspace = await self._guess_src_workspace()
|
||||
await awrite(
|
||||
filename=self.context.repo.workdir / ".src_workspace",
|
||||
data=str(self.context.src_workspace.relative_to(self.context.repo.workdir)),
|
||||
)
|
||||
|
||||
async def _create_prd(self):
|
||||
action = ExtractReadMe(i_context=str(self.context.repo.workdir), context=self.context)
|
||||
await action.run()
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SUMMARY)
|
||||
prd = {"Project Name": self.context.repo.workdir.name}
|
||||
for r in rows:
|
||||
if Path(r.subject).stem == "README":
|
||||
prd["Original Requirements"] = r.object_
|
||||
break
|
||||
self.rid = FileRepository.new_filename()
|
||||
await self.repo.docs.prd.save(filename=self.rid + ".json", content=json.dumps(prd))
|
||||
|
||||
async def _create_system_design(self):
|
||||
action = RebuildClassView(
|
||||
name="ReverseEngineering", i_context=str(self.context.src_workspace), context=self.context
|
||||
)
|
||||
await action.run()
|
||||
rows = await action.graph_db.select(predicate="hasMermaidClassDiagramFile")
|
||||
class_view_filename = rows[0].object_
|
||||
logger.info(f"class view:{class_view_filename}")
|
||||
|
||||
rows = await action.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
tag = "__name__:__main__"
|
||||
entries = []
|
||||
src_workspace = self.context.src_workspace.relative_to(self.context.repo.workdir)
|
||||
for r in rows:
|
||||
if tag in r.subject:
|
||||
path = split_namespace(r.subject)[0]
|
||||
elif tag in r.object_:
|
||||
path = split_namespace(r.object_)[0]
|
||||
else:
|
||||
continue
|
||||
if Path(path).is_relative_to(src_workspace):
|
||||
entries.append(Path(path))
|
||||
main_entry = await self._guess_main_entry(entries)
|
||||
full_path = RebuildSequenceView.get_full_filename(self.context.repo.workdir, main_entry)
|
||||
action = RebuildSequenceView(context=self.context, i_context=str(full_path))
|
||||
try:
|
||||
await action.run()
|
||||
except Exception as e:
|
||||
logger.warning(f"{e}, use the last successful version.")
|
||||
files = list_files(self.context.repo.resources.data_api_design.workdir)
|
||||
pattern = re.compile(r"[^a-zA-Z0-9]")
|
||||
name = re.sub(pattern, "_", str(main_entry))
|
||||
filename = Path(name).with_suffix(".sequence_diagram.mmd")
|
||||
postfix = str(filename)
|
||||
sequence_files = [i for i in files if postfix in str(i)]
|
||||
content = await aread(filename=sequence_files[0])
|
||||
await self.context.repo.resources.data_api_design.save(
|
||||
filename=self.repo.workdir.stem + ".sequence_diagram.mmd", content=content
|
||||
)
|
||||
await self._save_system_design()
|
||||
|
||||
async def _save_system_design(self):
|
||||
class_view = await self.context.repo.resources.data_api_design.get(
|
||||
filename=self.repo.workdir.stem + ".class_diagram.mmd"
|
||||
)
|
||||
sequence_view = await self.context.repo.resources.data_api_design.get(
|
||||
filename=self.repo.workdir.stem + ".sequence_diagram.mmd"
|
||||
)
|
||||
file_list = self.context.git_repo.get_files(relative_path=".", root_relative_path=self.context.src_workspace)
|
||||
data = {
|
||||
"Data structures and interfaces": class_view.content,
|
||||
"Program call flow": sequence_view.content,
|
||||
"File list": [str(i) for i in file_list],
|
||||
}
|
||||
await self.context.repo.docs.system_design.save(filename=self.rid + ".json", content=json.dumps(data))
|
||||
|
||||
async def _guess_src_workspace(self) -> Path:
|
||||
files = list_files(self.context.repo.workdir)
|
||||
dirs = [i.parent for i in files if i.name == "__init__.py"]
|
||||
distinct = set()
|
||||
for i in dirs:
|
||||
done = False
|
||||
for j in distinct:
|
||||
if i.is_relative_to(j):
|
||||
done = True
|
||||
break
|
||||
if j.is_relative_to(i):
|
||||
break
|
||||
if not done:
|
||||
distinct = {j for j in distinct if not j.is_relative_to(i)}
|
||||
distinct.add(i)
|
||||
if len(distinct) == 1:
|
||||
return list(distinct)[0]
|
||||
prompt = "\n".join([f"- {str(i)}" for i in distinct])
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
"You are a tool to choose the source code path from a list of paths based on the directory name.",
|
||||
"You should identify the source code path among paths such as unit test path, examples path, etc.",
|
||||
"Return a markdown JSON object containing:\n"
|
||||
'- a "src" field containing the source code path;\n'
|
||||
'- a "reason" field containing explaining why other paths is not the source code path\n',
|
||||
],
|
||||
)
|
||||
logger.debug(rsp)
|
||||
json_blocks = parse_json_code_block(rsp)
|
||||
|
||||
class Data(BaseModel):
|
||||
src: str
|
||||
reason: str
|
||||
|
||||
data = Data.model_validate_json(json_blocks[0])
|
||||
logger.info(f"src_workspace: {data.src}")
|
||||
return Path(data.src)
|
||||
|
||||
async def _guess_main_entry(self, entries: List[Path]) -> Path:
|
||||
if len(entries) == 1:
|
||||
return entries[0]
|
||||
|
||||
file_list = "## File List\n"
|
||||
file_list += "\n".join([f"- {i}" for i in entries])
|
||||
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_USAGE)
|
||||
usage = "## Usage\n"
|
||||
for r in rows:
|
||||
if Path(r.subject).stem == "README":
|
||||
usage += r.object_
|
||||
|
||||
prompt = file_list + "\n---\n" + usage
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
'You are a tool to choose the source file path from "File List" which is used in "Usage".',
|
||||
'You choose the source file path based on the name of file and the class name and package name used in "Usage".',
|
||||
"Return a markdown JSON object containing:\n"
|
||||
'- a "filename" field containing the chosen source file path from "File List" which is used in "Usage";\n'
|
||||
'- a "reason" field explaining why.',
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
logger.debug(rsp)
|
||||
json_blocks = parse_json_code_block(rsp)
|
||||
|
||||
class Data(BaseModel):
|
||||
filename: str
|
||||
reason: str
|
||||
|
||||
data = Data.model_validate_json(json_blocks[0])
|
||||
logger.info(f"main: {data.filename}")
|
||||
return Path(data.filename)
|
||||
|
|
|
|||
|
|
@ -244,15 +244,6 @@ class RebuildSequenceView(Action):
|
|||
class_view = await self._get_uml_class_view(ns_class_name)
|
||||
source_code = await self._get_source_code(ns_class_name)
|
||||
|
||||
# prompt_blocks = [
|
||||
# "## Instruction\n"
|
||||
# "You are a python code to UML 2.0 Use Case translator.\n"
|
||||
# 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n'
|
||||
# "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
|
||||
# 'conflict with the information in "Mermaid Class Views".\n'
|
||||
# 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
|
||||
# "system interactions with the internal system.\n"
|
||||
# ]
|
||||
prompt_blocks = []
|
||||
block = "## Participants\n"
|
||||
for p in participants:
|
||||
|
|
@ -340,6 +331,7 @@ class RebuildSequenceView(Action):
|
|||
system_msgs=[
|
||||
"You are a Mermaid Sequence Diagram translator in function detail.",
|
||||
"Translate the markdown text to a Mermaid Sequence Diagram.",
|
||||
"Response must be concise.",
|
||||
"Return a markdown mermaid code block.",
|
||||
],
|
||||
stream=False,
|
||||
|
|
@ -440,7 +432,7 @@ class RebuildSequenceView(Action):
|
|||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
filename = split_namespace(ns_class_name=ns_class_name)[0]
|
||||
if not rows:
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
|
||||
src_filename = RebuildSequenceView.get_full_filename(root=self.i_context, pathname=filename)
|
||||
if not src_filename:
|
||||
return ""
|
||||
return await aread(filename=src_filename, encoding="utf-8")
|
||||
|
|
@ -450,7 +442,7 @@ class RebuildSequenceView(Action):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
|
||||
def get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
|
||||
"""
|
||||
Convert package name to the full path of the module.
|
||||
|
||||
|
|
@ -466,7 +458,7 @@ class RebuildSequenceView(Action):
|
|||
"metagpt/management/skill_manager.py", then the returned value will be
|
||||
"/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py"
|
||||
"""
|
||||
if re.match(r"^/.+", pathname):
|
||||
if re.match(r"^/.+", str(pathname)):
|
||||
return pathname
|
||||
files = list_files(root=root)
|
||||
postfix = "/" + str(pathname)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from metagpt.schema import (
|
|||
Documents,
|
||||
Message,
|
||||
)
|
||||
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set, aread
|
||||
|
||||
IS_PASS_PROMPT = """
|
||||
{context}
|
||||
|
|
@ -239,7 +239,8 @@ class Engineer(Role):
|
|||
|
||||
async def _think(self) -> Action | None:
|
||||
if not self.src_workspace:
|
||||
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
name = self._get_src_workspace_name()
|
||||
self.src_workspace = self.git_repo.workdir / name
|
||||
write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug])
|
||||
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode])
|
||||
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
|
||||
|
|
@ -383,3 +384,10 @@ class Engineer(Role):
|
|||
def action_description(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
return self.next_todo_action
|
||||
|
||||
async def _get_src_workspace_name(self):
|
||||
name = self.git_repo.workdir.name
|
||||
src_workspace_filename = self.git_repo.workdir / ".src_workspace"
|
||||
if src_workspace_filename.exists():
|
||||
name = await aread(filename=src_workspace_filename)
|
||||
return name
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.import_repo import ImportRepo
|
||||
from metagpt.context import Context
|
||||
from metagpt.utils.common import list_files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -13,6 +14,12 @@ async def test_import_repo(repo_path):
|
|||
context = Context()
|
||||
action = ImportRepo(repo_path=repo_path, context=context)
|
||||
await action.run()
|
||||
assert context.repo
|
||||
prd = list_files(context.repo.docs.prd.workdir)
|
||||
assert prd
|
||||
design = list_files(context.repo.docs.system_design.workdir)
|
||||
assert design
|
||||
assert prd[0].stem == design[0].stem
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ async def test_rebuild(context, mocker):
|
|||
],
|
||||
)
|
||||
def test_get_full_filename(root, pathname, want):
|
||||
res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname)
|
||||
res = RebuildSequenceView.get_full_filename(root=root, pathname=pathname)
|
||||
assert res == want
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Optional, Union
|
|||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
|
|
@ -22,7 +23,7 @@ class MockLLM(OriginalLLM):
|
|||
self.rsp_cache: dict = {}
|
||||
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=LLM_API_TIMEOUT) -> str:
|
||||
"""Overwrite original acompletion_text to cancel retry"""
|
||||
if stream:
|
||||
resp = await self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
|
@ -37,7 +38,7 @@ class MockLLM(OriginalLLM):
|
|||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
timeout=LLM_API_TIMEOUT,
|
||||
stream=True,
|
||||
) -> str:
|
||||
if system_msgs:
|
||||
|
|
@ -56,7 +57,7 @@ class MockLLM(OriginalLLM):
|
|||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
||||
async def original_aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
async def original_aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
|
||||
"""A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked"""
|
||||
context = []
|
||||
for msg in msgs:
|
||||
|
|
@ -83,7 +84,7 @@ class MockLLM(OriginalLLM):
|
|||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
timeout=LLM_API_TIMEOUT,
|
||||
stream=True,
|
||||
) -> str:
|
||||
# used to identify it a message has been called before
|
||||
|
|
@ -98,7 +99,7 @@ class MockLLM(OriginalLLM):
|
|||
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream)
|
||||
return rsp
|
||||
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
async def aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
|
||||
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
|
||||
return rsp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue