mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
Merge branch 'feature/import_repo' into featur/intent_detect
This commit is contained in:
commit
2e82a16e74
54 changed files with 1736 additions and 142 deletions
39
examples/di/software_company.py
Normal file
39
examples/di/software_company.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import fire
|
||||
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main():
|
||||
prompt = """
|
||||
This is a software requirement:
|
||||
```text
|
||||
write a snake game
|
||||
```
|
||||
---
|
||||
1. Writes a PRD based on software requirements.
|
||||
2. Writes a design to the project repository, based on the PRD of the project.
|
||||
3. Writes a project plan to the project repository, based on the design of the project.
|
||||
4. Writes codes to the project repository, based on the project plan of the project.
|
||||
5. Run QA test on the project repository.
|
||||
6. Stage and commit changes for the project repository using Git.
|
||||
Note: All required dependencies and environments have been fully installed and configured.
|
||||
"""
|
||||
di = DataInterpreter(
|
||||
tools=[
|
||||
"write_prd",
|
||||
"write_design",
|
||||
"write_project_plan",
|
||||
"write_codes",
|
||||
"run_qa_test",
|
||||
"fix_bug",
|
||||
"git_archive",
|
||||
]
|
||||
)
|
||||
|
||||
await di.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
|
@ -11,9 +11,13 @@ from metagpt.rag.schema import (
|
|||
BM25RetrieverConfig,
|
||||
ChromaIndexConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
)
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
|
||||
QUESTION = "What are key qualities to be a good writer?"
|
||||
|
|
@ -39,12 +43,22 @@ class Player(BaseModel):
|
|||
class RAGExample:
|
||||
"""Show how to use RAG."""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
def __init__(self, engine: SimpleEngine = None):
|
||||
self._engine = engine
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
if not self._engine:
|
||||
self._engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
return self._engine
|
||||
|
||||
@engine.setter
|
||||
def engine(self, value: SimpleEngine):
|
||||
self._engine = value
|
||||
|
||||
async def run_pipeline(self, question=QUESTION, print_title=True):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
|
|
@ -97,6 +111,7 @@ class RAGExample:
|
|||
self.engine.add_docs([travel_filepath])
|
||||
await self.run_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
@handle_exception
|
||||
async def add_objects(self, print_title=True):
|
||||
"""This example show how to add objects.
|
||||
|
||||
|
|
@ -154,20 +169,41 @@ class RAGExample:
|
|||
"""
|
||||
self._print_title("Init And Query ChromaDB")
|
||||
|
||||
# save index
|
||||
# 1. save index
|
||||
output_dir = DATA_PATH / "rag"
|
||||
SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
|
||||
)
|
||||
|
||||
# load index
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=ChromaIndexConfig(persist_path=output_dir),
|
||||
# 2. load index
|
||||
engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir))
|
||||
|
||||
# 3. query
|
||||
answer = await engine.aquery(TRAVEL_QUESTION)
|
||||
self._print_query_result(answer)
|
||||
|
||||
@handle_exception
|
||||
async def init_and_query_es(self):
|
||||
"""This example show how to use es. how to save and load index. will print something like:
|
||||
|
||||
Query Result:
|
||||
Bob likes traveling.
|
||||
"""
|
||||
self._print_title("Init And Query Elasticsearch")
|
||||
|
||||
# 1. create es index and save docs
|
||||
store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200")
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)],
|
||||
)
|
||||
|
||||
# query
|
||||
answer = engine.query(TRAVEL_QUESTION)
|
||||
# 2. load index
|
||||
engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config))
|
||||
|
||||
# 3. query
|
||||
answer = await engine.aquery(TRAVEL_QUESTION)
|
||||
self._print_query_result(answer)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -205,6 +241,7 @@ async def main():
|
|||
await e.add_objects()
|
||||
await e.init_objects()
|
||||
await e.init_and_query_chromadb()
|
||||
await e.init_and_query_es()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -104,3 +104,8 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
|
|||
if self.node:
|
||||
return await self._run_action_node(*args, **kwargs)
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
def override_context(self):
|
||||
"""Set `private_context` and `context` to the same `Context` object."""
|
||||
if not self.private_context:
|
||||
self.private_context = self.context
|
||||
|
|
|
|||
|
|
@ -340,10 +340,7 @@ class ActionNode:
|
|||
def tagging(self, text, schema, tag="") -> str:
|
||||
if not tag:
|
||||
return text
|
||||
if schema == "json":
|
||||
return f"[{tag}]\n" + text + f"\n[/{tag}]"
|
||||
else: # markdown
|
||||
return f"[{tag}]\n" + text + f"\n[/{tag}]"
|
||||
return f"[{tag}]\n{text}\n[/{tag}]"
|
||||
|
||||
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
|
||||
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
|
||||
|
|
@ -375,7 +372,7 @@ class ActionNode:
|
|||
schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action
|
||||
"""
|
||||
if schema == "raw":
|
||||
return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction
|
||||
return f"{context}\n\n## Actions\n{LANGUAGE_CONSTRAINT}\n{self.instruction}"
|
||||
|
||||
### 直接使用 pydantic BaseModel 生成 instruction 与 example,仅限 JSON
|
||||
# child_class = self._create_children_class()
|
||||
|
|
|
|||
123
metagpt/actions/extract_readme.py
Normal file
123
metagpt/actions/extract_readme.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Module Description: This script defines the LearnReadMe class, which is an action to learn from the contents of
|
||||
a README.md file.
|
||||
Author: mashenquan
|
||||
Date: 2024-3-20
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class ExtractReadMe(Action):
|
||||
"""
|
||||
An action to extract summary, installation, configuration, usages from the contents of a README.md file.
|
||||
|
||||
Attributes:
|
||||
graph_db (Optional[GraphRepository]): A graph database repository.
|
||||
install_to_path (Optional[str]): The path where the repository to install to.
|
||||
"""
|
||||
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
install_to_path: Optional[str] = Field(default="/TO/PATH")
|
||||
_readme: Optional[str] = None
|
||||
_filename: Optional[str] = None
|
||||
|
||||
async def run(self, with_messages=None, **kwargs):
|
||||
"""
|
||||
Implementation of `Action`'s `run` method.
|
||||
|
||||
Args:
|
||||
with_messages (Optional[Type]): An optional argument specifying messages to react to.
|
||||
"""
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
summary = await self._summarize()
|
||||
await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_SUMMARY, object_=summary)
|
||||
install = await self._extract_install()
|
||||
await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_INSTALL, object_=install)
|
||||
conf = await self._extract_configuration()
|
||||
await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_CONFIG, object_=conf)
|
||||
usage = await self._extract_usage()
|
||||
await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_USAGE, object_=usage)
|
||||
|
||||
await self.graph_db.save()
|
||||
|
||||
return Message(content="", cause_by=self)
|
||||
|
||||
async def _summarize(self) -> str:
|
||||
readme = await self._get()
|
||||
summary = await self.llm.aask(
|
||||
readme,
|
||||
system_msgs=[
|
||||
"You are a tool can summarize git repository README.md file.",
|
||||
"Return the summary about what is the repository.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
return summary
|
||||
|
||||
async def _extract_install(self) -> str:
|
||||
await self._get()
|
||||
install = await self.llm.aask(
|
||||
self._readme,
|
||||
system_msgs=[
|
||||
"You are a tool can install git repository according to README.md file.",
|
||||
"Return a bash code block of markdown including:\n"
|
||||
f"1. git clone the repository to the directory `{self.install_to_path}`;\n"
|
||||
f"2. cd `{self.install_to_path}`;\n"
|
||||
f"3. install the repository.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
return install
|
||||
|
||||
async def _extract_configuration(self) -> str:
|
||||
await self._get()
|
||||
configuration = await self.llm.aask(
|
||||
self._readme,
|
||||
system_msgs=[
|
||||
"You are a tool can configure git repository according to README.md file.",
|
||||
"Return a bash code block of markdown object to configure the repository if necessary, otherwise return"
|
||||
" a empty bash code block of markdown object",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
return configuration
|
||||
|
||||
async def _extract_usage(self) -> str:
|
||||
await self._get()
|
||||
usage = await self.llm.aask(
|
||||
self._readme,
|
||||
system_msgs=[
|
||||
"You are a tool can summarize all usages of git repository according to README.md file.",
|
||||
"Return a list of code block of markdown objects to demonstrates the usage of the repository.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
return usage
|
||||
|
||||
async def _get(self) -> str:
|
||||
if self._readme is not None:
|
||||
return self._readme
|
||||
root = Path(self.i_context).resolve()
|
||||
filename = None
|
||||
for file_path in root.iterdir():
|
||||
if file_path.is_file() and file_path.stem == "README":
|
||||
filename = file_path
|
||||
break
|
||||
if not filename:
|
||||
return ""
|
||||
self._readme = await aread(filename=filename, encoding="utf-8")
|
||||
self._filename = str(filename)
|
||||
return self._readme
|
||||
226
metagpt/actions/import_repo.py
Normal file
226
metagpt/actions/import_repo.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
|
||||
This script defines an action to import a Git repository into the MetaGPT project format, enabling incremental
|
||||
appending of requirements.
|
||||
The MetaGPT project format encompasses a structured representation of project data compatible with MetaGPT's
|
||||
capabilities, facilitating the integration of Git repositories into MetaGPT workflows while allowing for the gradual
|
||||
addition of requirements.
|
||||
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.extract_readme import ExtractReadMe
|
||||
from metagpt.actions.rebuild_class_view import RebuildClassView
|
||||
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools.libs.git import git_clone
|
||||
from metagpt.utils.common import (
|
||||
aread,
|
||||
awrite,
|
||||
list_files,
|
||||
parse_json_code_block,
|
||||
split_namespace,
|
||||
)
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class ImportRepo(Action):
|
||||
"""
|
||||
An action to import a Git repository into a graph database and create related artifacts.
|
||||
|
||||
Attributes:
|
||||
repo_path (str): The URL of the Git repository to import.
|
||||
graph_db (Optional[GraphRepository]): The output graph database of the Git repository.
|
||||
rid (str): The output requirement ID.
|
||||
"""
|
||||
|
||||
repo_path: str # input, git repo url.
|
||||
graph_db: Optional[GraphRepository] = None # output. graph db of the git repository
|
||||
rid: str = "" # output, requirement ID.
|
||||
|
||||
async def run(self, with_messages: List[Message] = None, **kwargs) -> Message:
|
||||
"""
|
||||
Runs the import process for the Git repository.
|
||||
|
||||
Args:
|
||||
with_messages (List[Message], optional): Additional messages to include.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Message: A message indicating the completion of the import process.
|
||||
"""
|
||||
await self._create_repo()
|
||||
await self._create_prd()
|
||||
await self._create_system_design()
|
||||
self.context.git_repo.archive(comments="Import")
|
||||
|
||||
async def _create_repo(self):
|
||||
path = await git_clone(url=self.repo_path, output_dir=self.config.workspace.path)
|
||||
self.repo_path = str(path)
|
||||
self.config.project_path = path
|
||||
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
self.context.repo = ProjectRepo(self.context.git_repo)
|
||||
self.context.src_workspace = await self._guess_src_workspace()
|
||||
await awrite(
|
||||
filename=self.context.repo.workdir / ".src_workspace",
|
||||
data=str(self.context.src_workspace.relative_to(self.context.repo.workdir)),
|
||||
)
|
||||
|
||||
async def _create_prd(self):
|
||||
action = ExtractReadMe(i_context=str(self.context.repo.workdir), context=self.context)
|
||||
await action.run()
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SUMMARY)
|
||||
prd = {"Project Name": self.context.repo.workdir.name}
|
||||
for r in rows:
|
||||
if Path(r.subject).stem == "README":
|
||||
prd["Original Requirements"] = r.object_
|
||||
break
|
||||
self.rid = FileRepository.new_filename()
|
||||
await self.repo.docs.prd.save(filename=self.rid + ".json", content=json.dumps(prd))
|
||||
|
||||
async def _create_system_design(self):
|
||||
action = RebuildClassView(
|
||||
name="ReverseEngineering", i_context=str(self.context.src_workspace), context=self.context
|
||||
)
|
||||
await action.run()
|
||||
rows = await action.graph_db.select(predicate="hasMermaidClassDiagramFile")
|
||||
class_view_filename = rows[0].object_
|
||||
logger.info(f"class view:{class_view_filename}")
|
||||
|
||||
rows = await action.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
tag = "__name__:__main__"
|
||||
entries = []
|
||||
src_workspace = self.context.src_workspace.relative_to(self.context.repo.workdir)
|
||||
for r in rows:
|
||||
if tag in r.subject:
|
||||
path = split_namespace(r.subject)[0]
|
||||
elif tag in r.object_:
|
||||
path = split_namespace(r.object_)[0]
|
||||
else:
|
||||
continue
|
||||
if Path(path).is_relative_to(src_workspace):
|
||||
entries.append(Path(path))
|
||||
main_entry = await self._guess_main_entry(entries)
|
||||
full_path = RebuildSequenceView.get_full_filename(self.context.repo.workdir, main_entry)
|
||||
action = RebuildSequenceView(context=self.context, i_context=str(full_path))
|
||||
try:
|
||||
await action.run()
|
||||
except Exception as e:
|
||||
logger.warning(f"{e}, use the last successful version.")
|
||||
files = list_files(self.context.repo.resources.data_api_design.workdir)
|
||||
pattern = re.compile(r"[^a-zA-Z0-9]")
|
||||
name = re.sub(pattern, "_", str(main_entry))
|
||||
filename = Path(name).with_suffix(".sequence_diagram.mmd")
|
||||
postfix = str(filename)
|
||||
sequence_files = [i for i in files if postfix in str(i)]
|
||||
content = await aread(filename=sequence_files[0])
|
||||
await self.context.repo.resources.data_api_design.save(
|
||||
filename=self.repo.workdir.stem + ".sequence_diagram.mmd", content=content
|
||||
)
|
||||
await self._save_system_design()
|
||||
|
||||
async def _save_system_design(self):
|
||||
class_view = await self.context.repo.resources.data_api_design.get(
|
||||
filename=self.repo.workdir.stem + ".class_diagram.mmd"
|
||||
)
|
||||
sequence_view = await self.context.repo.resources.data_api_design.get(
|
||||
filename=self.repo.workdir.stem + ".sequence_diagram.mmd"
|
||||
)
|
||||
file_list = self.context.git_repo.get_files(relative_path=".", root_relative_path=self.context.src_workspace)
|
||||
data = {
|
||||
"Data structures and interfaces": class_view.content,
|
||||
"Program call flow": sequence_view.content,
|
||||
"File list": [str(i) for i in file_list],
|
||||
}
|
||||
await self.context.repo.docs.system_design.save(filename=self.rid + ".json", content=json.dumps(data))
|
||||
|
||||
async def _guess_src_workspace(self) -> Path:
|
||||
files = list_files(self.context.repo.workdir)
|
||||
dirs = [i.parent for i in files if i.name == "__init__.py"]
|
||||
distinct = set()
|
||||
for i in dirs:
|
||||
done = False
|
||||
for j in distinct:
|
||||
if i.is_relative_to(j):
|
||||
done = True
|
||||
break
|
||||
if j.is_relative_to(i):
|
||||
break
|
||||
if not done:
|
||||
distinct = {j for j in distinct if not j.is_relative_to(i)}
|
||||
distinct.add(i)
|
||||
if len(distinct) == 1:
|
||||
return list(distinct)[0]
|
||||
prompt = "\n".join([f"- {str(i)}" for i in distinct])
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
"You are a tool to choose the source code path from a list of paths based on the directory name.",
|
||||
"You should identify the source code path among paths such as unit test path, examples path, etc.",
|
||||
"Return a markdown JSON object containing:\n"
|
||||
'- a "src" field containing the source code path;\n'
|
||||
'- a "reason" field containing explaining why other paths is not the source code path\n',
|
||||
],
|
||||
)
|
||||
logger.debug(rsp)
|
||||
json_blocks = parse_json_code_block(rsp)
|
||||
|
||||
class Data(BaseModel):
|
||||
src: str
|
||||
reason: str
|
||||
|
||||
data = Data.model_validate_json(json_blocks[0])
|
||||
logger.info(f"src_workspace: {data.src}")
|
||||
return Path(data.src)
|
||||
|
||||
async def _guess_main_entry(self, entries: List[Path]) -> Path:
|
||||
if len(entries) == 1:
|
||||
return entries[0]
|
||||
|
||||
file_list = "## File List\n"
|
||||
file_list += "\n".join([f"- {i}" for i in entries])
|
||||
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_USAGE)
|
||||
usage = "## Usage\n"
|
||||
for r in rows:
|
||||
if Path(r.subject).stem == "README":
|
||||
usage += r.object_
|
||||
|
||||
prompt = file_list + "\n---\n" + usage
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
'You are a tool to choose the source file path from "File List" which is used in "Usage".',
|
||||
'You choose the source file path based on the name of file and the class name and package name used in "Usage".',
|
||||
"Return a markdown JSON object containing:\n"
|
||||
'- a "filename" field containing the chosen source file path from "File List" which is used in "Usage";\n'
|
||||
'- a "reason" field explaining why.',
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
logger.debug(rsp)
|
||||
json_blocks = parse_json_code_block(rsp)
|
||||
|
||||
class Data(BaseModel):
|
||||
filename: str
|
||||
reason: str
|
||||
|
||||
data = Data.model_validate_json(json_blocks[0])
|
||||
logger.info(f"main: {data.filename}")
|
||||
return Path(data.filename)
|
||||
|
|
@ -14,8 +14,6 @@ from typing import Optional
|
|||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class PrepareDocuments(Action):
|
||||
|
|
@ -38,8 +36,7 @@ class PrepareDocuments(Action):
|
|||
if path.exists() and not self.config.inc:
|
||||
shutil.rmtree(path)
|
||||
self.config.project_path = path
|
||||
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
self.context.repo = ProjectRepo(self.context.git_repo)
|
||||
self.context.set_repo_dir(path)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
"""Create and initialize the workspace folder, initialize the Git environment."""
|
||||
|
|
|
|||
|
|
@ -244,15 +244,6 @@ class RebuildSequenceView(Action):
|
|||
class_view = await self._get_uml_class_view(ns_class_name)
|
||||
source_code = await self._get_source_code(ns_class_name)
|
||||
|
||||
# prompt_blocks = [
|
||||
# "## Instruction\n"
|
||||
# "You are a python code to UML 2.0 Use Case translator.\n"
|
||||
# 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n'
|
||||
# "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
|
||||
# 'conflict with the information in "Mermaid Class Views".\n'
|
||||
# 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
|
||||
# "system interactions with the internal system.\n"
|
||||
# ]
|
||||
prompt_blocks = []
|
||||
block = "## Participants\n"
|
||||
for p in participants:
|
||||
|
|
@ -340,6 +331,7 @@ class RebuildSequenceView(Action):
|
|||
system_msgs=[
|
||||
"You are a Mermaid Sequence Diagram translator in function detail.",
|
||||
"Translate the markdown text to a Mermaid Sequence Diagram.",
|
||||
"Response must be concise.",
|
||||
"Return a markdown mermaid code block.",
|
||||
],
|
||||
stream=False,
|
||||
|
|
@ -440,7 +432,7 @@ class RebuildSequenceView(Action):
|
|||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
filename = split_namespace(ns_class_name=ns_class_name)[0]
|
||||
if not rows:
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
|
||||
src_filename = RebuildSequenceView.get_full_filename(root=self.i_context, pathname=filename)
|
||||
if not src_filename:
|
||||
return ""
|
||||
return await aread(filename=src_filename, encoding="utf-8")
|
||||
|
|
@ -450,7 +442,7 @@ class RebuildSequenceView(Action):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
|
||||
def get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
|
||||
"""
|
||||
Convert package name to the full path of the module.
|
||||
|
||||
|
|
@ -466,7 +458,7 @@ class RebuildSequenceView(Action):
|
|||
"metagpt/management/skill_manager.py", then the returned value will be
|
||||
"/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py"
|
||||
"""
|
||||
if re.match(r"^/.+", pathname):
|
||||
if re.match(r"^/.+", str(pathname)):
|
||||
return pathname
|
||||
files = list_files(root=root)
|
||||
postfix = "/" + str(pathname)
|
||||
|
|
|
|||
|
|
@ -128,6 +128,9 @@ CODE_PLAN_AND_CHANGE_CONTEXT = """
|
|||
## User New Requirements
|
||||
{requirement}
|
||||
|
||||
## Issue
|
||||
{issue}
|
||||
|
||||
## PRD
|
||||
{prd}
|
||||
|
||||
|
|
@ -211,7 +214,8 @@ class WriteCodePlanAndChange(Action):
|
|||
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
|
||||
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
|
||||
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
|
||||
requirement=self.i_context.requirement,
|
||||
requirement=f"```text\n{self.i_context.requirement}\n```",
|
||||
issue=f"```text\n{self.i_context.issue}\n```",
|
||||
prd=prd_doc.content,
|
||||
design=design_doc.content,
|
||||
task=task_doc.content,
|
||||
|
|
|
|||
|
|
@ -133,10 +133,10 @@ REQUIREMENT_ANALYSIS = ActionNode(
|
|||
REFINED_REQUIREMENT_ANALYSIS = ActionNode(
|
||||
key="Refined Requirement Analysis",
|
||||
expected_type=List[str],
|
||||
instruction="Review and refine the existing requirement analysis to align with the evolving needs of the project "
|
||||
instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project "
|
||||
"due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements "
|
||||
"required for the refined project scope.",
|
||||
example=["Require add/update/modify ..."],
|
||||
example=["Require add ...", "Require modify ..."],
|
||||
)
|
||||
|
||||
REQUIREMENT_POOL = ActionNode(
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class Config(CLIParams, YamlModel):
|
|||
# Key Parameters
|
||||
llm: LLMConfig
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
# Global Proxy. Not used by LLM, but by other tools such as browsers.
|
||||
proxy: str = ""
|
||||
|
||||
# Tool Parameters
|
||||
|
|
|
|||
|
|
@ -5,9 +5,11 @@
|
|||
@Author : alexanderwu
|
||||
@File : context.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
|
@ -78,11 +80,10 @@ class Context(BaseModel):
|
|||
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
# """Use a LLM instance"""
|
||||
# self._llm_config = self.config.get_llm_config(name, provider)
|
||||
# self._llm = None
|
||||
# return self._llm
|
||||
def set_repo_dir(self, path: str | Path):
|
||||
repo_path = Path(path)
|
||||
self.git_repo = GitRepository(local_path=repo_path, auto_init=True)
|
||||
self.repo = ProjectRepo(self.git_repo)
|
||||
|
||||
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
||||
"""Return a CostManager instance"""
|
||||
|
|
@ -108,3 +109,38 @@ class Context(BaseModel):
|
|||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self._select_costmanager(llm_config)
|
||||
return llm
|
||||
|
||||
def serialize(self) -> Dict[str, Any]:
|
||||
"""Serialize the object's attributes into a dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing serialized data.
|
||||
"""
|
||||
return {
|
||||
"workdir": str(self.repo.workdir) if self.repo else "",
|
||||
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
|
||||
"cost_manager": self.cost_manager.model_dump_json(),
|
||||
}
|
||||
|
||||
def deserialize(self, serialized_data: Dict[str, Any]):
|
||||
"""Deserialize the given serialized data and update the object's attributes accordingly.
|
||||
|
||||
Args:
|
||||
serialized_data (Dict[str, Any]): A dictionary containing serialized data.
|
||||
"""
|
||||
if not serialized_data:
|
||||
return
|
||||
workdir = serialized_data.get("workdir")
|
||||
if workdir:
|
||||
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
|
||||
self.repo = ProjectRepo(self.git_repo)
|
||||
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
if src_workspace.exists():
|
||||
self.src_workspace = src_workspace
|
||||
kwargs = serialized_data.get("kwargs")
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
self.kwargs.set(k, v)
|
||||
cost_manager = serialized_data.get("cost_manager")
|
||||
if cost_manager:
|
||||
self.cost_manager.model_validate_json(cost_manager)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from dataclasses import asdict
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.ai import generativelanguage as glm
|
||||
|
|
@ -11,6 +12,7 @@ from google.generativeai.generative_models import GenerativeModel
|
|||
from google.generativeai.types import content_types
|
||||
from google.generativeai.types.generation_types import (
|
||||
AsyncGenerateContentResponse,
|
||||
BlockedPromptException,
|
||||
GenerateContentResponse,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
|
@ -141,7 +143,11 @@ class GeminiLLM(BaseLLM):
|
|||
)
|
||||
collected_content = []
|
||||
async for chunk in resp:
|
||||
content = chunk.text
|
||||
try:
|
||||
content = chunk.text
|
||||
except Exception as e:
|
||||
logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}")
|
||||
raise BlockedPromptException(str(chunk))
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
|
|
@ -150,3 +156,10 @@ class GeminiLLM(BaseLLM):
|
|||
usage = await self.aget_usage(messages, full_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
def list_models(self) -> List:
|
||||
models = []
|
||||
for model in genai.list_models(page_size=100):
|
||||
models.append(asdict(model))
|
||||
logger.info(json.dumps(models))
|
||||
return models
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class HumanProvider(BaseLLM):
|
|||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
self.config = config
|
||||
|
||||
def ask(self, msg: str, timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
logger.info("It's your turn, please type in your response. You may also refer to the context below")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Engines init"""
|
||||
|
||||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
from metagpt.rag.engines.flare import FLAREEngine
|
||||
|
||||
__all__ = ["SimpleEngine"]
|
||||
__all__ = ["SimpleEngine", "FLAREEngine"]
|
||||
|
|
|
|||
9
metagpt/rag/engines/flare.py
Normal file
9
metagpt/rag/engines/flare.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""FLARE Engine.
|
||||
|
||||
Use llamaindex's FLAREInstructQueryEngine as FLAREEngine, which accepts other engines as parameters.
|
||||
For example, Create a simple engine, and then pass it to FLAREEngine.
|
||||
"""
|
||||
|
||||
from llama_index.core.query_engine import ( # noqa: F401
|
||||
FLAREInstructQueryEngine as FLAREEngine,
|
||||
)
|
||||
|
|
@ -130,10 +130,12 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
objs = objs or []
|
||||
retriever_configs = retriever_configs or []
|
||||
|
||||
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
objs = objs or []
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class ConfigBasedFactory(GenericFactory):
|
|||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
raise ValueError(f"Unknown config: {key}")
|
||||
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import chromadb
|
|||
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
|
|
@ -11,6 +13,8 @@ from metagpt.rag.schema import (
|
|||
BaseIndexConfig,
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchKeywordIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
|
@ -22,6 +26,8 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
FAISSIndexConfig: self._create_faiss,
|
||||
ChromaIndexConfig: self._create_chroma,
|
||||
BM25IndexConfig: self._create_bm25,
|
||||
ElasticsearchIndexConfig: self._create_es,
|
||||
ElasticsearchKeywordIndexConfig: self._create_es,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -30,31 +36,44 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
return super().get_instance(config, **kwargs)
|
||||
|
||||
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
|
||||
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
return index
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
||||
def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
||||
def _index_from_storage(
|
||||
self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
return load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
|
||||
def _index_from_vector_store(
|
||||
self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
return VectorStoreIndex.from_vector_store(
|
||||
vector_store=vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ class RAGLLM(CustomLLM):
|
|||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
"""Get LLM metadata."""
|
||||
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
|
||||
return LLMMetadata(
|
||||
context_window=self.context_window, num_output=self.num_output, model_name=self.model_name or "unknown"
|
||||
)
|
||||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
|
|
|
|||
|
|
@ -3,9 +3,16 @@
|
|||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.postprocessor.colbert_rerank import ColbertRerank
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
|
||||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import (
|
||||
BaseRankerConfig,
|
||||
ColbertRerankConfig,
|
||||
LLMRankerConfig,
|
||||
ObjectRankerConfig,
|
||||
)
|
||||
|
||||
|
||||
class RankerFactory(ConfigBasedFactory):
|
||||
|
|
@ -14,6 +21,8 @@ class RankerFactory(ConfigBasedFactory):
|
|||
def __init__(self):
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
ColbertRerankConfig: self._create_colbert_ranker,
|
||||
ObjectRankerConfig: self._create_object_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -28,6 +37,12 @@ class RankerFactory(ConfigBasedFactory):
|
|||
config.llm = self._extract_llm(config, **kwargs)
|
||||
return LLMRerank(**config.model_dump())
|
||||
|
||||
def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
|
||||
return ColbertRerank(**config.model_dump())
|
||||
|
||||
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
|
||||
return ObjectSortPostprocessor(**config.model_dump())
|
||||
|
||||
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
|
||||
return self._val_from_config_or_kwargs("llm", config, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,18 +6,22 @@ import chromadb
|
|||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchKeywordRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
|
|
@ -32,6 +36,8 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
ChromaRetrieverConfig: self._create_chroma_retriever,
|
||||
ElasticsearchRetrieverConfig: self._create_es_retriever,
|
||||
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -53,20 +59,29 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
nodes = list(config.index.docstore.docs.values())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return ElasticsearchRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
|
|
|
|||
55
metagpt/rag/rankers/object_ranker.py
Normal file
55
metagpt/rag/rankers/object_ranker.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""Object ranker."""
|
||||
|
||||
import heapq
|
||||
import json
|
||||
from typing import Literal, Optional
|
||||
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.rag.schema import ObjectNode
|
||||
|
||||
|
||||
class ObjectSortPostprocessor(BaseNodePostprocessor):
|
||||
"""Sorted by object's field, desc or asc.
|
||||
|
||||
Assumes nodes is list of ObjectNode with score.
|
||||
"""
|
||||
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
top_n: int = 5
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ObjectSortPostprocessor"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: list[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> list[NodeWithScore]:
|
||||
"""Postprocess nodes."""
|
||||
if query_bundle is None:
|
||||
raise ValueError("Missing query bundle in extra info.")
|
||||
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
self._check_metadata(nodes[0].node)
|
||||
|
||||
sort_key = lambda node: json.loads(node.node.metadata["obj_json"])[self.field_name]
|
||||
return self._get_sort_func()(self.top_n, nodes, key=sort_key)
|
||||
|
||||
def _check_metadata(self, node: ObjectNode):
|
||||
try:
|
||||
obj_dict = json.loads(node.metadata.get("obj_json"))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid object json in metadata: {node.metadata}, error: {e}")
|
||||
|
||||
if self.field_name not in obj_dict:
|
||||
raise ValueError(f"Field '{self.field_name}' not found in object: {obj_dict}")
|
||||
|
||||
def _get_sort_func(self):
|
||||
return heapq.nlargest if self.order == "desc" else heapq.nsmallest
|
||||
17
metagpt/rag/retrievers/es_retriever.py
Normal file
17
metagpt/rag/retrievers/es_retriever.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Elasticsearch retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class ElasticsearchRetriever(VectorIndexRetriever):
|
||||
"""Elasticsearch retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist.
|
||||
|
||||
Elasticsearch automatically saves, so there is no need to implement."""
|
||||
|
|
@ -8,7 +8,7 @@ class FAISSRetriever(VectorIndexRetriever):
|
|||
"""FAISS retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes"""
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
|
@ -46,6 +47,35 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
|
|||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class ElasticsearchStoreConfig(BaseModel):
|
||||
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
|
||||
es_url: str = Field(default=None, description="Elasticsearch URL.")
|
||||
es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
|
||||
es_api_key: str = Field(default=None, description="Elasticsearch API key.")
|
||||
es_user: str = Field(default=None, description="Elasticsearch username.")
|
||||
es_password: str = Field(default=None, description="Elasticsearch password.")
|
||||
batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
|
||||
distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")
|
||||
|
||||
|
||||
class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Elasticsearch-based retrievers. Support both vector and text."""
|
||||
|
||||
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
|
||||
vector_store_query_mode: VectorStoreQueryMode = Field(
|
||||
default=VectorStoreQueryMode.DEFAULT, description="default is vector query."
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig):
|
||||
"""Config for Elasticsearch-based retrievers. Support text only."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field(
|
||||
default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only."
|
||||
)
|
||||
|
||||
|
||||
class BaseRankerConfig(BaseModel):
|
||||
"""Common config for rankers.
|
||||
|
||||
|
|
@ -53,7 +83,6 @@ class BaseRankerConfig(BaseModel):
|
|||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
top_n: int = Field(default=5, description="The number of top results to return.")
|
||||
|
||||
|
||||
|
|
@ -66,12 +95,24 @@ class LLMRankerConfig(BaseRankerConfig):
|
|||
)
|
||||
|
||||
|
||||
class ColbertRerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.")
|
||||
device: str = Field(default="cpu", description="Device to use for sentence transformer.")
|
||||
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
|
||||
|
||||
|
||||
class ObjectRankerConfig(BaseRankerConfig):
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
|
||||
|
||||
class BaseIndexConfig(BaseModel):
|
||||
"""Common config for index.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
|
||||
|
||||
|
||||
|
|
@ -97,6 +138,19 @@ class BM25IndexConfig(BaseIndexConfig):
|
|||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ElasticsearchIndexConfig(VectorIndexConfig):
|
||||
"""Config for es-based index."""
|
||||
|
||||
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
|
||||
persist_path: Union[str, Path] = ""
|
||||
|
||||
|
||||
class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig):
|
||||
"""Config for es-based index. no embedding."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from __future__ import annotations
|
|||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
|
|
@ -30,6 +30,7 @@ from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
|
|||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange
|
||||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
CODE_PLAN_AND_CHANGE_FILE_REPO,
|
||||
REQUIREMENT_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
|
|
@ -45,7 +46,13 @@ from metagpt.schema import (
|
|||
Documents,
|
||||
Message,
|
||||
)
|
||||
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
any_to_str,
|
||||
any_to_str_set,
|
||||
get_project_srcs_path,
|
||||
init_python_folder,
|
||||
)
|
||||
|
||||
IS_PASS_PROMPT = """
|
||||
{context}
|
||||
|
|
@ -239,7 +246,7 @@ class Engineer(Role):
|
|||
|
||||
async def _think(self) -> Action | None:
|
||||
if not self.src_workspace:
|
||||
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
self.src_workspace = get_project_srcs_path(self.project_repo.workdir)
|
||||
write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug])
|
||||
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode])
|
||||
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
|
||||
|
|
@ -248,11 +255,11 @@ class Engineer(Role):
|
|||
msg = self.rc.news[0]
|
||||
if self.config.inc and msg.cause_by in write_plan_and_change_filters:
|
||||
logger.debug(f"TODO WriteCodePlanAndChange:{msg.model_dump_json()}")
|
||||
await self._new_code_plan_and_change_action()
|
||||
await self._new_code_plan_and_change_action(cause_by=msg.cause_by)
|
||||
return self.rc.todo
|
||||
if msg.cause_by in write_code_filters:
|
||||
logger.debug(f"TODO WriteCode:{msg.model_dump_json()}")
|
||||
await self._new_code_actions(bug_fix=msg.cause_by == any_to_str(FixBug))
|
||||
await self._new_code_actions()
|
||||
return self.rc.todo
|
||||
if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self):
|
||||
logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}")
|
||||
|
|
@ -260,14 +267,14 @@ class Engineer(Role):
|
|||
return self.rc.todo
|
||||
return None
|
||||
|
||||
async def _new_coding_context(self, filename, dependency) -> CodingContext:
|
||||
async def _new_coding_context(self, filename, dependency) -> Optional[CodingContext]:
|
||||
old_code_doc = await self.project_repo.srcs.get(filename)
|
||||
if not old_code_doc:
|
||||
old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="")
|
||||
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
|
||||
task_doc = None
|
||||
design_doc = None
|
||||
code_plan_and_change_doc = None
|
||||
code_plan_and_change_doc = await self._get_any_code_plan_and_change() if await self._is_fixbug() else None
|
||||
for i in dependencies:
|
||||
if str(i.parent) == TASK_FILE_REPO:
|
||||
task_doc = await self.project_repo.docs.task.get(i.name)
|
||||
|
|
@ -276,6 +283,8 @@ class Engineer(Role):
|
|||
elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO:
|
||||
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name)
|
||||
if not task_doc or not design_doc:
|
||||
if filename == "__init__.py": # `__init__.py` created by `init_python_folder`
|
||||
return None
|
||||
logger.error(f'Detected source code "{filename}" from an unknown origin.')
|
||||
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
|
||||
context = CodingContext(
|
||||
|
|
@ -287,14 +296,17 @@ class Engineer(Role):
|
|||
)
|
||||
return context
|
||||
|
||||
async def _new_coding_doc(self, filename, dependency):
|
||||
async def _new_coding_doc(self, filename, dependency) -> Optional[Document]:
|
||||
context = await self._new_coding_context(filename, dependency)
|
||||
if not context:
|
||||
return None # `__init__.py` created by `init_python_folder`
|
||||
coding_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json()
|
||||
)
|
||||
return coding_doc
|
||||
|
||||
async def _new_code_actions(self, bug_fix=False):
|
||||
async def _new_code_actions(self):
|
||||
bug_fix = await self._is_fixbug()
|
||||
# Prepare file repos
|
||||
changed_src_files = self.project_repo.srcs.all_files if bug_fix else self.project_repo.srcs.changed_files
|
||||
changed_task_files = self.project_repo.docs.task.changed_files
|
||||
|
|
@ -305,6 +317,7 @@ class Engineer(Role):
|
|||
task_doc = await self.project_repo.docs.task.get(filename)
|
||||
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename)
|
||||
task_list = self._parse_tasks(task_doc)
|
||||
await self._init_python_folder(task_list)
|
||||
for task_filename in task_list:
|
||||
old_code_doc = await self.project_repo.srcs.get(task_filename)
|
||||
if not old_code_doc:
|
||||
|
|
@ -343,6 +356,8 @@ class Engineer(Role):
|
|||
if filename in changed_files.docs:
|
||||
continue
|
||||
coding_doc = await self._new_coding_doc(filename=filename, dependency=dependency)
|
||||
if not coding_doc:
|
||||
continue # `__init__.py` created by `init_python_folder`
|
||||
changed_files.docs[filename] = coding_doc
|
||||
self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm))
|
||||
|
||||
|
|
@ -358,6 +373,8 @@ class Engineer(Role):
|
|||
ctx = CodeSummarizeContext.loads(filenames=list(dependencies))
|
||||
summarizations[ctx].append(filename)
|
||||
for ctx, filenames in summarizations.items():
|
||||
if not ctx.design_filename or not ctx.task_filename:
|
||||
continue # cause by `__init__.py` which is created by `init_python_folder`
|
||||
ctx.codes_filenames = filenames
|
||||
new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm)
|
||||
for i, act in enumerate(self.summarize_todos):
|
||||
|
|
@ -371,15 +388,40 @@ class Engineer(Role):
|
|||
self.set_todo(self.summarize_todos[0])
|
||||
self.summarize_todos.pop(0)
|
||||
|
||||
async def _new_code_plan_and_change_action(self):
|
||||
async def _new_code_plan_and_change_action(self, cause_by: str):
|
||||
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""
|
||||
files = self.project_repo.all_files
|
||||
requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME)
|
||||
requirement = requirement_doc.content if requirement_doc else ""
|
||||
code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, requirement=requirement)
|
||||
options = {}
|
||||
if cause_by != any_to_str(FixBug):
|
||||
requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME)
|
||||
options["requirement"] = requirement_doc.content
|
||||
else:
|
||||
fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME)
|
||||
options["issue"] = fixbug_doc.content
|
||||
code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, **options)
|
||||
self.rc.todo = WriteCodePlanAndChange(i_context=code_plan_and_change_ctx, context=self.context, llm=self.llm)
|
||||
|
||||
@property
|
||||
def action_description(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
return self.next_todo_action
|
||||
|
||||
async def _init_python_folder(self, task_list: List[str]):
|
||||
for i in task_list:
|
||||
filename = Path(i)
|
||||
if filename.suffix != ".py":
|
||||
continue
|
||||
workdir = self.src_workspace / filename.parent
|
||||
await init_python_folder(workdir)
|
||||
|
||||
async def _is_fixbug(self) -> bool:
|
||||
fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME)
|
||||
return bool(fixbug_doc and fixbug_doc.content)
|
||||
|
||||
async def _get_any_code_plan_and_change(self) -> Optional[Document]:
|
||||
changed_files = self.project_repo.docs.code_plan_and_change.changed_files
|
||||
for filename in changed_files.keys():
|
||||
doc = await self.project_repo.docs.code_plan_and_change.get(filename)
|
||||
if doc and doc.content:
|
||||
return doc
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from metagpt.const import MESSAGE_ROUTE_TO_NONE
|
|||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Document, Message, RunCodeContext, TestingContext
|
||||
from metagpt.utils.common import any_to_str_set, parse_recipient
|
||||
from metagpt.utils.common import any_to_str_set, init_python_folder, parse_recipient
|
||||
|
||||
|
||||
class QaEngineer(Role):
|
||||
|
|
@ -141,6 +141,7 @@ class QaEngineer(Role):
|
|||
)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
await init_python_folder(self.project_repo.tests.workdir)
|
||||
if self.test_round > self.test_round_allowed:
|
||||
result_msg = Message(
|
||||
content=f"Exceeding {self.test_round_allowed} rounds of tests, skip (writing code counts as a round, too)",
|
||||
|
|
|
|||
|
|
@ -677,13 +677,14 @@ class BugFixContext(BaseContext):
|
|||
|
||||
class CodePlanAndChangeContext(BaseModel):
|
||||
requirement: str = ""
|
||||
issue: str = ""
|
||||
prd_filename: str = ""
|
||||
design_filename: str = ""
|
||||
task_filename: str = ""
|
||||
|
||||
@staticmethod
|
||||
def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext:
|
||||
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""))
|
||||
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", ""))
|
||||
for filename in filenames:
|
||||
filename = Path(filename)
|
||||
if filename.is_relative_to(PRDS_FILE_REPO):
|
||||
|
|
|
|||
|
|
@ -56,8 +56,10 @@ class Team(BaseModel):
|
|||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
|
||||
team_info_path = stg_path.joinpath("team.json")
|
||||
serialized_data = self.model_dump()
|
||||
serialized_data["context"] = self.env.context.serialize()
|
||||
|
||||
write_json_file(team_info_path, self.model_dump())
|
||||
write_json_file(team_info_path, serialized_data)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path, context: Context = None) -> "Team":
|
||||
|
|
@ -71,6 +73,7 @@ class Team(BaseModel):
|
|||
|
||||
team_info: dict = read_json_file(team_info_path)
|
||||
ctx = context or Context()
|
||||
ctx.deserialize(team_info.pop("context", None))
|
||||
team = Team(**team_info, context=ctx)
|
||||
return team
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,15 @@ from metagpt.tools.libs import (
|
|||
web_scraping,
|
||||
email_login,
|
||||
)
|
||||
from metagpt.tools.libs.software_development import (
|
||||
write_prd,
|
||||
write_design,
|
||||
write_project_plan,
|
||||
write_codes,
|
||||
run_qa_test,
|
||||
fix_bug,
|
||||
git_archive,
|
||||
)
|
||||
|
||||
_ = (
|
||||
data_preprocess,
|
||||
|
|
@ -20,4 +29,11 @@ _ = (
|
|||
gpt_v_generator,
|
||||
web_scraping,
|
||||
email_login,
|
||||
write_prd,
|
||||
write_design,
|
||||
write_project_plan,
|
||||
write_codes,
|
||||
run_qa_test,
|
||||
fix_bug,
|
||||
git_archive,
|
||||
) # Avoid pre-commit error
|
||||
|
|
|
|||
65
metagpt/tools/libs/git.py
Normal file
65
metagpt/tools/libs/git.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
@register_tool(tags=["git"])
|
||||
async def git_clone(url: str, output_dir: str | Path = None) -> Path:
|
||||
"""
|
||||
Clones a Git repository from the given URL.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the Git repository to clone.
|
||||
output_dir (str or Path, optional): The directory where the repository will be cloned.
|
||||
If not provided, the repository will be cloned into the current working directory.
|
||||
|
||||
Returns:
|
||||
Path: The path to the cloned repository.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified Git root is invalid.
|
||||
|
||||
Example:
|
||||
>>> # git clone to /TO/PATH
|
||||
>>> url = 'https://github.com/geekan/MetaGPT.git'
|
||||
>>> output_dir = "/TO/PATH"
|
||||
>>> repo_dir = await git_clone(url=url, output_dir=output_dir)
|
||||
>>> print(repo_dir)
|
||||
/TO/PATH/MetaGPT
|
||||
|
||||
>>> # git clone to default directory.
|
||||
>>> url = 'https://github.com/geekan/MetaGPT.git'
|
||||
>>> repo_dir = await git_clone(url)
|
||||
>>> print(repo_dir)
|
||||
/WORK_SPACE/downloads/MetaGPT
|
||||
"""
|
||||
repo = await GitRepository.clone_from(url, output_dir)
|
||||
return repo.workdir
|
||||
|
||||
|
||||
async def git_checkout(repo_dir: str | Path, commit_id: str):
|
||||
"""
|
||||
Checks out a specific commit in a Git repository.
|
||||
|
||||
Args:
|
||||
repo_dir (str or Path): The directory containing the Git repository.
|
||||
commit_id (str): The ID of the commit to check out.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified Git root is invalid.
|
||||
|
||||
Example:
|
||||
>>> repo_dir = '/TO/GIT/REPO'
|
||||
>>> commit_id = 'main'
|
||||
>>> await git_checkout(repo_dir=repo_dir, commit_id=commit_id)
|
||||
git checkout main
|
||||
"""
|
||||
repo = GitRepository(local_path=repo_dir, auto_init=False)
|
||||
if not repo.is_valid:
|
||||
ValueError(f"Invalid git root: {repo_dir}")
|
||||
await repo.checkout(commit_id)
|
||||
63
metagpt/tools/libs/shell.py
Normal file
63
metagpt/tools/libs/shell.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
|
||||
|
||||
@register_tool(tags=["shell"])
|
||||
async def shell_execute(
|
||||
command: Union[List[str], str], cwd: str | Path = None, env: Dict = None, timeout: int = 600
|
||||
) -> Tuple[str, str, int]:
|
||||
"""
|
||||
Execute a command asynchronously and return its standard output and standard error.
|
||||
|
||||
Args:
|
||||
command (Union[List[str], str]): The command to execute and its arguments. It can be provided either as a list
|
||||
of strings or as a single string.
|
||||
cwd (str | Path, optional): The current working directory for the command. Defaults to None.
|
||||
env (Dict, optional): Environment variables to set for the command. Defaults to None.
|
||||
timeout (int, optional): Timeout for the command execution in seconds. Defaults to 600.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, int]: A tuple containing the string type standard output and string type standard error of the executed command and int type return code.
|
||||
|
||||
Raises:
|
||||
ValueError: If the command times out, this error is raised. The error message contains both standard output and
|
||||
standard error of the timed-out process.
|
||||
|
||||
Example:
|
||||
>>> # command is a list
|
||||
>>> stdout, stderr, returncode = await shell_execute(command=["ls", "-l"], cwd="/home/user", env={"PATH": "/usr/bin"})
|
||||
>>> print(stdout)
|
||||
total 8
|
||||
-rw-r--r-- 1 user user 0 Mar 22 10:00 file1.txt
|
||||
-rw-r--r-- 1 user user 0 Mar 22 10:00 file2.txt
|
||||
...
|
||||
|
||||
>>> # command is a string of shell script
|
||||
>>> stdout, stderr, returncode = await shell_execute(command="ls -l", cwd="/home/user", env={"PATH": "/usr/bin"})
|
||||
>>> print(stdout)
|
||||
total 8
|
||||
-rw-r--r-- 1 user user 0 Mar 22 10:00 file1.txt
|
||||
-rw-r--r-- 1 user user 0 Mar 22 10:00 file2.txt
|
||||
...
|
||||
|
||||
References:
|
||||
This function uses `subprocess.Popen` for executing shell commands asynchronously.
|
||||
"""
|
||||
cwd = str(cwd) if cwd else None
|
||||
shell = True if isinstance(command, str) else False
|
||||
process = subprocess.Popen(command, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, shell=shell)
|
||||
try:
|
||||
# Wait for the process to complete, with a timeout
|
||||
stdout, stderr = process.communicate(timeout=timeout)
|
||||
return stdout.decode("utf-8"), stderr.decode("utf-8"), process.returncode
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill() # Kill the process if it times out
|
||||
stdout, stderr = process.communicate()
|
||||
raise ValueError(f"{stdout.decode('utf-8')}\n{stderr.decode('utf-8')}")
|
||||
301
metagpt/tools/libs/software_development.py
Normal file
301
metagpt/tools/libs/software_development.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME
|
||||
from metagpt.schema import BugFixContext, Message
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "ProductManager"])
|
||||
async def write_prd(idea: str, project_path: Optional[str | Path] = None) -> Path:
|
||||
"""Writes a PRD based on user requirements.
|
||||
|
||||
Args:
|
||||
idea (str): The idea or concept for the PRD.
|
||||
project_path (Optional[str|Path], optional): The path to an existing project directory.
|
||||
If it's None, a new project path will be created. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Path: The path to the PRD files under the project directory
|
||||
|
||||
Example:
|
||||
>>> # Create a new project:
|
||||
>>> from metagpt.tools.libs.software_development import write_prd
|
||||
>>> prd_path = await write_prd("Create a new feature for the application")
|
||||
>>> print(prd_path)
|
||||
'/path/to/project_path/docs/prd/'
|
||||
|
||||
>>> # Add user requirements to the exists project:
|
||||
>>> from metagpt.tools.libs.software_development import write_prd
|
||||
>>> project_path = '/path/to/exists_project_path'
|
||||
>>> prd_path = await write_prd("Create a new feature for the application", project_path=project_path)
|
||||
>>> print(prd_path = )
|
||||
'/path/to/project_path/docs/prd/'
|
||||
"""
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import ProductManager
|
||||
|
||||
ctx = Context()
|
||||
if project_path:
|
||||
ctx.config.project_path = Path(project_path)
|
||||
ctx.config.inc = True
|
||||
role = ProductManager(context=ctx)
|
||||
msg = await role.run(with_message=Message(content=idea, cause_by=UserRequirement))
|
||||
await role.run(with_message=msg)
|
||||
return ctx.repo.docs.prd.workdir
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "Architect"])
|
||||
async def write_design(prd_path: str | Path) -> Path:
|
||||
"""Writes a design to the project repository, based on the PRD of the project.
|
||||
|
||||
Args:
|
||||
prd_path (str|Path): The path to the PRD files under the project directory.
|
||||
|
||||
Returns:
|
||||
Path: The path to the system design files under the project directory.
|
||||
|
||||
Example:
|
||||
>>> from metagpt.tools.libs.software_development import write_design
|
||||
>>> prd_path = '/path/to/project_path/docs/prd' # Returned by `write_prd`
|
||||
>>> system_design_path = await write_desgin(prd_path)
|
||||
>>> print(system_design_path)
|
||||
'/path/to/project_path/docs/system_design/'
|
||||
|
||||
"""
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import Architect
|
||||
|
||||
ctx = Context()
|
||||
project_path = Path(prd_path).parent.parent
|
||||
ctx.set_repo_dir(project_path)
|
||||
|
||||
role = Architect(context=ctx)
|
||||
await role.run(with_message=Message(content="", cause_by=WritePRD))
|
||||
return ctx.repo.docs.system_design.workdir
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "Architect"])
|
||||
async def write_project_plan(system_design_path: str | Path) -> Path:
|
||||
"""Writes a project plan to the project repository, based on the design of the project.
|
||||
|
||||
Args:
|
||||
system_design_path (str|Path): The path to the system design files under the project directory.
|
||||
|
||||
Returns:
|
||||
Path: The path to task files under the project directory.
|
||||
|
||||
Example:
|
||||
>>> from metagpt.tools.libs.software_development import write_project_plan
|
||||
>>> system_design_path = '/path/to/project_path/docs/system_design/' # Returned by `write_design`
|
||||
>>> task_path = await write_project_plan(system_design_path)
|
||||
>>> print(task_path)
|
||||
'/path/to/project_path/docs/task'
|
||||
|
||||
"""
|
||||
from metagpt.actions import WriteDesign
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import ProjectManager
|
||||
|
||||
ctx = Context()
|
||||
project_path = Path(system_design_path).parent.parent
|
||||
ctx.set_repo_dir(project_path)
|
||||
|
||||
role = ProjectManager(context=ctx)
|
||||
await role.run(with_message=Message(content="", cause_by=WriteDesign))
|
||||
return ctx.repo.docs.task.workdir
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "Engineer"])
|
||||
async def write_codes(task_path: str | Path, inc: bool = False) -> Path:
|
||||
"""Writes codes to the project repository, based on the project plan of the project.
|
||||
|
||||
Args:
|
||||
task_path (str|Path): The path to task files under the project directory.
|
||||
inc (bool, optional): Whether to write incremental codes. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Path: The path to the source code files under the project directory.
|
||||
|
||||
Example:
|
||||
# Write codes to a new project
|
||||
>>> from metagpt.tools.libs.software_development import write_codes
|
||||
>>> task_path = '/path/to/project_path/docs/task' # Returned by `write_project_plan`
|
||||
>>> src_path = await write_codes(task_path)
|
||||
>>> print(src_path)
|
||||
'/path/to/project_path/src/'
|
||||
|
||||
# Write increment codes to the exists project
|
||||
>>> from metagpt.tools.libs.software_development import write_codes
|
||||
>>> task_path = '/path/to/project_path/docs/task' # Returned by `write_prd`
|
||||
>>> src_path = await write_codes(task_path, inc=True)
|
||||
>>> print(src_path)
|
||||
'/path/to/project_path/src/'
|
||||
"""
|
||||
from metagpt.actions import WriteTasks
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import Engineer
|
||||
|
||||
ctx = Context()
|
||||
ctx.config.inc = inc
|
||||
project_path = Path(task_path).parent.parent
|
||||
ctx.set_repo_dir(project_path)
|
||||
|
||||
role = Engineer(context=ctx)
|
||||
msg = Message(content="", cause_by=WriteTasks, send_to=role)
|
||||
me = {any_to_str(role), role.name}
|
||||
while me.intersection(msg.send_to):
|
||||
msg = await role.run(with_message=msg)
|
||||
return ctx.repo.srcs.workdir
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "QaEngineer"])
|
||||
async def run_qa_test(src_path: str | Path) -> Path:
|
||||
"""Run QA test on the project repository.
|
||||
|
||||
Args:
|
||||
src_path (str|Path): The path to the source code files under the project directory.
|
||||
|
||||
Returns:
|
||||
Path: The path to the unit tests under the project directory
|
||||
|
||||
Example:
|
||||
>>> from metagpt.tools.libs.software_development import run_qa_test
|
||||
>>> src_path = '/path/to/project_path/src/' # Returned by `write_codes`
|
||||
>>> test_path = await run_qa_test(src_path)
|
||||
>>> print(test_path)
|
||||
'/path/to/project_path/tests'
|
||||
"""
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.context import Context
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles import QaEngineer
|
||||
|
||||
ctx = Context()
|
||||
project_path = Path(src_path).parent
|
||||
ctx.set_repo_dir(project_path)
|
||||
ctx.src_workspace = ctx.git_repo.workdir / ctx.git_repo.workdir.name
|
||||
|
||||
env = Environment(context=ctx)
|
||||
role = QaEngineer(context=ctx)
|
||||
env.add_role(role)
|
||||
|
||||
msg = Message(content="", cause_by=SummarizeCode, send_to=role)
|
||||
env.publish_message(msg)
|
||||
|
||||
while not env.is_idle:
|
||||
await env.run()
|
||||
return ctx.repo.tests.workdir
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "Engineer"])
|
||||
async def fix_bug(project_path: str | Path, issue: str) -> Path:
|
||||
"""Fix bugs in the project repository.
|
||||
|
||||
Args:
|
||||
project_path (str|Path): The path to the project repository.
|
||||
issue (str): Description of the bug or issue.
|
||||
|
||||
Returns:
|
||||
Path: The path to the project directory
|
||||
|
||||
Example:
|
||||
>>> from metagpt.tools.libs.software_development import fix_bug
|
||||
>>> project_path = '/path/to/project_path' # Returned by `write_codes`
|
||||
>>> issue = 'Exception: exception about ...; Bug: bug about ...; Issue: issue about ...'
|
||||
>>> project_path = await fix_bug(project_path=project_path, issue=issue)
|
||||
>>> print(project_path)
|
||||
'/path/to/project_path'
|
||||
"""
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import Engineer
|
||||
|
||||
ctx = Context()
|
||||
ctx.set_repo_dir(project_path)
|
||||
ctx.src_workspace = ctx.git_repo.workdir / ctx.git_repo.workdir.name
|
||||
await ctx.repo.docs.save(filename=BUGFIX_FILENAME, content=issue)
|
||||
await ctx.repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
|
||||
|
||||
role = Engineer(context=ctx)
|
||||
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
|
||||
msg = Message(
|
||||
content=bug_fix.model_dump_json(),
|
||||
instruct_content=bug_fix,
|
||||
role="",
|
||||
cause_by=FixBug,
|
||||
sent_from=role,
|
||||
send_to=role,
|
||||
)
|
||||
me = {any_to_str(role), role.name}
|
||||
while me.intersection(msg.send_to):
|
||||
msg = await role.run(with_message=msg)
|
||||
return project_path
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "git"])
|
||||
async def git_archive(project_path: str | Path) -> str:
|
||||
"""Stage and commit changes for the project repository using Git.
|
||||
|
||||
Args:
|
||||
project_path (str|Path): The path to the project repository.
|
||||
|
||||
|
||||
Returns:
|
||||
git log
|
||||
|
||||
Example:
|
||||
>>> from metagpt.tools.libs.software_development import git_archive
|
||||
>>> project_path = '/path/to/project_path' # Returned by `write_prd`
|
||||
>>> git_log = await git_archive(project_path=project_path)
|
||||
>>> print(git_log)
|
||||
commit a221d1c418c07f2b4fc07001e486285ead1a520a (HEAD -> feature/toollib/software_company, geekan/main)
|
||||
Merge: e01afd09 4a72f398
|
||||
Author: Sirui Hong <x@xx.github.com>
|
||||
Date: Tue Mar 19 15:16:03 2024 +0800
|
||||
Merge pull request #1037 from iorisa/fixbug/issues/1018
|
||||
fixbug: #1018
|
||||
|
||||
"""
|
||||
from metagpt.context import Context
|
||||
|
||||
ctx = Context()
|
||||
ctx.set_repo_dir(project_path)
|
||||
ctx.git_repo.archive()
|
||||
return ctx.git_repo.log()
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "import git repo"])
|
||||
async def import_git_repo(url: str) -> Path:
|
||||
"""
|
||||
Imports a project from a Git website and formats it to MetaGPT project format to enable incremental appending requirements.
|
||||
|
||||
Args:
|
||||
url (str): The Git project URL, such as "https://github.com/geekan/MetaGPT.git".
|
||||
|
||||
Returns:
|
||||
Path: The path of the formatted project.
|
||||
|
||||
Example:
|
||||
# The Git project URL to input
|
||||
>>> git_url = "https://github.com/geekan/MetaGPT.git"
|
||||
|
||||
# Import the Git repository and get the formatted project path
|
||||
>>> formatted_project_path = await import_git_repo(git_url)
|
||||
>>> print("Formatted project path:", formatted_project_path)
|
||||
/PATH/TO/THE/FORMMATTED/PROJECT
|
||||
"""
|
||||
from metagpt.actions.import_repo import ImportRepo
|
||||
from metagpt.context import Context
|
||||
|
||||
ctx = Context()
|
||||
action = ImportRepo(repo_path=url, context=ctx)
|
||||
await action.run()
|
||||
return ctx.repo.workdir
|
||||
|
|
@ -822,19 +822,78 @@ See FAQ 5.8
|
|||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
def get_markdown_codeblock_type(filename: str) -> str:
|
||||
async def get_mime_type(filename: str | Path, force_read: bool = False) -> str:
|
||||
guess_mime_type, _ = mimetypes.guess_type(filename.name)
|
||||
if not guess_mime_type:
|
||||
ext_mappings = {".yml": "text/yaml", ".yaml": "text/yaml"}
|
||||
guess_mime_type = ext_mappings.get(filename.suffix)
|
||||
if not force_read and guess_mime_type:
|
||||
return guess_mime_type
|
||||
|
||||
from metagpt.tools.libs.shell import shell_execute # avoid circular import
|
||||
|
||||
text_set = {
|
||||
"application/json",
|
||||
"application/vnd.chipnuts.karaoke-mmd",
|
||||
"application/javascript",
|
||||
"application/xml",
|
||||
"application/x-sh",
|
||||
"application/sql",
|
||||
"text/yaml",
|
||||
}
|
||||
|
||||
try:
|
||||
stdout, _, _ = await shell_execute(f"file --mime-type {str(filename)}")
|
||||
ix = stdout.rfind(" ")
|
||||
mime_type = stdout[ix:].strip()
|
||||
if mime_type == "text/plain" and guess_mime_type in text_set:
|
||||
return guess_mime_type
|
||||
return mime_type
|
||||
except Exception as e:
|
||||
logger.debug(f"file:{filename}, error:{e}")
|
||||
return "unknown"
|
||||
|
||||
|
||||
def get_markdown_codeblock_type(filename: str = None, mime_type: str = None) -> str:
|
||||
"""Return the markdown code-block type corresponding to the file extension."""
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not filename and not mime_type:
|
||||
raise ValueError("Either filename or mime_type must be valid.")
|
||||
|
||||
if not mime_type:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
mappings = {
|
||||
"text/x-shellscript": "bash",
|
||||
"text/x-c++src": "cpp",
|
||||
"text/css": "css",
|
||||
"text/html": "html",
|
||||
"text/x-java": "java",
|
||||
"application/javascript": "javascript",
|
||||
"application/json": "json",
|
||||
"text/x-python": "python",
|
||||
"text/x-ruby": "ruby",
|
||||
"text/x-c": "cpp",
|
||||
"text/yaml": "yaml",
|
||||
"application/javascript": "javascript",
|
||||
"application/json": "json",
|
||||
"application/sql": "sql",
|
||||
"application/vnd.chipnuts.karaoke-mmd": "mermaid",
|
||||
"application/x-sh": "bash",
|
||||
"application/xml": "xml",
|
||||
}
|
||||
return mappings.get(mime_type, "text")
|
||||
|
||||
|
||||
def get_project_srcs_path(workdir: str | Path) -> Path:
|
||||
src_workdir_path = workdir / ".src_workspace"
|
||||
if src_workdir_path.exists():
|
||||
with open(src_workdir_path, "r") as file:
|
||||
src_name = file.read()
|
||||
else:
|
||||
src_name = Path(workdir).name
|
||||
return Path(workdir) / src_name
|
||||
|
||||
|
||||
async def init_python_folder(workdir: str | Path):
|
||||
init_filename = Path(workdir) / "__init__.py"
|
||||
if init_filename.exists():
|
||||
return
|
||||
async with aiofiles.open(init_filename, "a"):
|
||||
os.utime(init_filename, None)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
|
@ -16,8 +17,10 @@ from typing import Dict, List
|
|||
from git.repo import Repo
|
||||
from git.repo.fun import is_git_dir
|
||||
from gitignore_parser import parse_gitignore
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.libs.shell import shell_execute
|
||||
from metagpt.utils.dependency_file import DependencyFile
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
|
@ -52,7 +55,7 @@ class GitRepository:
|
|||
self._dependency = None
|
||||
self._gitignore_rules = None
|
||||
if local_path:
|
||||
self.open(local_path=local_path, auto_init=auto_init)
|
||||
self.open(local_path=Path(local_path), auto_init=auto_init)
|
||||
|
||||
def open(self, local_path: Path, auto_init=False):
|
||||
"""Open an existing Git repository or initialize a new one if auto_init is True.
|
||||
|
|
@ -68,7 +71,7 @@ class GitRepository:
|
|||
if not auto_init:
|
||||
return
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
return self._init(local_path)
|
||||
self._init(local_path)
|
||||
|
||||
def _init(self, local_path: Path):
|
||||
"""Initialize a new Git repository at the specified path.
|
||||
|
|
@ -248,6 +251,8 @@ class GitRepository:
|
|||
if not directory_path.exists():
|
||||
return []
|
||||
for file_path in directory_path.iterdir():
|
||||
if not file_path.is_relative_to(root_relative_path):
|
||||
continue
|
||||
if file_path.is_file():
|
||||
rpath = file_path.relative_to(root_relative_path)
|
||||
files.append(str(rpath))
|
||||
|
|
@ -283,3 +288,37 @@ class GitRepository:
|
|||
continue
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
@retry(wait=wait_random_exponential(min=1, max=15), stop=stop_after_attempt(3))
|
||||
async def clone_from(cls, url: str | Path, output_dir: str | Path = None) -> "GitRepository":
|
||||
from metagpt.context import Context
|
||||
|
||||
to_path = Path(output_dir or Path(__file__).parent / f"../../workspace/downloads/{uuid.uuid4().hex}").resolve()
|
||||
to_path.mkdir(parents=True, exist_ok=True)
|
||||
repo_dir = to_path / Path(url).stem
|
||||
if repo_dir.exists():
|
||||
shutil.rmtree(repo_dir, ignore_errors=True)
|
||||
ctx = Context()
|
||||
env = ctx.new_environ()
|
||||
proxy = ["-c", f"http.proxy={ctx.config.proxy}"] if ctx.config.proxy else []
|
||||
command = ["git", "clone"] + proxy + [str(url)]
|
||||
logger.info(" ".join(command))
|
||||
|
||||
stdout, stderr, return_code = await shell_execute(command=command, cwd=str(to_path), env=env, timeout=600)
|
||||
info = f"{stdout}\n{stderr}\nexit: {return_code}\n"
|
||||
logger.info(info)
|
||||
dir_name = Path(url).stem
|
||||
to_path = to_path / dir_name
|
||||
if not cls.is_git_dir(to_path):
|
||||
raise ValueError(info)
|
||||
logger.info(f"git clone to {to_path}")
|
||||
return GitRepository(local_path=to_path, auto_init=False)
|
||||
|
||||
async def checkout(self, commit_id: str):
|
||||
self._repository.git.checkout(commit_id)
|
||||
logger.info(f"git checkout {commit_id}")
|
||||
|
||||
def log(self) -> str:
|
||||
"""Return git log"""
|
||||
return self._repository.git.log()
|
||||
|
|
|
|||
|
|
@ -49,6 +49,10 @@ class GraphKeyword:
|
|||
IS_COMPOSITE_OF = "is_composite_of"
|
||||
IS_AGGREGATE_OF = "is_aggregate_of"
|
||||
HAS_PARTICIPANT = "has_participant"
|
||||
HAS_SUMMARY = "has_summary"
|
||||
HAS_INSTALL = "has_install"
|
||||
HAS_CONFIG = "has_config"
|
||||
HAS_USAGE = "has_usage"
|
||||
|
||||
|
||||
class SPO(BaseModel):
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from metagpt.const import (
|
|||
TEST_OUTPUTS_FILE_REPO,
|
||||
VISUAL_GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.utils.common import get_project_srcs_path
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
|
@ -129,11 +130,10 @@ class ProjectRepo(FileRepository):
|
|||
return self._git_repo.new_file_repository(self._srcs_path)
|
||||
|
||||
def code_files_exists(self) -> bool:
|
||||
git_workdir = self.git_repo.workdir
|
||||
src_workdir = git_workdir / git_workdir.name
|
||||
src_workdir = get_project_srcs_path(self.git_repo.workdir)
|
||||
if not src_workdir.exists():
|
||||
return False
|
||||
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
|
||||
code_files = self.with_src_path(path=src_workdir).srcs.all_files
|
||||
if not code_files:
|
||||
return False
|
||||
return bool(code_files)
|
||||
|
|
|
|||
|
|
@ -5,17 +5,24 @@ This file provides functionality to convert a local repository into a markdown r
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import mimetypes
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import aread, awrite, get_markdown_codeblock_type, list_files
|
||||
from metagpt.utils.common import (
|
||||
aread,
|
||||
awrite,
|
||||
get_markdown_codeblock_type,
|
||||
get_mime_type,
|
||||
list_files,
|
||||
)
|
||||
from metagpt.utils.tree import tree
|
||||
|
||||
|
||||
async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, gitignore: str | Path = None) -> str:
|
||||
async def repo_to_markdown(repo_path: str | Path, output: str | Path = None) -> str:
|
||||
"""
|
||||
Convert a local repository into a markdown representation.
|
||||
|
||||
|
|
@ -25,56 +32,108 @@ async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, git
|
|||
Args:
|
||||
repo_path (str | Path): The path to the local repository.
|
||||
output (str | Path, optional): The path to save the generated markdown file. Defaults to None.
|
||||
gitignore (str | Path, optional): The path to the .gitignore file. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The markdown representation of the repository.
|
||||
"""
|
||||
repo_path = Path(repo_path)
|
||||
gitignore = Path(gitignore or Path(__file__).parent / "../../.gitignore").resolve()
|
||||
repo_path = Path(repo_path).resolve()
|
||||
gitignore_file = repo_path / ".gitignore"
|
||||
|
||||
markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore)
|
||||
markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore_file)
|
||||
|
||||
gitignore_rules = parse_gitignore(full_path=str(gitignore))
|
||||
gitignore_rules = parse_gitignore(full_path=str(gitignore_file)) if gitignore_file.exists() else None
|
||||
markdown += await _write_files(repo_path=repo_path, gitignore_rules=gitignore_rules)
|
||||
|
||||
if output:
|
||||
await awrite(filename=str(output), data=markdown, encoding="utf-8")
|
||||
output_file = Path(output).resolve()
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
await awrite(filename=str(output_file), data=markdown, encoding="utf-8")
|
||||
logger.info(f"save: {output_file}")
|
||||
return markdown
|
||||
|
||||
|
||||
async def _write_dir_tree(repo_path: Path, gitignore: Path) -> str:
|
||||
try:
|
||||
content = tree(repo_path, gitignore, run_command=True)
|
||||
content = await tree(repo_path, gitignore, run_command=True)
|
||||
except Exception as e:
|
||||
logger.info(f"{e}, using safe mode.")
|
||||
content = tree(repo_path, gitignore, run_command=False)
|
||||
content = await tree(repo_path, gitignore, run_command=False)
|
||||
|
||||
doc = f"## Directory Tree\n```text\n{content}\n```\n---\n\n"
|
||||
return doc
|
||||
|
||||
|
||||
async def _write_files(repo_path, gitignore_rules) -> str:
|
||||
async def _write_files(repo_path, gitignore_rules=None) -> str:
|
||||
filenames = list_files(repo_path)
|
||||
markdown = ""
|
||||
pattern = r"^\..*" # Hidden folders/files
|
||||
for filename in filenames:
|
||||
if gitignore_rules(str(filename)):
|
||||
if gitignore_rules and gitignore_rules(str(filename)):
|
||||
continue
|
||||
ignore = False
|
||||
for i in filename.parts:
|
||||
if re.match(pattern, i):
|
||||
ignore = True
|
||||
break
|
||||
if ignore:
|
||||
continue
|
||||
markdown += await _write_file(filename=filename, repo_path=repo_path)
|
||||
return markdown
|
||||
|
||||
|
||||
async def _write_file(filename: Path, repo_path: Path) -> str:
|
||||
relative_path = filename.relative_to(repo_path)
|
||||
markdown = f"## {relative_path}\n"
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(filename.name)
|
||||
if "text/" not in mime_type:
|
||||
is_text, mime_type = await _is_text_file(filename)
|
||||
if not is_text:
|
||||
logger.info(f"Ignore content: {filename}")
|
||||
markdown += "<binary file>\n---\n\n"
|
||||
return ""
|
||||
|
||||
try:
|
||||
relative_path = filename.relative_to(repo_path)
|
||||
markdown = f"## {relative_path}\n"
|
||||
content = await aread(filename, encoding="utf-8")
|
||||
content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
|
||||
code_block_type = get_markdown_codeblock_type(filename.name)
|
||||
markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
|
||||
return markdown
|
||||
content = await aread(filename, encoding="utf-8")
|
||||
content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
|
||||
code_block_type = get_markdown_codeblock_type(filename.name)
|
||||
markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
|
||||
return markdown
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return ""
|
||||
|
||||
|
||||
async def _is_text_file(filename: Path) -> Tuple[bool, str]:
|
||||
pass_set = {
|
||||
"application/json",
|
||||
"application/vnd.chipnuts.karaoke-mmd",
|
||||
"application/javascript",
|
||||
"application/xml",
|
||||
"application/x-sh",
|
||||
"application/sql",
|
||||
}
|
||||
denied_set = {
|
||||
"application/zlib",
|
||||
"application/octet-stream",
|
||||
"image/svg+xml",
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.ms-excel",
|
||||
"audio/x-wav",
|
||||
"application/x-git",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/zip",
|
||||
"image/jpeg",
|
||||
"audio/mpeg",
|
||||
"video/mp2t",
|
||||
"inode/x-empty",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"image/png",
|
||||
"image/vnd.microsoft.icon",
|
||||
"video/mp4",
|
||||
}
|
||||
mime_type = await get_mime_type(filename, force_read=True)
|
||||
v = "text/" in mime_type or mime_type in pass_set
|
||||
if v:
|
||||
return True, mime_type
|
||||
|
||||
if mime_type not in denied_set:
|
||||
logger.info(mime_type)
|
||||
return False, mime_type
|
||||
|
|
|
|||
|
|
@ -27,14 +27,15 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
from metagpt.tools.libs.shell import shell_execute
|
||||
|
||||
def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
|
||||
|
||||
async def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
|
||||
"""
|
||||
Recursively traverses the directory structure and prints it out in a tree-like format.
|
||||
|
||||
|
|
@ -80,7 +81,7 @@ def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = Fal
|
|||
"""
|
||||
root = Path(root).resolve()
|
||||
if run_command:
|
||||
return _execute_tree(root, gitignore)
|
||||
return await _execute_tree(root, gitignore)
|
||||
|
||||
git_ignore_rules = parse_gitignore(gitignore) if gitignore else None
|
||||
dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)}
|
||||
|
|
@ -129,12 +130,7 @@ def _add_line(rows: List[str]) -> List[str]:
|
|||
return rows
|
||||
|
||||
|
||||
def _execute_tree(root: Path, gitignore: str | Path) -> str:
|
||||
async def _execute_tree(root: Path, gitignore: str | Path) -> str:
|
||||
args = ["--gitfile", str(gitignore)] if gitignore else []
|
||||
try:
|
||||
result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True)
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"tree exits with code {result.returncode}")
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise e
|
||||
stdout, _, _ = await shell_execute(["tree"] + args + [str(root)])
|
||||
return stdout
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ gitignore-parser==0.1.9
|
|||
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
|
||||
websockets~=11.0
|
||||
networkx~=3.2.1
|
||||
google-generativeai==0.3.2
|
||||
google-generativeai==0.4.1
|
||||
playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
|
||||
anytree
|
||||
ipywidgets==8.1.1
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -36,6 +36,8 @@ extras_require = {
|
|||
"llama-index-readers-file==0.1.4",
|
||||
"llama-index-retrievers-bm25==0.1.3",
|
||||
"llama-index-vector-stores-faiss==0.1.1",
|
||||
"llama-index-vector-stores-elasticsearch==0.1.6",
|
||||
"llama-index-postprocessor-colbert-rerank==0.1.1",
|
||||
"chromadb==0.4.23",
|
||||
],
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
26
tests/metagpt/actions/test_extract_readme.py
Normal file
26
tests/metagpt/actions/test_extract_readme.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.extract_readme import ExtractReadMe
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_learn_readme(context):
|
||||
action = ExtractReadMe(
|
||||
name="RedBean",
|
||||
i_context=str(Path(__file__).parent.parent.parent.parent),
|
||||
llm=LLM(),
|
||||
context=context,
|
||||
)
|
||||
await action.run()
|
||||
rows = await action.graph_db.select()
|
||||
assert rows
|
||||
assert context.repo.docs.graph_repo.changed_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
33
tests/metagpt/actions/test_import_repo.py
Normal file
33
tests/metagpt/actions/test_import_repo.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.import_repo import ImportRepo
|
||||
from metagpt.context import Context
|
||||
from metagpt.utils.common import list_files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"repo_path",
|
||||
[
|
||||
"https://github.com/spec-first/connexion.git",
|
||||
# "https://github.com/geekan/MetaGPT.git"
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip
|
||||
async def test_import_repo(repo_path):
|
||||
context = Context()
|
||||
action = ImportRepo(repo_path=repo_path, context=context)
|
||||
await action.run()
|
||||
assert context.repo
|
||||
prd = list_files(context.repo.docs.prd.workdir)
|
||||
assert prd
|
||||
design = list_files(context.repo.docs.system_design.workdir)
|
||||
assert design
|
||||
assert prd[0].stem == design[0].stem
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -18,6 +18,7 @@ from metagpt.utils.git_repository import ChangeType
|
|||
from metagpt.utils.graph_repository import SPO
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild(context, mocker):
|
||||
# Mock
|
||||
|
|
@ -60,7 +61,7 @@ async def test_rebuild(context, mocker):
|
|||
],
|
||||
)
|
||||
def test_get_full_filename(root, pathname, want):
|
||||
res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname)
|
||||
res = RebuildSequenceView.get_full_filename(root=root, pathname=pathname)
|
||||
assert res == want
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import pytest
|
|||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
default_resp_cont,
|
||||
get_part_chat_completion,
|
||||
|
|
@ -22,7 +23,7 @@ name = "GPT"
|
|||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
pass
|
||||
self.config = config or mock_llm_config
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(name)
|
||||
|
|
|
|||
60
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
60
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import ObjectNode
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
score: int
|
||||
|
||||
|
||||
class TestObjectSortPostprocessor:
|
||||
@pytest.fixture
|
||||
def nodes_with_scores(self):
|
||||
nodes = [
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5),
|
||||
]
|
||||
return nodes
|
||||
|
||||
@pytest.fixture
|
||||
def query_bundle(self, mocker):
|
||||
return mocker.MagicMock(spec=QueryBundle)
|
||||
|
||||
def test_sort_descending(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [20, 10, 5]
|
||||
|
||||
def test_sort_ascending(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [5, 10, 20]
|
||||
|
||||
def test_top_n_limit(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert len(sorted_nodes) == 2
|
||||
assert [node.score for node in sorted_nodes] == [20, 10]
|
||||
|
||||
def test_invalid_json_metadata(self, query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle)
|
||||
|
||||
def test_missing_query_bundle(self, nodes_with_scores):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
|
||||
|
||||
def test_field_not_found_in_object(self):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes)
|
||||
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager
|
||||
from metagpt.team import Team
|
||||
|
|
@ -146,5 +147,21 @@ async def test_team_recover_multi_roles_save(mocker, context):
|
|||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context(context):
|
||||
context.kwargs.set("a", "a")
|
||||
context.cost_manager.max_budget = 9
|
||||
company = Team(context=context)
|
||||
|
||||
save_to = context.repo.workdir / "serial"
|
||||
company.serialize(save_to)
|
||||
|
||||
company.deserialize(save_to, Context())
|
||||
assert company.env.context.repo
|
||||
assert company.env.context.repo.workdir == context.repo.workdir
|
||||
assert company.env.context.kwargs.a == "a"
|
||||
assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
31
tests/metagpt/tools/libs/test_git.py
Normal file
31
tests/metagpt/tools/libs/test_git.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.tools.libs.git import git_checkout, git_clone
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
class SWEBenchItem(BaseModel):
|
||||
base_commit: str
|
||||
repo: str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
["url", "commit_id"], [("https://github.com/sqlfluff/sqlfluff.git", "d19de0ecd16d298f9e3bfb91da122734c40c01e5")]
|
||||
)
|
||||
async def test_git(url: str, commit_id: str):
|
||||
repo_dir = await git_clone(url)
|
||||
assert repo_dir
|
||||
|
||||
await git_checkout(repo_dir, commit_id)
|
||||
|
||||
repo = GitRepository(repo_dir, auto_init=False)
|
||||
repo.delete_repository()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
23
tests/metagpt/tools/libs/test_shell.py
Normal file
23
tests/metagpt/tools/libs/test_shell.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.libs.shell import execute
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
["command", "expect_stdout", "expect_stderr"],
|
||||
[
|
||||
(["file", f"{__file__}"], "Python script text executable, ASCII text", ""),
|
||||
(f"file {__file__}", "Python script text executable, ASCII text", ""),
|
||||
],
|
||||
)
|
||||
async def test_shell(command, expect_stdout, expect_stderr):
|
||||
stdout, stderr = await execute(command)
|
||||
assert expect_stdout in stdout
|
||||
assert stderr == expect_stderr
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
65
tests/metagpt/tools/libs/test_software_development.py
Normal file
65
tests/metagpt/tools/libs/test_software_development.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.libs import (
|
||||
fix_bug,
|
||||
git_archive,
|
||||
run_qa_test,
|
||||
write_codes,
|
||||
write_design,
|
||||
write_prd,
|
||||
write_project_plan,
|
||||
)
|
||||
from metagpt.tools.libs.software_development import import_git_repo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_software_team():
|
||||
path = await write_prd("snake game")
|
||||
assert path
|
||||
|
||||
path = await write_design(path)
|
||||
assert path
|
||||
|
||||
path = await write_project_plan(path)
|
||||
assert path
|
||||
|
||||
path = await write_codes(path)
|
||||
assert path
|
||||
|
||||
path = await run_qa_test(path)
|
||||
assert path
|
||||
|
||||
issue = """
|
||||
pygame 2.0.1 (SDL 2.0.14, Python 3.9.17)
|
||||
Hello from the pygame community. https://www.pygame.org/contribute.html
|
||||
Traceback (most recent call last):
|
||||
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/main.py", line 10, in <module>
|
||||
main()
|
||||
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/main.py", line 7, in main
|
||||
game.start_game()
|
||||
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/game.py", line 81, in start_game
|
||||
x
|
||||
NameError: name 'x' is not defined
|
||||
"""
|
||||
path = await fix_bug(path, issue)
|
||||
assert path
|
||||
|
||||
new_path = await write_prd("snake game with moving enemy", path)
|
||||
assert new_path == path
|
||||
|
||||
git_log = await git_archive(new_path)
|
||||
assert git_log
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_repo():
|
||||
url = "https://github.com/spec-first/connexion.git"
|
||||
path = await import_git_repo(url)
|
||||
assert path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -10,7 +10,12 @@ from metagpt.utils.repo_to_markdown import repo_to_markdown
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
["repo_path", "output"],
|
||||
[(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")],
|
||||
[
|
||||
(
|
||||
Path(__file__).parent.parent.parent,
|
||||
Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}.md",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_repo_to_markdown(repo_path: Path, output: Path):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Optional, Union
|
|||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
|
|
@ -22,7 +23,7 @@ class MockLLM(OriginalLLM):
|
|||
self.rsp_cache: dict = {}
|
||||
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=LLM_API_TIMEOUT) -> str:
|
||||
"""Overwrite original acompletion_text to cancel retry"""
|
||||
if stream:
|
||||
resp = await self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
|
@ -37,7 +38,7 @@ class MockLLM(OriginalLLM):
|
|||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
timeout=LLM_API_TIMEOUT,
|
||||
stream=True,
|
||||
) -> str:
|
||||
if system_msgs:
|
||||
|
|
@ -56,7 +57,7 @@ class MockLLM(OriginalLLM):
|
|||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
||||
async def original_aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
async def original_aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
|
||||
"""A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked"""
|
||||
context = []
|
||||
for msg in msgs:
|
||||
|
|
@ -83,7 +84,7 @@ class MockLLM(OriginalLLM):
|
|||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
timeout=LLM_API_TIMEOUT,
|
||||
stream=True,
|
||||
) -> str:
|
||||
# used to identify it a message has been called before
|
||||
|
|
@ -98,7 +99,7 @@ class MockLLM(OriginalLLM):
|
|||
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream)
|
||||
return rsp
|
||||
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
async def aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
|
||||
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
|
||||
return rsp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue