Merge branch 'feature/import_repo' into featur/intent_detect

This commit is contained in:
莘权 马 2024-03-29 10:53:24 +08:00
commit 2e82a16e74
54 changed files with 1736 additions and 142 deletions

View file

@ -0,0 +1,39 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import fire
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main():
prompt = """
This is a software requirement:
```text
write a snake game
```
---
1. Writes a PRD based on software requirements.
2. Writes a design to the project repository, based on the PRD of the project.
3. Writes a project plan to the project repository, based on the design of the project.
4. Writes codes to the project repository, based on the project plan of the project.
5. Run QA test on the project repository.
6. Stage and commit changes for the project repository using Git.
Note: All required dependencies and environments have been fully installed and configured.
"""
di = DataInterpreter(
tools=[
"write_prd",
"write_design",
"write_project_plan",
"write_codes",
"run_qa_test",
"fix_bug",
"git_archive",
]
)
await di.run(prompt)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -11,9 +11,13 @@ from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaIndexConfig,
ChromaRetrieverConfig,
ElasticsearchIndexConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
)
from metagpt.utils.exceptions import handle_exception
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
QUESTION = "What are key qualities to be a good writer?"
@ -39,12 +43,22 @@ class Player(BaseModel):
class RAGExample:
"""Show how to use RAG."""
def __init__(self):
self.engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
def __init__(self, engine: SimpleEngine = None):
self._engine = engine
@property
def engine(self):
if not self._engine:
self._engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
return self._engine
@engine.setter
def engine(self, value: SimpleEngine):
self._engine = value
async def run_pipeline(self, question=QUESTION, print_title=True):
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
@ -97,6 +111,7 @@ class RAGExample:
self.engine.add_docs([travel_filepath])
await self.run_pipeline(question=travel_question, print_title=False)
@handle_exception
async def add_objects(self, print_title=True):
"""This example show how to add objects.
@ -154,20 +169,41 @@ class RAGExample:
"""
self._print_title("Init And Query ChromaDB")
# save index
# 1. save index
output_dir = DATA_PATH / "rag"
SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
)
# load index
engine = SimpleEngine.from_index(
index_config=ChromaIndexConfig(persist_path=output_dir),
# 2. load index
engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir))
# 3. query
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)
@handle_exception
async def init_and_query_es(self):
"""This example show how to use es. how to save and load index. will print something like:
Query Result:
Bob likes traveling.
"""
self._print_title("Init And Query Elasticsearch")
# 1. create es index and save docs
store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200")
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)],
)
# query
answer = engine.query(TRAVEL_QUESTION)
# 2. load index
engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config))
# 3. query
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)
@staticmethod
@ -205,6 +241,7 @@ async def main():
await e.add_objects()
await e.init_objects()
await e.init_and_query_chromadb()
await e.init_and_query_es()
if __name__ == "__main__":

View file

@ -104,3 +104,8 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
if self.node:
return await self._run_action_node(*args, **kwargs)
raise NotImplementedError("The run method should be implemented in a subclass.")
def override_context(self):
"""Set `private_context` and `context` to the same `Context` object."""
if not self.private_context:
self.private_context = self.context

View file

@ -340,10 +340,7 @@ class ActionNode:
def tagging(self, text, schema, tag="") -> str:
if not tag:
return text
if schema == "json":
return f"[{tag}]\n" + text + f"\n[/{tag}]"
else: # markdown
return f"[{tag}]\n" + text + f"\n[/{tag}]"
return f"[{tag}]\n{text}\n[/{tag}]"
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
@ -375,7 +372,7 @@ class ActionNode:
schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action
"""
if schema == "raw":
return context + "\n\n## Actions\n" + LANGUAGE_CONSTRAINT + "\n" + self.instruction
return f"{context}\n\n## Actions\n{LANGUAGE_CONSTRAINT}\n{self.instruction}"
### 直接使用 pydantic BaseModel 生成 instruction 与 example仅限 JSON
# child_class = self._create_children_class()

View file

@ -0,0 +1,123 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Module Description: This script defines the LearnReadMe class, which is an action to learn from the contents of
a README.md file.
Author: mashenquan
Date: 2024-3-20
"""
from pathlib import Path
from typing import Optional
from pydantic import Field
from metagpt.actions import Action
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.schema import Message
from metagpt.utils.common import aread
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class ExtractReadMe(Action):
"""
An action to extract summary, installation, configuration, usages from the contents of a README.md file.
Attributes:
graph_db (Optional[GraphRepository]): A graph database repository.
install_to_path (Optional[str]): The path where the repository to install to.
"""
graph_db: Optional[GraphRepository] = None
install_to_path: Optional[str] = Field(default="/TO/PATH")
_readme: Optional[str] = None
_filename: Optional[str] = None
async def run(self, with_messages=None, **kwargs):
"""
Implementation of `Action`'s `run` method.
Args:
with_messages (Optional[Type]): An optional argument specifying messages to react to.
"""
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
summary = await self._summarize()
await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_SUMMARY, object_=summary)
install = await self._extract_install()
await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_INSTALL, object_=install)
conf = await self._extract_configuration()
await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_CONFIG, object_=conf)
usage = await self._extract_usage()
await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_USAGE, object_=usage)
await self.graph_db.save()
return Message(content="", cause_by=self)
async def _summarize(self) -> str:
readme = await self._get()
summary = await self.llm.aask(
readme,
system_msgs=[
"You are a tool can summarize git repository README.md file.",
"Return the summary about what is the repository.",
],
stream=False,
)
return summary
async def _extract_install(self) -> str:
await self._get()
install = await self.llm.aask(
self._readme,
system_msgs=[
"You are a tool can install git repository according to README.md file.",
"Return a bash code block of markdown including:\n"
f"1. git clone the repository to the directory `{self.install_to_path}`;\n"
f"2. cd `{self.install_to_path}`;\n"
f"3. install the repository.",
],
stream=False,
)
return install
async def _extract_configuration(self) -> str:
await self._get()
configuration = await self.llm.aask(
self._readme,
system_msgs=[
"You are a tool can configure git repository according to README.md file.",
"Return a bash code block of markdown object to configure the repository if necessary, otherwise return"
" a empty bash code block of markdown object",
],
stream=False,
)
return configuration
async def _extract_usage(self) -> str:
await self._get()
usage = await self.llm.aask(
self._readme,
system_msgs=[
"You are a tool can summarize all usages of git repository according to README.md file.",
"Return a list of code block of markdown objects to demonstrates the usage of the repository.",
],
stream=False,
)
return usage
async def _get(self) -> str:
if self._readme is not None:
return self._readme
root = Path(self.i_context).resolve()
filename = None
for file_path in root.iterdir():
if file_path.is_file() and file_path.stem == "README":
filename = file_path
break
if not filename:
return ""
self._readme = await aread(filename=filename, encoding="utf-8")
self._filename = str(filename)
return self._readme

View file

@ -0,0 +1,226 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This script defines an action to import a Git repository into the MetaGPT project format, enabling incremental
appending of requirements.
The MetaGPT project format encompasses a structured representation of project data compatible with MetaGPT's
capabilities, facilitating the integration of Git repositories into MetaGPT workflows while allowing for the gradual
addition of requirements.
"""
import json
import re
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel
from metagpt.actions import Action
from metagpt.actions.extract_readme import ExtractReadMe
from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools.libs.git import git_clone
from metagpt.utils.common import (
aread,
awrite,
list_files,
parse_json_code_block,
split_namespace,
)
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
from metagpt.utils.project_repo import ProjectRepo
class ImportRepo(Action):
"""
An action to import a Git repository into a graph database and create related artifacts.
Attributes:
repo_path (str): The URL of the Git repository to import.
graph_db (Optional[GraphRepository]): The output graph database of the Git repository.
rid (str): The output requirement ID.
"""
repo_path: str # input, git repo url.
graph_db: Optional[GraphRepository] = None # output. graph db of the git repository
rid: str = "" # output, requirement ID.
async def run(self, with_messages: List[Message] = None, **kwargs) -> Message:
"""
Runs the import process for the Git repository.
Args:
with_messages (List[Message], optional): Additional messages to include.
**kwargs: Additional keyword arguments.
Returns:
Message: A message indicating the completion of the import process.
"""
await self._create_repo()
await self._create_prd()
await self._create_system_design()
self.context.git_repo.archive(comments="Import")
async def _create_repo(self):
path = await git_clone(url=self.repo_path, output_dir=self.config.workspace.path)
self.repo_path = str(path)
self.config.project_path = path
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
self.context.repo = ProjectRepo(self.context.git_repo)
self.context.src_workspace = await self._guess_src_workspace()
await awrite(
filename=self.context.repo.workdir / ".src_workspace",
data=str(self.context.src_workspace.relative_to(self.context.repo.workdir)),
)
async def _create_prd(self):
action = ExtractReadMe(i_context=str(self.context.repo.workdir), context=self.context)
await action.run()
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SUMMARY)
prd = {"Project Name": self.context.repo.workdir.name}
for r in rows:
if Path(r.subject).stem == "README":
prd["Original Requirements"] = r.object_
break
self.rid = FileRepository.new_filename()
await self.repo.docs.prd.save(filename=self.rid + ".json", content=json.dumps(prd))
async def _create_system_design(self):
action = RebuildClassView(
name="ReverseEngineering", i_context=str(self.context.src_workspace), context=self.context
)
await action.run()
rows = await action.graph_db.select(predicate="hasMermaidClassDiagramFile")
class_view_filename = rows[0].object_
logger.info(f"class view:{class_view_filename}")
rows = await action.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
tag = "__name__:__main__"
entries = []
src_workspace = self.context.src_workspace.relative_to(self.context.repo.workdir)
for r in rows:
if tag in r.subject:
path = split_namespace(r.subject)[0]
elif tag in r.object_:
path = split_namespace(r.object_)[0]
else:
continue
if Path(path).is_relative_to(src_workspace):
entries.append(Path(path))
main_entry = await self._guess_main_entry(entries)
full_path = RebuildSequenceView.get_full_filename(self.context.repo.workdir, main_entry)
action = RebuildSequenceView(context=self.context, i_context=str(full_path))
try:
await action.run()
except Exception as e:
logger.warning(f"{e}, use the last successful version.")
files = list_files(self.context.repo.resources.data_api_design.workdir)
pattern = re.compile(r"[^a-zA-Z0-9]")
name = re.sub(pattern, "_", str(main_entry))
filename = Path(name).with_suffix(".sequence_diagram.mmd")
postfix = str(filename)
sequence_files = [i for i in files if postfix in str(i)]
content = await aread(filename=sequence_files[0])
await self.context.repo.resources.data_api_design.save(
filename=self.repo.workdir.stem + ".sequence_diagram.mmd", content=content
)
await self._save_system_design()
async def _save_system_design(self):
class_view = await self.context.repo.resources.data_api_design.get(
filename=self.repo.workdir.stem + ".class_diagram.mmd"
)
sequence_view = await self.context.repo.resources.data_api_design.get(
filename=self.repo.workdir.stem + ".sequence_diagram.mmd"
)
file_list = self.context.git_repo.get_files(relative_path=".", root_relative_path=self.context.src_workspace)
data = {
"Data structures and interfaces": class_view.content,
"Program call flow": sequence_view.content,
"File list": [str(i) for i in file_list],
}
await self.context.repo.docs.system_design.save(filename=self.rid + ".json", content=json.dumps(data))
async def _guess_src_workspace(self) -> Path:
files = list_files(self.context.repo.workdir)
dirs = [i.parent for i in files if i.name == "__init__.py"]
distinct = set()
for i in dirs:
done = False
for j in distinct:
if i.is_relative_to(j):
done = True
break
if j.is_relative_to(i):
break
if not done:
distinct = {j for j in distinct if not j.is_relative_to(i)}
distinct.add(i)
if len(distinct) == 1:
return list(distinct)[0]
prompt = "\n".join([f"- {str(i)}" for i in distinct])
rsp = await self.llm.aask(
prompt,
system_msgs=[
"You are a tool to choose the source code path from a list of paths based on the directory name.",
"You should identify the source code path among paths such as unit test path, examples path, etc.",
"Return a markdown JSON object containing:\n"
'- a "src" field containing the source code path;\n'
'- a "reason" field containing explaining why other paths is not the source code path\n',
],
)
logger.debug(rsp)
json_blocks = parse_json_code_block(rsp)
class Data(BaseModel):
src: str
reason: str
data = Data.model_validate_json(json_blocks[0])
logger.info(f"src_workspace: {data.src}")
return Path(data.src)
async def _guess_main_entry(self, entries: List[Path]) -> Path:
if len(entries) == 1:
return entries[0]
file_list = "## File List\n"
file_list += "\n".join([f"- {i}" for i in entries])
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_USAGE)
usage = "## Usage\n"
for r in rows:
if Path(r.subject).stem == "README":
usage += r.object_
prompt = file_list + "\n---\n" + usage
rsp = await self.llm.aask(
prompt,
system_msgs=[
'You are a tool to choose the source file path from "File List" which is used in "Usage".',
'You choose the source file path based on the name of file and the class name and package name used in "Usage".',
"Return a markdown JSON object containing:\n"
'- a "filename" field containing the chosen source file path from "File List" which is used in "Usage";\n'
'- a "reason" field explaining why.',
],
stream=False,
)
logger.debug(rsp)
json_blocks = parse_json_code_block(rsp)
class Data(BaseModel):
filename: str
reason: str
data = Data.model_validate_json(json_blocks[0])
logger.info(f"main: {data.filename}")
return Path(data.filename)

View file

@ -14,8 +14,6 @@ from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class PrepareDocuments(Action):
@ -38,8 +36,7 @@ class PrepareDocuments(Action):
if path.exists() and not self.config.inc:
shutil.rmtree(path)
self.config.project_path = path
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
self.context.repo = ProjectRepo(self.context.git_repo)
self.context.set_repo_dir(path)
async def run(self, with_messages, **kwargs):
"""Create and initialize the workspace folder, initialize the Git environment."""

View file

@ -244,15 +244,6 @@ class RebuildSequenceView(Action):
class_view = await self._get_uml_class_view(ns_class_name)
source_code = await self._get_source_code(ns_class_name)
# prompt_blocks = [
# "## Instruction\n"
# "You are a python code to UML 2.0 Use Case translator.\n"
# 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n'
# "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
# 'conflict with the information in "Mermaid Class Views".\n'
# 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
# "system interactions with the internal system.\n"
# ]
prompt_blocks = []
block = "## Participants\n"
for p in participants:
@ -340,6 +331,7 @@ class RebuildSequenceView(Action):
system_msgs=[
"You are a Mermaid Sequence Diagram translator in function detail.",
"Translate the markdown text to a Mermaid Sequence Diagram.",
"Response must be concise.",
"Return a markdown mermaid code block.",
],
stream=False,
@ -440,7 +432,7 @@ class RebuildSequenceView(Action):
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
filename = split_namespace(ns_class_name=ns_class_name)[0]
if not rows:
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
src_filename = RebuildSequenceView.get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
return ""
return await aread(filename=src_filename, encoding="utf-8")
@ -450,7 +442,7 @@ class RebuildSequenceView(Action):
)
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
def get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
"""
Convert package name to the full path of the module.
@ -466,7 +458,7 @@ class RebuildSequenceView(Action):
"metagpt/management/skill_manager.py", then the returned value will be
"/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py"
"""
if re.match(r"^/.+", pathname):
if re.match(r"^/.+", str(pathname)):
return pathname
files = list_files(root=root)
postfix = "/" + str(pathname)

View file

@ -128,6 +128,9 @@ CODE_PLAN_AND_CHANGE_CONTEXT = """
## User New Requirements
{requirement}
## Issue
{issue}
## PRD
{prd}
@ -211,7 +214,8 @@ class WriteCodePlanAndChange(Action):
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
requirement=self.i_context.requirement,
requirement=f"```text\n{self.i_context.requirement}\n```",
issue=f"```text\n{self.i_context.issue}\n```",
prd=prd_doc.content,
design=design_doc.content,
task=task_doc.content,

View file

@ -133,10 +133,10 @@ REQUIREMENT_ANALYSIS = ActionNode(
REFINED_REQUIREMENT_ANALYSIS = ActionNode(
key="Refined Requirement Analysis",
expected_type=List[str],
instruction="Review and refine the existing requirement analysis to align with the evolving needs of the project "
instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project "
"due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements "
"required for the refined project scope.",
example=["Require add/update/modify ..."],
example=["Require add ...", "Require modify ..."],
)
REQUIREMENT_POOL = ActionNode(

View file

@ -47,7 +47,7 @@ class Config(CLIParams, YamlModel):
# Key Parameters
llm: LLMConfig
# Global Proxy. Will be used if llm.proxy is not set
# Global Proxy. Not used by LLM, but by other tools such as browsers.
proxy: str = ""
# Tool Parameters

View file

@ -5,9 +5,11 @@
@Author : alexanderwu
@File : context.py
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict
@ -78,11 +80,10 @@ class Context(BaseModel):
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
# """Use a LLM instance"""
# self._llm_config = self.config.get_llm_config(name, provider)
# self._llm = None
# return self._llm
def set_repo_dir(self, path: str | Path):
repo_path = Path(path)
self.git_repo = GitRepository(local_path=repo_path, auto_init=True)
self.repo = ProjectRepo(self.git_repo)
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
"""Return a CostManager instance"""
@ -108,3 +109,38 @@ class Context(BaseModel):
if llm.cost_manager is None:
llm.cost_manager = self._select_costmanager(llm_config)
return llm
def serialize(self) -> Dict[str, Any]:
"""Serialize the object's attributes into a dictionary.
Returns:
Dict[str, Any]: A dictionary containing serialized data.
"""
return {
"workdir": str(self.repo.workdir) if self.repo else "",
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
"cost_manager": self.cost_manager.model_dump_json(),
}
def deserialize(self, serialized_data: Dict[str, Any]):
"""Deserialize the given serialized data and update the object's attributes accordingly.
Args:
serialized_data (Dict[str, Any]): A dictionary containing serialized data.
"""
if not serialized_data:
return
workdir = serialized_data.get("workdir")
if workdir:
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
self.repo = ProjectRepo(self.git_repo)
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
if src_workspace.exists():
self.src_workspace = src_workspace
kwargs = serialized_data.get("kwargs")
if kwargs:
for k, v in kwargs.items():
self.kwargs.set(k, v)
cost_manager = serialized_data.get("cost_manager")
if cost_manager:
self.cost_manager.model_validate_json(cost_manager)

View file

@ -1,9 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
import json
import os
from typing import Optional, Union
from dataclasses import asdict
from typing import List, Optional, Union
import google.generativeai as genai
from google.ai import generativelanguage as glm
@ -11,6 +12,7 @@ from google.generativeai.generative_models import GenerativeModel
from google.generativeai.types import content_types
from google.generativeai.types.generation_types import (
AsyncGenerateContentResponse,
BlockedPromptException,
GenerateContentResponse,
GenerationConfig,
)
@ -141,7 +143,11 @@ class GeminiLLM(BaseLLM):
)
collected_content = []
async for chunk in resp:
content = chunk.text
try:
content = chunk.text
except Exception as e:
logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}")
raise BlockedPromptException(str(chunk))
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
@ -150,3 +156,10 @@ class GeminiLLM(BaseLLM):
usage = await self.aget_usage(messages, full_content)
self._update_costs(usage)
return full_content
def list_models(self) -> List:
models = []
for model in genai.list_models(page_size=100):
models.append(asdict(model))
logger.info(json.dumps(models))
return models

View file

@ -17,7 +17,7 @@ class HumanProvider(BaseLLM):
"""
def __init__(self, config: LLMConfig):
pass
self.config = config
def ask(self, msg: str, timeout=USE_CONFIG_TIMEOUT) -> str:
logger.info("It's your turn, please type in your response. You may also refer to the context below")

View file

@ -1,5 +1,6 @@
"""Engines init"""
from metagpt.rag.engines.simple import SimpleEngine
from metagpt.rag.engines.flare import FLAREEngine
__all__ = ["SimpleEngine"]
__all__ = ["SimpleEngine", "FLAREEngine"]

View file

@ -0,0 +1,9 @@
"""FLARE Engine.
Use llamaindex's FLAREInstructQueryEngine as FLAREEngine, which accepts other engines as parameters.
For example, Create a simple engine, and then pass it to FLAREEngine.
"""
from llama_index.core.query_engine import ( # noqa: F401
FLAREInstructQueryEngine as FLAREEngine,
)

View file

@ -130,10 +130,12 @@ class SimpleEngine(RetrieverQueryEngine):
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
objs = objs or []
retriever_configs = retriever_configs or []
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
objs = objs or []
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
nodes=nodes,

View file

@ -41,7 +41,7 @@ class ConfigBasedFactory(GenericFactory):
if creator:
return creator(key, **kwargs)
raise ValueError(f"Unknown config: {key}")
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:

View file

@ -4,6 +4,8 @@ import chromadb
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
@ -11,6 +13,8 @@ from metagpt.rag.schema import (
BaseIndexConfig,
BM25IndexConfig,
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
@ -22,6 +26,8 @@ class RAGIndexFactory(ConfigBasedFactory):
FAISSIndexConfig: self._create_faiss,
ChromaIndexConfig: self._create_chroma,
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
}
super().__init__(creators)
@ -30,31 +36,44 @@ class RAGIndexFactory(ConfigBasedFactory):
return super().get_instance(config, **kwargs)
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=embed_model,
)
return index
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
def _index_from_storage(
self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs
) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
return load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
def _index_from_vector_store(
self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs
) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
return VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=embed_model,
)
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)

View file

@ -33,7 +33,9 @@ class RAGLLM(CustomLLM):
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
return LLMMetadata(
context_window=self.context_window, num_output=self.num_output, model_name=self.model_name or "unknown"
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:

View file

@ -3,9 +3,16 @@
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.colbert_rerank import ColbertRerank
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import (
BaseRankerConfig,
ColbertRerankConfig,
LLMRankerConfig,
ObjectRankerConfig,
)
class RankerFactory(ConfigBasedFactory):
@ -14,6 +21,8 @@ class RankerFactory(ConfigBasedFactory):
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
ColbertRerankConfig: self._create_colbert_ranker,
ObjectRankerConfig: self._create_object_ranker,
}
super().__init__(creators)
@ -28,6 +37,12 @@ class RankerFactory(ConfigBasedFactory):
config.llm = self._extract_llm(config, **kwargs)
return LLMRerank(**config.model_dump())
def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
return ColbertRerank(**config.model_dump())
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
return ObjectSortPostprocessor(**config.model_dump())
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)

View file

@ -6,18 +6,22 @@ import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
@ -32,6 +36,8 @@ class RetrieverFactory(ConfigBasedFactory):
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
}
super().__init__(creators)
@ -53,20 +59,29 @@ class RetrieverFactory(ConfigBasedFactory):
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = list(config.index.docstore.docs.values())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return ChromaRetriever(**config.model_dump())
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return ElasticsearchRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)

View file

@ -0,0 +1,55 @@
"""Object ranker."""
import heapq
import json
from typing import Literal, Optional
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle
from pydantic import Field
from metagpt.rag.schema import ObjectNode
class ObjectSortPostprocessor(BaseNodePostprocessor):
"""Sorted by object's field, desc or asc.
Assumes nodes is list of ObjectNode with score.
"""
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
top_n: int = 5
@classmethod
def class_name(cls) -> str:
return "ObjectSortPostprocessor"
def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> list[NodeWithScore]:
"""Postprocess nodes."""
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if not nodes:
return []
self._check_metadata(nodes[0].node)
sort_key = lambda node: json.loads(node.node.metadata["obj_json"])[self.field_name]
return self._get_sort_func()(self.top_n, nodes, key=sort_key)
def _check_metadata(self, node: ObjectNode):
try:
obj_dict = json.loads(node.metadata.get("obj_json"))
except Exception as e:
raise ValueError(f"Invalid object json in metadata: {node.metadata}, error: {e}")
if self.field_name not in obj_dict:
raise ValueError(f"Field '{self.field_name}' not found in object: {obj_dict}")
def _get_sort_func(self):
return heapq.nlargest if self.order == "desc" else heapq.nsmallest

View file

@ -0,0 +1,17 @@
"""Elasticsearch retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class ElasticsearchRetriever(VectorIndexRetriever):
"""Elasticsearch retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Elasticsearch automatically saves, so there is no need to implement."""

View file

@ -8,7 +8,7 @@ class FAISSRetriever(VectorIndexRetriever):
"""FAISS retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes"""
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:

View file

@ -1,11 +1,12 @@
"""RAG schemas."""
from pathlib import Path
from typing import Any, Union
from typing import Any, Literal, Union
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from metagpt.rag.interface import RAGObject
@ -46,6 +47,35 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class ElasticsearchStoreConfig(BaseModel):
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
es_url: str = Field(default=None, description="Elasticsearch URL.")
es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
es_api_key: str = Field(default=None, description="Elasticsearch API key.")
es_user: str = Field(default=None, description="Elasticsearch username.")
es_password: str = Field(default=None, description="Elasticsearch password.")
batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")
class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
"""Config for Elasticsearch-based retrievers. Support both vector and text."""
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
vector_store_query_mode: VectorStoreQueryMode = Field(
default=VectorStoreQueryMode.DEFAULT, description="default is vector query."
)
class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig):
"""Config for Elasticsearch-based retrievers. Support text only."""
_no_embedding: bool = PrivateAttr(default=True)
vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field(
default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only."
)
class BaseRankerConfig(BaseModel):
"""Common config for rankers.
@ -53,7 +83,6 @@ class BaseRankerConfig(BaseModel):
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
top_n: int = Field(default=5, description="The number of top results to return.")
@ -66,12 +95,24 @@ class LLMRankerConfig(BaseRankerConfig):
)
class ColbertRerankConfig(BaseRankerConfig):
model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.")
device: str = Field(default="cpu", description="Device to use for sentence transformer.")
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
class ObjectRankerConfig(BaseRankerConfig):
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
class BaseIndexConfig(BaseModel):
"""Common config for index.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
@ -97,6 +138,19 @@ class BM25IndexConfig(BaseIndexConfig):
_no_embedding: bool = PrivateAttr(default=True)
class ElasticsearchIndexConfig(VectorIndexConfig):
"""Config for es-based index."""
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
persist_path: Union[str, Path] = ""
class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig):
"""Config for es-based index. no embedding."""
_no_embedding: bool = PrivateAttr(default=True)
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""

View file

@ -22,7 +22,7 @@ from __future__ import annotations
import json
from collections import defaultdict
from pathlib import Path
from typing import Set
from typing import List, Optional, Set
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
from metagpt.actions.fix_bug import FixBug
@ -30,6 +30,7 @@ from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange
from metagpt.const import (
BUGFIX_FILENAME,
CODE_PLAN_AND_CHANGE_FILE_REPO,
REQUIREMENT_FILENAME,
SYSTEM_DESIGN_FILE_REPO,
@ -45,7 +46,13 @@ from metagpt.schema import (
Documents,
Message,
)
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set
from metagpt.utils.common import (
any_to_name,
any_to_str,
any_to_str_set,
get_project_srcs_path,
init_python_folder,
)
IS_PASS_PROMPT = """
{context}
@ -239,7 +246,7 @@ class Engineer(Role):
async def _think(self) -> Action | None:
if not self.src_workspace:
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
self.src_workspace = get_project_srcs_path(self.project_repo.workdir)
write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug])
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode])
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
@ -248,11 +255,11 @@ class Engineer(Role):
msg = self.rc.news[0]
if self.config.inc and msg.cause_by in write_plan_and_change_filters:
logger.debug(f"TODO WriteCodePlanAndChange:{msg.model_dump_json()}")
await self._new_code_plan_and_change_action()
await self._new_code_plan_and_change_action(cause_by=msg.cause_by)
return self.rc.todo
if msg.cause_by in write_code_filters:
logger.debug(f"TODO WriteCode:{msg.model_dump_json()}")
await self._new_code_actions(bug_fix=msg.cause_by == any_to_str(FixBug))
await self._new_code_actions()
return self.rc.todo
if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self):
logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}")
@ -260,14 +267,14 @@ class Engineer(Role):
return self.rc.todo
return None
async def _new_coding_context(self, filename, dependency) -> CodingContext:
async def _new_coding_context(self, filename, dependency) -> Optional[CodingContext]:
old_code_doc = await self.project_repo.srcs.get(filename)
if not old_code_doc:
old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="")
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
task_doc = None
design_doc = None
code_plan_and_change_doc = None
code_plan_and_change_doc = await self._get_any_code_plan_and_change() if await self._is_fixbug() else None
for i in dependencies:
if str(i.parent) == TASK_FILE_REPO:
task_doc = await self.project_repo.docs.task.get(i.name)
@ -276,6 +283,8 @@ class Engineer(Role):
elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO:
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name)
if not task_doc or not design_doc:
if filename == "__init__.py": # `__init__.py` created by `init_python_folder`
return None
logger.error(f'Detected source code "{filename}" from an unknown origin.')
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
context = CodingContext(
@ -287,14 +296,17 @@ class Engineer(Role):
)
return context
async def _new_coding_doc(self, filename, dependency):
async def _new_coding_doc(self, filename, dependency) -> Optional[Document]:
context = await self._new_coding_context(filename, dependency)
if not context:
return None # `__init__.py` created by `init_python_folder`
coding_doc = Document(
root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json()
)
return coding_doc
async def _new_code_actions(self, bug_fix=False):
async def _new_code_actions(self):
bug_fix = await self._is_fixbug()
# Prepare file repos
changed_src_files = self.project_repo.srcs.all_files if bug_fix else self.project_repo.srcs.changed_files
changed_task_files = self.project_repo.docs.task.changed_files
@ -305,6 +317,7 @@ class Engineer(Role):
task_doc = await self.project_repo.docs.task.get(filename)
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename)
task_list = self._parse_tasks(task_doc)
await self._init_python_folder(task_list)
for task_filename in task_list:
old_code_doc = await self.project_repo.srcs.get(task_filename)
if not old_code_doc:
@ -343,6 +356,8 @@ class Engineer(Role):
if filename in changed_files.docs:
continue
coding_doc = await self._new_coding_doc(filename=filename, dependency=dependency)
if not coding_doc:
continue # `__init__.py` created by `init_python_folder`
changed_files.docs[filename] = coding_doc
self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm))
@ -358,6 +373,8 @@ class Engineer(Role):
ctx = CodeSummarizeContext.loads(filenames=list(dependencies))
summarizations[ctx].append(filename)
for ctx, filenames in summarizations.items():
if not ctx.design_filename or not ctx.task_filename:
continue # cause by `__init__.py` which is created by `init_python_folder`
ctx.codes_filenames = filenames
new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm)
for i, act in enumerate(self.summarize_todos):
@ -371,15 +388,40 @@ class Engineer(Role):
self.set_todo(self.summarize_todos[0])
self.summarize_todos.pop(0)
async def _new_code_plan_and_change_action(self):
async def _new_code_plan_and_change_action(self, cause_by: str):
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""
files = self.project_repo.all_files
requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME)
requirement = requirement_doc.content if requirement_doc else ""
code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, requirement=requirement)
options = {}
if cause_by != any_to_str(FixBug):
requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME)
options["requirement"] = requirement_doc.content
else:
fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME)
options["issue"] = fixbug_doc.content
code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, **options)
self.rc.todo = WriteCodePlanAndChange(i_context=code_plan_and_change_ctx, context=self.context, llm=self.llm)
@property
def action_description(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.next_todo_action
async def _init_python_folder(self, task_list: List[str]):
for i in task_list:
filename = Path(i)
if filename.suffix != ".py":
continue
workdir = self.src_workspace / filename.parent
await init_python_folder(workdir)
async def _is_fixbug(self) -> bool:
fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME)
return bool(fixbug_doc and fixbug_doc.content)
async def _get_any_code_plan_and_change(self) -> Optional[Document]:
changed_files = self.project_repo.docs.code_plan_and_change.changed_files
for filename in changed_files.keys():
doc = await self.project_repo.docs.code_plan_and_change.get(filename)
if doc and doc.content:
return doc
return None

View file

@ -21,7 +21,7 @@ from metagpt.const import MESSAGE_ROUTE_TO_NONE
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Document, Message, RunCodeContext, TestingContext
from metagpt.utils.common import any_to_str_set, parse_recipient
from metagpt.utils.common import any_to_str_set, init_python_folder, parse_recipient
class QaEngineer(Role):
@ -141,6 +141,7 @@ class QaEngineer(Role):
)
async def _act(self) -> Message:
await init_python_folder(self.project_repo.tests.workdir)
if self.test_round > self.test_round_allowed:
result_msg = Message(
content=f"Exceeding {self.test_round_allowed} rounds of tests, skip (writing code counts as a round, too)",

View file

@ -677,13 +677,14 @@ class BugFixContext(BaseContext):
class CodePlanAndChangeContext(BaseModel):
requirement: str = ""
issue: str = ""
prd_filename: str = ""
design_filename: str = ""
task_filename: str = ""
@staticmethod
def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext:
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""))
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", ""))
for filename in filenames:
filename = Path(filename)
if filename.is_relative_to(PRDS_FILE_REPO):

View file

@ -56,8 +56,10 @@ class Team(BaseModel):
def serialize(self, stg_path: Path = None):
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
team_info_path = stg_path.joinpath("team.json")
serialized_data = self.model_dump()
serialized_data["context"] = self.env.context.serialize()
write_json_file(team_info_path, self.model_dump())
write_json_file(team_info_path, serialized_data)
@classmethod
def deserialize(cls, stg_path: Path, context: Context = None) -> "Team":
@ -71,6 +73,7 @@ class Team(BaseModel):
team_info: dict = read_json_file(team_info_path)
ctx = context or Context()
ctx.deserialize(team_info.pop("context", None))
team = Team(**team_info, context=ctx)
return team

View file

@ -12,6 +12,15 @@ from metagpt.tools.libs import (
web_scraping,
email_login,
)
from metagpt.tools.libs.software_development import (
write_prd,
write_design,
write_project_plan,
write_codes,
run_qa_test,
fix_bug,
git_archive,
)
_ = (
data_preprocess,
@ -20,4 +29,11 @@ _ = (
gpt_v_generator,
web_scraping,
email_login,
write_prd,
write_design,
write_project_plan,
write_codes,
run_qa_test,
fix_bug,
git_archive,
) # Avoid pre-commit error

65
metagpt/tools/libs/git.py Normal file
View file

@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
from pathlib import Path
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.git_repository import GitRepository
@register_tool(tags=["git"])
async def git_clone(url: str, output_dir: str | Path = None) -> Path:
"""
Clones a Git repository from the given URL.
Args:
url (str): The URL of the Git repository to clone.
output_dir (str or Path, optional): The directory where the repository will be cloned.
If not provided, the repository will be cloned into the current working directory.
Returns:
Path: The path to the cloned repository.
Raises:
ValueError: If the specified Git root is invalid.
Example:
>>> # git clone to /TO/PATH
>>> url = 'https://github.com/geekan/MetaGPT.git'
>>> output_dir = "/TO/PATH"
>>> repo_dir = await git_clone(url=url, output_dir=output_dir)
>>> print(repo_dir)
/TO/PATH/MetaGPT
>>> # git clone to default directory.
>>> url = 'https://github.com/geekan/MetaGPT.git'
>>> repo_dir = await git_clone(url)
>>> print(repo_dir)
/WORK_SPACE/downloads/MetaGPT
"""
repo = await GitRepository.clone_from(url, output_dir)
return repo.workdir
async def git_checkout(repo_dir: str | Path, commit_id: str):
"""
Checks out a specific commit in a Git repository.
Args:
repo_dir (str or Path): The directory containing the Git repository.
commit_id (str): The ID of the commit to check out.
Raises:
ValueError: If the specified Git root is invalid.
Example:
>>> repo_dir = '/TO/GIT/REPO'
>>> commit_id = 'main'
>>> await git_checkout(repo_dir=repo_dir, commit_id=commit_id)
git checkout main
"""
repo = GitRepository(local_path=repo_dir, auto_init=False)
if not repo.is_valid:
ValueError(f"Invalid git root: {repo_dir}")
await repo.checkout(commit_id)

View file

@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import Dict, List, Tuple, Union
from metagpt.tools.tool_registry import register_tool
@register_tool(tags=["shell"])
async def shell_execute(
command: Union[List[str], str], cwd: str | Path = None, env: Dict = None, timeout: int = 600
) -> Tuple[str, str, int]:
"""
Execute a command asynchronously and return its standard output and standard error.
Args:
command (Union[List[str], str]): The command to execute and its arguments. It can be provided either as a list
of strings or as a single string.
cwd (str | Path, optional): The current working directory for the command. Defaults to None.
env (Dict, optional): Environment variables to set for the command. Defaults to None.
timeout (int, optional): Timeout for the command execution in seconds. Defaults to 600.
Returns:
Tuple[str, str, int]: A tuple containing the string type standard output and string type standard error of the executed command and int type return code.
Raises:
ValueError: If the command times out, this error is raised. The error message contains both standard output and
standard error of the timed-out process.
Example:
>>> # command is a list
>>> stdout, stderr, returncode = await shell_execute(command=["ls", "-l"], cwd="/home/user", env={"PATH": "/usr/bin"})
>>> print(stdout)
total 8
-rw-r--r-- 1 user user 0 Mar 22 10:00 file1.txt
-rw-r--r-- 1 user user 0 Mar 22 10:00 file2.txt
...
>>> # command is a string of shell script
>>> stdout, stderr, returncode = await shell_execute(command="ls -l", cwd="/home/user", env={"PATH": "/usr/bin"})
>>> print(stdout)
total 8
-rw-r--r-- 1 user user 0 Mar 22 10:00 file1.txt
-rw-r--r-- 1 user user 0 Mar 22 10:00 file2.txt
...
References:
This function uses `subprocess.Popen` for executing shell commands asynchronously.
"""
cwd = str(cwd) if cwd else None
shell = True if isinstance(command, str) else False
process = subprocess.Popen(command, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, shell=shell)
try:
# Wait for the process to complete, with a timeout
stdout, stderr = process.communicate(timeout=timeout)
return stdout.decode("utf-8"), stderr.decode("utf-8"), process.returncode
except subprocess.TimeoutExpired:
process.kill() # Kill the process if it times out
stdout, stderr = process.communicate()
raise ValueError(f"{stdout.decode('utf-8')}\n{stderr.decode('utf-8')}")

View file

@ -0,0 +1,301 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
from pathlib import Path
from typing import Optional
from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME
from metagpt.schema import BugFixContext, Message
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import any_to_str
@register_tool(tags=["software development", "ProductManager"])
async def write_prd(idea: str, project_path: Optional[str | Path] = None) -> Path:
"""Writes a PRD based on user requirements.
Args:
idea (str): The idea or concept for the PRD.
project_path (Optional[str|Path], optional): The path to an existing project directory.
If it's None, a new project path will be created. Defaults to None.
Returns:
Path: The path to the PRD files under the project directory
Example:
>>> # Create a new project:
>>> from metagpt.tools.libs.software_development import write_prd
>>> prd_path = await write_prd("Create a new feature for the application")
>>> print(prd_path)
'/path/to/project_path/docs/prd/'
>>> # Add user requirements to the exists project:
>>> from metagpt.tools.libs.software_development import write_prd
>>> project_path = '/path/to/exists_project_path'
>>> prd_path = await write_prd("Create a new feature for the application", project_path=project_path)
>>> print(prd_path = )
'/path/to/project_path/docs/prd/'
"""
from metagpt.actions import UserRequirement
from metagpt.context import Context
from metagpt.roles import ProductManager
ctx = Context()
if project_path:
ctx.config.project_path = Path(project_path)
ctx.config.inc = True
role = ProductManager(context=ctx)
msg = await role.run(with_message=Message(content=idea, cause_by=UserRequirement))
await role.run(with_message=msg)
return ctx.repo.docs.prd.workdir
@register_tool(tags=["software development", "Architect"])
async def write_design(prd_path: str | Path) -> Path:
"""Writes a design to the project repository, based on the PRD of the project.
Args:
prd_path (str|Path): The path to the PRD files under the project directory.
Returns:
Path: The path to the system design files under the project directory.
Example:
>>> from metagpt.tools.libs.software_development import write_design
>>> prd_path = '/path/to/project_path/docs/prd' # Returned by `write_prd`
>>> system_design_path = await write_desgin(prd_path)
>>> print(system_design_path)
'/path/to/project_path/docs/system_design/'
"""
from metagpt.actions import WritePRD
from metagpt.context import Context
from metagpt.roles import Architect
ctx = Context()
project_path = Path(prd_path).parent.parent
ctx.set_repo_dir(project_path)
role = Architect(context=ctx)
await role.run(with_message=Message(content="", cause_by=WritePRD))
return ctx.repo.docs.system_design.workdir
@register_tool(tags=["software development", "Architect"])
async def write_project_plan(system_design_path: str | Path) -> Path:
"""Writes a project plan to the project repository, based on the design of the project.
Args:
system_design_path (str|Path): The path to the system design files under the project directory.
Returns:
Path: The path to task files under the project directory.
Example:
>>> from metagpt.tools.libs.software_development import write_project_plan
>>> system_design_path = '/path/to/project_path/docs/system_design/' # Returned by `write_design`
>>> task_path = await write_project_plan(system_design_path)
>>> print(task_path)
'/path/to/project_path/docs/task'
"""
from metagpt.actions import WriteDesign
from metagpt.context import Context
from metagpt.roles import ProjectManager
ctx = Context()
project_path = Path(system_design_path).parent.parent
ctx.set_repo_dir(project_path)
role = ProjectManager(context=ctx)
await role.run(with_message=Message(content="", cause_by=WriteDesign))
return ctx.repo.docs.task.workdir
@register_tool(tags=["software development", "Engineer"])
async def write_codes(task_path: str | Path, inc: bool = False) -> Path:
"""Writes codes to the project repository, based on the project plan of the project.
Args:
task_path (str|Path): The path to task files under the project directory.
inc (bool, optional): Whether to write incremental codes. Defaults to False.
Returns:
Path: The path to the source code files under the project directory.
Example:
# Write codes to a new project
>>> from metagpt.tools.libs.software_development import write_codes
>>> task_path = '/path/to/project_path/docs/task' # Returned by `write_project_plan`
>>> src_path = await write_codes(task_path)
>>> print(src_path)
'/path/to/project_path/src/'
# Write increment codes to the exists project
>>> from metagpt.tools.libs.software_development import write_codes
>>> task_path = '/path/to/project_path/docs/task' # Returned by `write_prd`
>>> src_path = await write_codes(task_path, inc=True)
>>> print(src_path)
'/path/to/project_path/src/'
"""
from metagpt.actions import WriteTasks
from metagpt.context import Context
from metagpt.roles import Engineer
ctx = Context()
ctx.config.inc = inc
project_path = Path(task_path).parent.parent
ctx.set_repo_dir(project_path)
role = Engineer(context=ctx)
msg = Message(content="", cause_by=WriteTasks, send_to=role)
me = {any_to_str(role), role.name}
while me.intersection(msg.send_to):
msg = await role.run(with_message=msg)
return ctx.repo.srcs.workdir
@register_tool(tags=["software development", "QaEngineer"])
async def run_qa_test(src_path: str | Path) -> Path:
"""Run QA test on the project repository.
Args:
src_path (str|Path): The path to the source code files under the project directory.
Returns:
Path: The path to the unit tests under the project directory
Example:
>>> from metagpt.tools.libs.software_development import run_qa_test
>>> src_path = '/path/to/project_path/src/' # Returned by `write_codes`
>>> test_path = await run_qa_test(src_path)
>>> print(test_path)
'/path/to/project_path/tests'
"""
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.context import Context
from metagpt.environment import Environment
from metagpt.roles import QaEngineer
ctx = Context()
project_path = Path(src_path).parent
ctx.set_repo_dir(project_path)
ctx.src_workspace = ctx.git_repo.workdir / ctx.git_repo.workdir.name
env = Environment(context=ctx)
role = QaEngineer(context=ctx)
env.add_role(role)
msg = Message(content="", cause_by=SummarizeCode, send_to=role)
env.publish_message(msg)
while not env.is_idle:
await env.run()
return ctx.repo.tests.workdir
@register_tool(tags=["software development", "Engineer"])
async def fix_bug(project_path: str | Path, issue: str) -> Path:
"""Fix bugs in the project repository.
Args:
project_path (str|Path): The path to the project repository.
issue (str): Description of the bug or issue.
Returns:
Path: The path to the project directory
Example:
>>> from metagpt.tools.libs.software_development import fix_bug
>>> project_path = '/path/to/project_path' # Returned by `write_codes`
>>> issue = 'Exception: exception about ...; Bug: bug about ...; Issue: issue about ...'
>>> project_path = await fix_bug(project_path=project_path, issue=issue)
>>> print(project_path)
'/path/to/project_path'
"""
from metagpt.actions.fix_bug import FixBug
from metagpt.context import Context
from metagpt.roles import Engineer
ctx = Context()
ctx.set_repo_dir(project_path)
ctx.src_workspace = ctx.git_repo.workdir / ctx.git_repo.workdir.name
await ctx.repo.docs.save(filename=BUGFIX_FILENAME, content=issue)
await ctx.repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
role = Engineer(context=ctx)
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
msg = Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=role,
send_to=role,
)
me = {any_to_str(role), role.name}
while me.intersection(msg.send_to):
msg = await role.run(with_message=msg)
return project_path
@register_tool(tags=["software development", "git"])
async def git_archive(project_path: str | Path) -> str:
"""Stage and commit changes for the project repository using Git.
Args:
project_path (str|Path): The path to the project repository.
Returns:
git log
Example:
>>> from metagpt.tools.libs.software_development import git_archive
>>> project_path = '/path/to/project_path' # Returned by `write_prd`
>>> git_log = await git_archive(project_path=project_path)
>>> print(git_log)
commit a221d1c418c07f2b4fc07001e486285ead1a520a (HEAD -> feature/toollib/software_company, geekan/main)
Merge: e01afd09 4a72f398
Author: Sirui Hong <x@xx.github.com>
Date: Tue Mar 19 15:16:03 2024 +0800
Merge pull request #1037 from iorisa/fixbug/issues/1018
fixbug: #1018
"""
from metagpt.context import Context
ctx = Context()
ctx.set_repo_dir(project_path)
ctx.git_repo.archive()
return ctx.git_repo.log()
@register_tool(tags=["software development", "import git repo"])
async def import_git_repo(url: str) -> Path:
"""
Imports a project from a Git website and formats it to MetaGPT project format to enable incremental appending requirements.
Args:
url (str): The Git project URL, such as "https://github.com/geekan/MetaGPT.git".
Returns:
Path: The path of the formatted project.
Example:
# The Git project URL to input
>>> git_url = "https://github.com/geekan/MetaGPT.git"
# Import the Git repository and get the formatted project path
>>> formatted_project_path = await import_git_repo(git_url)
>>> print("Formatted project path:", formatted_project_path)
/PATH/TO/THE/FORMMATTED/PROJECT
"""
from metagpt.actions.import_repo import ImportRepo
from metagpt.context import Context
ctx = Context()
action = ImportRepo(repo_path=url, context=ctx)
await action.run()
return ctx.repo.workdir

View file

@ -822,19 +822,78 @@ See FAQ 5.8
raise retry_state.outcome.exception()
def get_markdown_codeblock_type(filename: str) -> str:
async def get_mime_type(filename: str | Path, force_read: bool = False) -> str:
guess_mime_type, _ = mimetypes.guess_type(filename.name)
if not guess_mime_type:
ext_mappings = {".yml": "text/yaml", ".yaml": "text/yaml"}
guess_mime_type = ext_mappings.get(filename.suffix)
if not force_read and guess_mime_type:
return guess_mime_type
from metagpt.tools.libs.shell import shell_execute # avoid circular import
text_set = {
"application/json",
"application/vnd.chipnuts.karaoke-mmd",
"application/javascript",
"application/xml",
"application/x-sh",
"application/sql",
"text/yaml",
}
try:
stdout, _, _ = await shell_execute(f"file --mime-type {str(filename)}")
ix = stdout.rfind(" ")
mime_type = stdout[ix:].strip()
if mime_type == "text/plain" and guess_mime_type in text_set:
return guess_mime_type
return mime_type
except Exception as e:
logger.debug(f"file:{filename}, error:{e}")
return "unknown"
def get_markdown_codeblock_type(filename: str = None, mime_type: str = None) -> str:
"""Return the markdown code-block type corresponding to the file extension."""
mime_type, _ = mimetypes.guess_type(filename)
if not filename and not mime_type:
raise ValueError("Either filename or mime_type must be valid.")
if not mime_type:
mime_type, _ = mimetypes.guess_type(filename)
mappings = {
"text/x-shellscript": "bash",
"text/x-c++src": "cpp",
"text/css": "css",
"text/html": "html",
"text/x-java": "java",
"application/javascript": "javascript",
"application/json": "json",
"text/x-python": "python",
"text/x-ruby": "ruby",
"text/x-c": "cpp",
"text/yaml": "yaml",
"application/javascript": "javascript",
"application/json": "json",
"application/sql": "sql",
"application/vnd.chipnuts.karaoke-mmd": "mermaid",
"application/x-sh": "bash",
"application/xml": "xml",
}
return mappings.get(mime_type, "text")
def get_project_srcs_path(workdir: str | Path) -> Path:
src_workdir_path = workdir / ".src_workspace"
if src_workdir_path.exists():
with open(src_workdir_path, "r") as file:
src_name = file.read()
else:
src_name = Path(workdir).name
return Path(workdir) / src_name
async def init_python_folder(workdir: str | Path):
init_filename = Path(workdir) / "__init__.py"
if init_filename.exists():
return
async with aiofiles.open(init_filename, "a"):
os.utime(init_filename, None)

View file

@ -9,6 +9,7 @@
from __future__ import annotations
import shutil
import uuid
from enum import Enum
from pathlib import Path
from typing import Dict, List
@ -16,8 +17,10 @@ from typing import Dict, List
from git.repo import Repo
from git.repo.fun import is_git_dir
from gitignore_parser import parse_gitignore
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.logs import logger
from metagpt.tools.libs.shell import shell_execute
from metagpt.utils.dependency_file import DependencyFile
from metagpt.utils.file_repository import FileRepository
@ -52,7 +55,7 @@ class GitRepository:
self._dependency = None
self._gitignore_rules = None
if local_path:
self.open(local_path=local_path, auto_init=auto_init)
self.open(local_path=Path(local_path), auto_init=auto_init)
def open(self, local_path: Path, auto_init=False):
"""Open an existing Git repository or initialize a new one if auto_init is True.
@ -68,7 +71,7 @@ class GitRepository:
if not auto_init:
return
local_path.mkdir(parents=True, exist_ok=True)
return self._init(local_path)
self._init(local_path)
def _init(self, local_path: Path):
"""Initialize a new Git repository at the specified path.
@ -248,6 +251,8 @@ class GitRepository:
if not directory_path.exists():
return []
for file_path in directory_path.iterdir():
if not file_path.is_relative_to(root_relative_path):
continue
if file_path.is_file():
rpath = file_path.relative_to(root_relative_path)
files.append(str(rpath))
@ -283,3 +288,37 @@ class GitRepository:
continue
files.append(filename)
return files
@classmethod
@retry(wait=wait_random_exponential(min=1, max=15), stop=stop_after_attempt(3))
async def clone_from(cls, url: str | Path, output_dir: str | Path = None) -> "GitRepository":
from metagpt.context import Context
to_path = Path(output_dir or Path(__file__).parent / f"../../workspace/downloads/{uuid.uuid4().hex}").resolve()
to_path.mkdir(parents=True, exist_ok=True)
repo_dir = to_path / Path(url).stem
if repo_dir.exists():
shutil.rmtree(repo_dir, ignore_errors=True)
ctx = Context()
env = ctx.new_environ()
proxy = ["-c", f"http.proxy={ctx.config.proxy}"] if ctx.config.proxy else []
command = ["git", "clone"] + proxy + [str(url)]
logger.info(" ".join(command))
stdout, stderr, return_code = await shell_execute(command=command, cwd=str(to_path), env=env, timeout=600)
info = f"{stdout}\n{stderr}\nexit: {return_code}\n"
logger.info(info)
dir_name = Path(url).stem
to_path = to_path / dir_name
if not cls.is_git_dir(to_path):
raise ValueError(info)
logger.info(f"git clone to {to_path}")
return GitRepository(local_path=to_path, auto_init=False)
async def checkout(self, commit_id: str):
self._repository.git.checkout(commit_id)
logger.info(f"git checkout {commit_id}")
def log(self) -> str:
"""Return git log"""
return self._repository.git.log()

View file

@ -49,6 +49,10 @@ class GraphKeyword:
IS_COMPOSITE_OF = "is_composite_of"
IS_AGGREGATE_OF = "is_aggregate_of"
HAS_PARTICIPANT = "has_participant"
HAS_SUMMARY = "has_summary"
HAS_INSTALL = "has_install"
HAS_CONFIG = "has_config"
HAS_USAGE = "has_usage"
class SPO(BaseModel):

View file

@ -35,6 +35,7 @@ from metagpt.const import (
TEST_OUTPUTS_FILE_REPO,
VISUAL_GRAPH_REPO_FILE_REPO,
)
from metagpt.utils.common import get_project_srcs_path
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
@ -129,11 +130,10 @@ class ProjectRepo(FileRepository):
return self._git_repo.new_file_repository(self._srcs_path)
def code_files_exists(self) -> bool:
git_workdir = self.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
src_workdir = get_project_srcs_path(self.git_repo.workdir)
if not src_workdir.exists():
return False
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
code_files = self.with_src_path(path=src_workdir).srcs.all_files
if not code_files:
return False
return bool(code_files)

View file

@ -5,17 +5,24 @@ This file provides functionality to convert a local repository into a markdown r
"""
from __future__ import annotations
import mimetypes
import re
from pathlib import Path
from typing import Tuple
from gitignore_parser import parse_gitignore
from metagpt.logs import logger
from metagpt.utils.common import aread, awrite, get_markdown_codeblock_type, list_files
from metagpt.utils.common import (
aread,
awrite,
get_markdown_codeblock_type,
get_mime_type,
list_files,
)
from metagpt.utils.tree import tree
async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, gitignore: str | Path = None) -> str:
async def repo_to_markdown(repo_path: str | Path, output: str | Path = None) -> str:
"""
Convert a local repository into a markdown representation.
@ -25,56 +32,108 @@ async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, git
Args:
repo_path (str | Path): The path to the local repository.
output (str | Path, optional): The path to save the generated markdown file. Defaults to None.
gitignore (str | Path, optional): The path to the .gitignore file. Defaults to None.
Returns:
str: The markdown representation of the repository.
"""
repo_path = Path(repo_path)
gitignore = Path(gitignore or Path(__file__).parent / "../../.gitignore").resolve()
repo_path = Path(repo_path).resolve()
gitignore_file = repo_path / ".gitignore"
markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore)
markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore_file)
gitignore_rules = parse_gitignore(full_path=str(gitignore))
gitignore_rules = parse_gitignore(full_path=str(gitignore_file)) if gitignore_file.exists() else None
markdown += await _write_files(repo_path=repo_path, gitignore_rules=gitignore_rules)
if output:
await awrite(filename=str(output), data=markdown, encoding="utf-8")
output_file = Path(output).resolve()
output_file.parent.mkdir(parents=True, exist_ok=True)
await awrite(filename=str(output_file), data=markdown, encoding="utf-8")
logger.info(f"save: {output_file}")
return markdown
async def _write_dir_tree(repo_path: Path, gitignore: Path) -> str:
try:
content = tree(repo_path, gitignore, run_command=True)
content = await tree(repo_path, gitignore, run_command=True)
except Exception as e:
logger.info(f"{e}, using safe mode.")
content = tree(repo_path, gitignore, run_command=False)
content = await tree(repo_path, gitignore, run_command=False)
doc = f"## Directory Tree\n```text\n{content}\n```\n---\n\n"
return doc
async def _write_files(repo_path, gitignore_rules) -> str:
async def _write_files(repo_path, gitignore_rules=None) -> str:
filenames = list_files(repo_path)
markdown = ""
pattern = r"^\..*" # Hidden folders/files
for filename in filenames:
if gitignore_rules(str(filename)):
if gitignore_rules and gitignore_rules(str(filename)):
continue
ignore = False
for i in filename.parts:
if re.match(pattern, i):
ignore = True
break
if ignore:
continue
markdown += await _write_file(filename=filename, repo_path=repo_path)
return markdown
async def _write_file(filename: Path, repo_path: Path) -> str:
relative_path = filename.relative_to(repo_path)
markdown = f"## {relative_path}\n"
mime_type, _ = mimetypes.guess_type(filename.name)
if "text/" not in mime_type:
is_text, mime_type = await _is_text_file(filename)
if not is_text:
logger.info(f"Ignore content: {filename}")
markdown += "<binary file>\n---\n\n"
return ""
try:
relative_path = filename.relative_to(repo_path)
markdown = f"## {relative_path}\n"
content = await aread(filename, encoding="utf-8")
content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
code_block_type = get_markdown_codeblock_type(filename.name)
markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
return markdown
content = await aread(filename, encoding="utf-8")
content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
code_block_type = get_markdown_codeblock_type(filename.name)
markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
return markdown
except Exception as e:
logger.error(e)
return ""
async def _is_text_file(filename: Path) -> Tuple[bool, str]:
pass_set = {
"application/json",
"application/vnd.chipnuts.karaoke-mmd",
"application/javascript",
"application/xml",
"application/x-sh",
"application/sql",
}
denied_set = {
"application/zlib",
"application/octet-stream",
"image/svg+xml",
"application/pdf",
"application/msword",
"application/vnd.ms-excel",
"audio/x-wav",
"application/x-git",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/zip",
"image/jpeg",
"audio/mpeg",
"video/mp2t",
"inode/x-empty",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"image/png",
"image/vnd.microsoft.icon",
"video/mp4",
}
mime_type = await get_mime_type(filename, force_read=True)
v = "text/" in mime_type or mime_type in pass_set
if v:
return True, mime_type
if mime_type not in denied_set:
logger.info(mime_type)
return False, mime_type

View file

@ -27,14 +27,15 @@
"""
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import Callable, Dict, List
from gitignore_parser import parse_gitignore
from metagpt.tools.libs.shell import shell_execute
def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
async def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
"""
Recursively traverses the directory structure and prints it out in a tree-like format.
@ -80,7 +81,7 @@ def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = Fal
"""
root = Path(root).resolve()
if run_command:
return _execute_tree(root, gitignore)
return await _execute_tree(root, gitignore)
git_ignore_rules = parse_gitignore(gitignore) if gitignore else None
dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)}
@ -129,12 +130,7 @@ def _add_line(rows: List[str]) -> List[str]:
return rows
def _execute_tree(root: Path, gitignore: str | Path) -> str:
async def _execute_tree(root: Path, gitignore: str | Path) -> str:
args = ["--gitfile", str(gitignore)] if gitignore else []
try:
result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True)
if result.returncode != 0:
raise ValueError(f"tree exits with code {result.returncode}")
return result.stdout
except subprocess.CalledProcessError as e:
raise e
stdout, _, _ = await shell_execute(["tree"] + args + [str(root)])
return stdout

View file

@ -60,7 +60,7 @@ gitignore-parser==0.1.9
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
websockets~=11.0
networkx~=3.2.1
google-generativeai==0.3.2
google-generativeai==0.4.1
playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
anytree
ipywidgets==8.1.1

View file

@ -36,6 +36,8 @@ extras_require = {
"llama-index-readers-file==0.1.4",
"llama-index-retrievers-bm25==0.1.3",
"llama-index-vector-stores-faiss==0.1.1",
"llama-index-vector-stores-elasticsearch==0.1.6",
"llama-index-postprocessor-colbert-rerank==0.1.1",
"chromadb==0.4.23",
],
}

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from pathlib import Path
import pytest
from metagpt.actions.extract_readme import ExtractReadMe
from metagpt.llm import LLM
@pytest.mark.asyncio
async def test_learn_readme(context):
action = ExtractReadMe(
name="RedBean",
i_context=str(Path(__file__).parent.parent.parent.parent),
llm=LLM(),
context=context,
)
await action.run()
rows = await action.graph_db.select()
assert rows
assert context.repo.docs.graph_repo.changed_files
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from metagpt.actions.import_repo import ImportRepo
from metagpt.context import Context
from metagpt.utils.common import list_files
@pytest.mark.asyncio
@pytest.mark.parametrize(
"repo_path",
[
"https://github.com/spec-first/connexion.git",
# "https://github.com/geekan/MetaGPT.git"
],
)
@pytest.mark.skip
async def test_import_repo(repo_path):
context = Context()
action = ImportRepo(repo_path=repo_path, context=context)
await action.run()
assert context.repo
prd = list_files(context.repo.docs.prd.workdir)
assert prd
design = list_files(context.repo.docs.system_design.workdir)
assert design
assert prd[0].stem == design[0].stem
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -18,6 +18,7 @@ from metagpt.utils.git_repository import ChangeType
from metagpt.utils.graph_repository import SPO
@pytest.mark.skip
@pytest.mark.asyncio
async def test_rebuild(context, mocker):
# Mock
@ -60,7 +61,7 @@ async def test_rebuild(context, mocker):
],
)
def test_get_full_filename(root, pathname, want):
res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname)
res = RebuildSequenceView.get_full_filename(root=root, pathname=pathname)
assert res == want

View file

@ -11,6 +11,7 @@ import pytest
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
default_resp_cont,
get_part_chat_completion,
@ -22,7 +23,7 @@ name = "GPT"
class MockBaseLLM(BaseLLM):
def __init__(self, config: LLMConfig = None):
pass
self.config = config or mock_llm_config
def completion(self, messages: list[dict], timeout=3):
return get_part_chat_completion(name)

View file

@ -0,0 +1,60 @@
import json
import pytest
from llama_index.core.schema import NodeWithScore, QueryBundle
from pydantic import BaseModel
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import ObjectNode
class Record(BaseModel):
score: int
class TestObjectSortPostprocessor:
@pytest.fixture
def nodes_with_scores(self):
nodes = [
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5),
]
return nodes
@pytest.fixture
def query_bundle(self, mocker):
return mocker.MagicMock(spec=QueryBundle)
def test_sort_descending(self, nodes_with_scores, query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
assert [node.score for node in sorted_nodes] == [20, 10, 5]
def test_sort_ascending(self, nodes_with_scores, query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
assert [node.score for node in sorted_nodes] == [5, 10, 20]
def test_top_n_limit(self, nodes_with_scores, query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
assert len(sorted_nodes) == 2
assert [node.score for node in sorted_nodes] == [20, 10]
def test_invalid_json_metadata(self, query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes, query_bundle)
def test_missing_query_bundle(self, nodes_with_scores):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
def test_field_not_found_in_object(self):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes)

View file

@ -8,6 +8,7 @@ from pathlib import Path
import pytest
from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles import Architect, ProductManager, ProjectManager
from metagpt.team import Team
@ -146,5 +147,21 @@ async def test_team_recover_multi_roles_save(mocker, context):
await new_company.run(n_round=4)
@pytest.mark.asyncio
async def test_context(context):
context.kwargs.set("a", "a")
context.cost_manager.max_budget = 9
company = Team(context=context)
save_to = context.repo.workdir / "serial"
company.serialize(save_to)
company.deserialize(save_to, Context())
assert company.env.context.repo
assert company.env.context.repo.workdir == context.repo.workdir
assert company.env.context.kwargs.a == "a"
assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from pydantic import BaseModel
from metagpt.tools.libs.git import git_checkout, git_clone
from metagpt.utils.git_repository import GitRepository
class SWEBenchItem(BaseModel):
base_commit: str
repo: str
@pytest.mark.asyncio
@pytest.mark.parametrize(
["url", "commit_id"], [("https://github.com/sqlfluff/sqlfluff.git", "d19de0ecd16d298f9e3bfb91da122734c40c01e5")]
)
async def test_git(url: str, commit_id: str):
repo_dir = await git_clone(url)
assert repo_dir
await git_checkout(repo_dir, commit_id)
repo = GitRepository(repo_dir, auto_init=False)
repo.delete_repository()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from metagpt.tools.libs.shell import execute
@pytest.mark.asyncio
@pytest.mark.parametrize(
["command", "expect_stdout", "expect_stderr"],
[
(["file", f"{__file__}"], "Python script text executable, ASCII text", ""),
(f"file {__file__}", "Python script text executable, ASCII text", ""),
],
)
async def test_shell(command, expect_stdout, expect_stderr):
stdout, stderr = await execute(command)
assert expect_stdout in stdout
assert stderr == expect_stderr
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
from metagpt.tools.libs import (
fix_bug,
git_archive,
run_qa_test,
write_codes,
write_design,
write_prd,
write_project_plan,
)
from metagpt.tools.libs.software_development import import_git_repo
@pytest.mark.asyncio
async def test_software_team():
path = await write_prd("snake game")
assert path
path = await write_design(path)
assert path
path = await write_project_plan(path)
assert path
path = await write_codes(path)
assert path
path = await run_qa_test(path)
assert path
issue = """
pygame 2.0.1 (SDL 2.0.14, Python 3.9.17)
Hello from the pygame community. https://www.pygame.org/contribute.html
Traceback (most recent call last):
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/main.py", line 10, in <module>
main()
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/main.py", line 7, in main
game.start_game()
File "/Users/ix/github/bak/MetaGPT/workspace/snake_game/snake_game/game.py", line 81, in start_game
x
NameError: name 'x' is not defined
"""
path = await fix_bug(path, issue)
assert path
new_path = await write_prd("snake game with moving enemy", path)
assert new_path == path
git_log = await git_archive(new_path)
assert git_log
@pytest.mark.asyncio
async def test_import_repo():
url = "https://github.com/spec-first/connexion.git"
path = await import_git_repo(url)
assert path
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -10,7 +10,12 @@ from metagpt.utils.repo_to_markdown import repo_to_markdown
@pytest.mark.parametrize(
["repo_path", "output"],
[(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")],
[
(
Path(__file__).parent.parent.parent,
Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}.md",
),
],
)
@pytest.mark.asyncio
async def test_repo_to_markdown(repo_path: Path, output: Path):

View file

@ -3,6 +3,7 @@ from typing import Optional, Union
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
@ -22,7 +23,7 @@ class MockLLM(OriginalLLM):
self.rsp_cache: dict = {}
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=LLM_API_TIMEOUT) -> str:
"""Overwrite original acompletion_text to cancel retry"""
if stream:
resp = await self._achat_completion_stream(messages, timeout=timeout)
@ -37,7 +38,7 @@ class MockLLM(OriginalLLM):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=LLM_API_TIMEOUT,
stream=True,
) -> str:
if system_msgs:
@ -56,7 +57,7 @@ class MockLLM(OriginalLLM):
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
async def original_aask_batch(self, msgs: list, timeout=3) -> str:
async def original_aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
"""A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked"""
context = []
for msg in msgs:
@ -83,7 +84,7 @@ class MockLLM(OriginalLLM):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=LLM_API_TIMEOUT,
stream=True,
) -> str:
# used to identify it a message has been called before
@ -98,7 +99,7 @@ class MockLLM(OriginalLLM):
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream)
return rsp
async def aask_batch(self, msgs: list, timeout=3) -> str:
async def aask_batch(self, msgs: list, timeout=LLM_API_TIMEOUT) -> str:
msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs])
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
return rsp