mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
Merge branch 'mgx_ops' into role_zero_draft
This commit is contained in:
commit
bbddbf4ef0
61 changed files with 1743 additions and 577 deletions
48
config/vault.example.yaml
Normal file
48
config/vault.example.yaml
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Usage:
|
||||
# 1. Get value.
|
||||
# >>> from metagpt.tools.libs.env import get_env
|
||||
# >>> access_token = await get_env(key="access_token", app_name="github")
|
||||
# >>> print(access_token)
|
||||
# YOUR_ACCESS_TOKEN
|
||||
#
|
||||
# 2. Get description for LLM understanding.
|
||||
# >>> from metagpt.tools.libs.env import get_env_description
|
||||
# >>> descriptions = await get_env_description
|
||||
# >>> for k, desc in descriptions.items():
|
||||
# >>> print(f"{key}:{desc}")
|
||||
# await get_env(key="access_token", app_name="github"):Get github access token
|
||||
# await get_env(key="access_token", app_name="gitlab"):Get gitlab access token
|
||||
# ...
|
||||
|
||||
vault:
|
||||
github:
|
||||
values:
|
||||
access_token: "YOUR_ACCESS_TOKEN"
|
||||
descriptions:
|
||||
access_token: "Get github access token"
|
||||
gitlab:
|
||||
values:
|
||||
access_token: "YOUR_ACCESS_TOKEN"
|
||||
descriptions:
|
||||
access_token: "Get gitlab access token"
|
||||
iflytek_tts:
|
||||
values:
|
||||
api_id: "YOUR_APP_ID"
|
||||
api_key: "YOUR_API_KEY"
|
||||
api_secret: "YOUR_API_SECRET"
|
||||
descriptions:
|
||||
api_id: "Get the API ID of IFlyTek Text to Speech"
|
||||
api_key: "Get the API KEY of IFlyTek Text to Speech"
|
||||
api_secret: "Get the API SECRET of IFlyTek Text to Speech"
|
||||
azure_tts:
|
||||
values:
|
||||
subscription_key: "YOUR_SUBSCRIPTION_KEY"
|
||||
region: "YOUR_REGION"
|
||||
descriptions:
|
||||
subscription_key: "Get the subscription key of Azure Text to Speech."
|
||||
region: "Get the region of Azure Text to Speech."
|
||||
default: # All key-value pairs whose app name is an empty string are placed below
|
||||
values:
|
||||
proxy: "YOUR_PROXY"
|
||||
descriptions:
|
||||
proxy: "Get proxy for tools like requests, playwright, selenium, etc."
|
||||
|
|
@ -6,16 +6,19 @@
|
|||
"""
|
||||
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
from metagpt.tools.libs.browser import Browser as _
|
||||
|
||||
|
||||
PAPER_LIST_REQ = """"
|
||||
Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/,
|
||||
and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables*
|
||||
and save it to a csv file. paper title must include `multiagent` or `large language model`.
|
||||
**Notice: view the page element before writing scraping code**
|
||||
"""
|
||||
|
||||
ECOMMERCE_REQ = """
|
||||
Get products data from website https://scrapeme.live/shop/ and save it as a csv file.
|
||||
**Notice: Firstly parse the web page encoding and the text HTML structure;
|
||||
The first page product name, price, product URL, and image URL must be saved in the csv;**
|
||||
The first page product name, price, product URL, and image URL must be saved in the csv.
|
||||
**Notice: view the page element before writing scraping code**
|
||||
"""
|
||||
|
||||
NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash 所有初创企业融资的信息, **注意: 这是一个中文网站**;
|
||||
|
|
@ -25,11 +28,12 @@ NEWS_36KR_REQ = """从36kr创投平台https://pitchhub.36kr.com/financing-flash
|
|||
3. 反思*快讯的html内容示例*中的规律, 设计正则匹配表达式来获取*`快讯`*的标题、链接、时间;
|
||||
4. 筛选最近3天的初创企业融资*`快讯`*, 以list[dict]形式打印前5个。
|
||||
5. 将全部结果存在本地csv中
|
||||
**Notice: view the page element before writing scraping code**
|
||||
"""
|
||||
|
||||
|
||||
async def main():
|
||||
di = DataInterpreter(tools=["scrape_web_playwright"])
|
||||
di = DataInterpreter(tools=["Browser"])
|
||||
|
||||
await di.run(ECOMMERCE_REQ)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from metagpt.schema import (
|
|||
SerializationMixin,
|
||||
TestingContext,
|
||||
)
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class Action(SerializationMixin, ContextMixin, BaseModel):
|
||||
|
|
@ -36,12 +35,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
|
|||
desc: str = "" # for skill manager
|
||||
node: ActionNode = Field(default=None, exclude=True)
|
||||
|
||||
@property
|
||||
def repo(self) -> ProjectRepo:
|
||||
if not self.context.repo:
|
||||
self.context.repo = ProjectRepo(self.context.git_repo)
|
||||
return self.context.repo
|
||||
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
return self.config.prompt_schema
|
||||
|
|
|
|||
|
|
@ -9,13 +9,15 @@
|
|||
2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name.
|
||||
"""
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
NOTICE
|
||||
|
|
@ -47,6 +49,8 @@ Now you should start rewriting the code:
|
|||
|
||||
class DebugError(Action):
|
||||
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename)
|
||||
|
|
@ -59,9 +63,7 @@ class DebugError(Action):
|
|||
return ""
|
||||
|
||||
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
|
||||
code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get(
|
||||
filename=self.i_context.code_filename
|
||||
)
|
||||
code_doc = await self.repo.srcs.get(filename=self.i_context.code_filename)
|
||||
if not code_doc:
|
||||
return ""
|
||||
test_doc = await self.repo.tests.get(filename=self.i_context.test_filename)
|
||||
|
|
|
|||
|
|
@ -8,10 +8,14 @@
|
|||
1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name.
|
||||
2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality.
|
||||
@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD.
|
||||
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.design_api_an import (
|
||||
|
|
@ -22,10 +26,17 @@ from metagpt.actions.design_api_an import (
|
|||
REFINED_DESIGN_NODE,
|
||||
REFINED_PROGRAM_CALL_FLOW,
|
||||
)
|
||||
from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO
|
||||
from metagpt.const import (
|
||||
DATA_API_DESIGN_FILE_REPO,
|
||||
DEFAULT_WORKSPACE_ROOT,
|
||||
SEQ_FLOW_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import AIMessage, Document, Documents, Message
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import aread, awrite, to_markdown_code_block
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.report import DocsReporter, GalleryReporter
|
||||
|
||||
NEW_REQ_TEMPLATE = """
|
||||
|
|
@ -37,6 +48,7 @@ NEW_REQ_TEMPLATE = """
|
|||
"""
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "write system design"])
|
||||
class WriteDesign(Action):
|
||||
name: str = ""
|
||||
i_context: Optional[str] = None
|
||||
|
|
@ -45,21 +57,134 @@ class WriteDesign(Action):
|
|||
"data structures, library tables, processes, and paths. Please provide your design, feedback "
|
||||
"clearly and in detail."
|
||||
)
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
async def run(self, with_messages: Message, schema: str = None):
|
||||
# Use `git status` to identify which PRD documents have been modified in the `docs/prd` directory.
|
||||
changed_prds = self.repo.docs.prd.changed_files
|
||||
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
|
||||
# changes.
|
||||
changed_system_designs = self.repo.docs.system_design.changed_files
|
||||
async def run(
|
||||
self,
|
||||
with_messages: List[Message] = None,
|
||||
*,
|
||||
user_requirement: str = "",
|
||||
prd_filename: str = "",
|
||||
legacy_design_filename: str = "",
|
||||
extra_info: str = "",
|
||||
output_pathname: str = "",
|
||||
**kwargs,
|
||||
) -> AIMessage:
|
||||
"""
|
||||
Write a system design.
|
||||
|
||||
Args:
|
||||
user_requirement (str): The user's requirements for the system design.
|
||||
prd_filename (str, optional): The filename of the Product Requirement Document (PRD).
|
||||
legacy_design_filename (str, optional): The filename of the legacy design document.
|
||||
extra_info (str, optional): Additional information to be included in the system design.
|
||||
output_pathname (str, optional): The output path name of file that the system design should be saved to.
|
||||
|
||||
Returns:
|
||||
AIMessage: An AIMessage object containing the system design.
|
||||
|
||||
Example:
|
||||
# Write a new system design.
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> extra_info = "Your extra information"
|
||||
>>> action = WriteDesign()
|
||||
>>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info)
|
||||
>>> print(result.content)
|
||||
System Design filename: "/path/to/design/filename"
|
||||
|
||||
# Modify an exists system design.
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> extra_info = "Your extra information"
|
||||
>>> legacy_design_filename = "/path/to/exists/design/filename"
|
||||
>>> action = WriteDesign()
|
||||
>>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename)
|
||||
>>> print(result.content)
|
||||
System Design filename: "/path/to/design/filename"
|
||||
|
||||
# Write a new system design with the given PRD(Product Requirement Document).
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> extra_info = "Your extra information"
|
||||
>>> prd_filename = "/path/to/prd/filename"
|
||||
>>> action = WriteDesign()
|
||||
>>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename)
|
||||
>>> print(result.content)
|
||||
System Design filename: "/path/to/design/filename"
|
||||
|
||||
# Modify an exists system design with the given PRD(Product Requirement Document).
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> extra_info = "Your extra information"
|
||||
>>> prd_filename = "/path/to/prd/filename"
|
||||
>>> legacy_design_filename = "/path/to/exists/design/filename"
|
||||
>>> action = WriteDesign()
|
||||
>>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename)
|
||||
>>> print(result.content)
|
||||
TSystem Design filename: "/path/to/design/filename"
|
||||
|
||||
# Write a new system design and save to the path name.
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> extra_info = "Your extra information"
|
||||
>>> output_pathname = "/path/to/design/filename"
|
||||
>>> action = WriteDesign()
|
||||
>>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, output_pathname=output_pathname)
|
||||
>>> print(result.content)
|
||||
System Design filename: "/path/to/design/filename"
|
||||
|
||||
# Modify an exists system design and save to the path name.
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> extra_info = "Your extra information"
|
||||
>>> legacy_design_filename = "/path/to/exists/design/filename"
|
||||
>>> output_pathname = "/path/to/design/filename"
|
||||
>>> action = WriteDesign()
|
||||
>>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, output_pathname=output_pathname)
|
||||
>>> print(result.content)
|
||||
System Design filename: "/path/to/design/filename"
|
||||
|
||||
# Write a new system design with the given PRD(Product Requirement Document) and save to the path name.
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> extra_info = "Your extra information"
|
||||
>>> prd_filename = "/path/to/prd/filename"
|
||||
>>> output_pathname = "/path/to/design/filename"
|
||||
>>> action = WriteDesign()
|
||||
>>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename, output_pathname=output_pathname)
|
||||
>>> print(result.content)
|
||||
System Design filename: "/path/to/design/filename"
|
||||
|
||||
# Modify an exists system design with the given PRD(Product Requirement Document) and save to the path name.
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> extra_info = "Your extra information"
|
||||
>>> prd_filename = "/path/to/prd/filename"
|
||||
>>> legacy_design_filename = "/path/to/exists/design/filename"
|
||||
>>> output_pathname = "/path/to/design/filename"
|
||||
>>> action = WriteDesign()
|
||||
>>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, prd_filename=prd_filename, output_pathname=output_pathname)
|
||||
>>> print(result.content)
|
||||
System Design filename: "/path/to/design/filename"
|
||||
"""
|
||||
if not with_messages:
|
||||
return await self._execute_api(
|
||||
user_requirement=user_requirement,
|
||||
prd_filename=prd_filename,
|
||||
legacy_design_filename=legacy_design_filename,
|
||||
extra_info=extra_info,
|
||||
output_pathname=output_pathname,
|
||||
)
|
||||
|
||||
self.input_args = with_messages[-1].instruct_content
|
||||
self.repo = ProjectRepo(self.input_args.project_path)
|
||||
changed_prds = self.input_args.changed_prd_filenames
|
||||
changed_system_designs = [
|
||||
str(self.repo.docs.system_design.workdir / i)
|
||||
for i in list(self.repo.docs.system_design.changed_files.keys())
|
||||
]
|
||||
|
||||
# For those PRDs and design documents that have undergone changes, regenerate the design content.
|
||||
changed_files = Documents()
|
||||
for filename in changed_prds.keys():
|
||||
for filename in changed_prds:
|
||||
doc = await self._update_system_design(filename=filename)
|
||||
changed_files.docs[filename] = doc
|
||||
|
||||
for filename in changed_system_designs.keys():
|
||||
for filename in changed_system_designs:
|
||||
if filename in changed_files.docs:
|
||||
continue
|
||||
doc = await self._update_system_design(filename=filename)
|
||||
|
|
@ -68,6 +193,11 @@ class WriteDesign(Action):
|
|||
logger.info("Nothing has changed.")
|
||||
# Wait until all files under `docs/system_designs/` are processed before sending the publish message,
|
||||
# leaving room for global optimization in subsequent steps.
|
||||
kvs = self.input_args.model_dump()
|
||||
kvs["changed_system_design_filenames"] = [
|
||||
str(self.repo.docs.system_design.workdir / i)
|
||||
for i in list(self.repo.docs.system_design.changed_files.keys())
|
||||
]
|
||||
return AIMessage(
|
||||
content="Designing is complete. "
|
||||
+ "\n".join(
|
||||
|
|
@ -75,6 +205,7 @@ class WriteDesign(Action):
|
|||
+ list(self.repo.resources.data_api_design.changed_files.keys())
|
||||
+ list(self.repo.resources.seq_flow.changed_files.keys())
|
||||
),
|
||||
instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput"),
|
||||
cause_by=self,
|
||||
)
|
||||
|
||||
|
|
@ -89,14 +220,15 @@ class WriteDesign(Action):
|
|||
return system_design_doc
|
||||
|
||||
async def _update_system_design(self, filename) -> Document:
|
||||
prd = await self.repo.docs.prd.get(filename)
|
||||
old_system_design_doc = await self.repo.docs.system_design.get(filename)
|
||||
root_relative_path = Path(filename).relative_to(self.repo.workdir)
|
||||
prd = await Document.load(filename=filename, project_path=self.repo.workdir)
|
||||
old_system_design_doc = await self.repo.docs.system_design.get(root_relative_path.name)
|
||||
async with DocsReporter(enable_llm_stream=True) as reporter:
|
||||
await reporter.async_report({"type": "design"}, "meta")
|
||||
if not old_system_design_doc:
|
||||
system_design = await self._new_system_design(context=prd.content)
|
||||
doc = await self.repo.docs.system_design.save(
|
||||
filename=filename,
|
||||
filename=prd.filename,
|
||||
content=system_design.instruct_content.model_dump_json(),
|
||||
dependencies={prd.root_relative_path},
|
||||
)
|
||||
|
|
@ -133,3 +265,40 @@ class WriteDesign(Action):
|
|||
image_path = pathname.parent / f"{pathname.name}.png"
|
||||
if image_path.exists():
|
||||
await GalleryReporter().async_report(image_path, "path")
|
||||
|
||||
async def _execute_api(
|
||||
self,
|
||||
user_requirement: str = "",
|
||||
prd_filename: str = "",
|
||||
legacy_design_filename: str = "",
|
||||
extra_info: str = "",
|
||||
output_pathname: str = "",
|
||||
) -> AIMessage:
|
||||
prd_content = ""
|
||||
if prd_filename:
|
||||
prd_content = await aread(filename=prd_filename)
|
||||
context = "### User Requirements\n{user_requirement}\n### Extra_info\n{extra_info}\n### PRD\n{prd}\n".format(
|
||||
user_requirement=to_markdown_code_block(user_requirement),
|
||||
extra_info=to_markdown_code_block(extra_info),
|
||||
prd=to_markdown_code_block(prd_content),
|
||||
)
|
||||
if not legacy_design_filename:
|
||||
node = await self._new_system_design(context=context)
|
||||
design = Document(content=node.instruct_content.model_dump_json())
|
||||
else:
|
||||
old_design_content = await aread(filename=legacy_design_filename)
|
||||
design = await self._merge(
|
||||
prd_doc=Document(content=context), system_design_doc=Document(content=old_design_content)
|
||||
)
|
||||
|
||||
if not output_pathname:
|
||||
output_path = DEFAULT_WORKSPACE_ROOT
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
output_pathname = Path(output_path) / f"{uuid.uuid4().hex}.json"
|
||||
await awrite(filename=output_pathname, data=design.content)
|
||||
kvs = {"changed_system_design_filenames": [output_pathname]}
|
||||
|
||||
return AIMessage(
|
||||
content=f'System Design filename: "{str(output_pathname)}"',
|
||||
instruct_content=AIMessage.create_instruct_value(kvs=kvs),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class ExecuteNbCode(Action):
|
|||
"""execute notebook code block, return result to llm, and display it."""
|
||||
|
||||
nb: NotebookNode
|
||||
nb_client: NotebookClient = None
|
||||
nb_client: RealtimeOutputNotebookClient = None
|
||||
console: Console
|
||||
interaction: str
|
||||
timeout: int = 600
|
||||
|
|
@ -78,11 +78,15 @@ class ExecuteNbCode(Action):
|
|||
interaction=("ipython" if self.is_ipython() else "terminal"),
|
||||
)
|
||||
self.reporter = NotebookReporter()
|
||||
self.set_nb_client()
|
||||
|
||||
def set_nb_client(self):
|
||||
self.nb_client = RealtimeOutputNotebookClient(
|
||||
nb,
|
||||
timeout=timeout,
|
||||
self.nb,
|
||||
timeout=self.timeout,
|
||||
resources={"metadata": {"path": DEFAULT_WORKSPACE_ROOT}},
|
||||
notebook_reporter=self.reporter,
|
||||
coalesce_streams=True,
|
||||
)
|
||||
|
||||
async def build(self):
|
||||
|
|
@ -118,7 +122,7 @@ class ExecuteNbCode(Action):
|
|||
# sleep 1s to wait for the kernel to be cleaned up completely
|
||||
await asyncio.sleep(1)
|
||||
await self.build()
|
||||
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
|
||||
self.set_nb_client()
|
||||
|
||||
def add_code_cell(self, code: str):
|
||||
self.nb.cells.append(new_code_cell(source=code))
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from metagpt.logs import logger
|
|||
from metagpt.schema import AIMessage
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class PrepareDocuments(Action):
|
||||
|
|
@ -36,7 +37,7 @@ class PrepareDocuments(Action):
|
|||
def config(self):
|
||||
return self.context.config
|
||||
|
||||
def _init_repo(self):
|
||||
def _init_repo(self) -> ProjectRepo:
|
||||
"""Initialize the Git environment."""
|
||||
if not self.config.project_path:
|
||||
name = self.config.project_name or FileRepository.new_filename()
|
||||
|
|
@ -45,8 +46,9 @@ class PrepareDocuments(Action):
|
|||
path = Path(self.config.project_path)
|
||||
if path.exists() and not self.config.inc:
|
||||
shutil.rmtree(path)
|
||||
self.config.project_path = path
|
||||
self.context.set_repo_dir(path)
|
||||
self.context.kwargs.project_path = path
|
||||
self.context.kwargs.inc = self.config.inc
|
||||
return ProjectRepo(path)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
"""Create and initialize the workspace folder, initialize the Git environment."""
|
||||
|
|
@ -67,10 +69,22 @@ class PrepareDocuments(Action):
|
|||
max_auto_summarize_code=0,
|
||||
)
|
||||
|
||||
self._init_repo()
|
||||
repo = self._init_repo()
|
||||
|
||||
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
|
||||
doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
|
||||
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
|
||||
# Send a Message notification to the WritePRD action, instructing it to process requirements using
|
||||
# `docs/requirement.txt` and `docs/prd/`.
|
||||
return AIMessage(content="", instruct_content=doc, cause_by=self, send_to=self.send_to)
|
||||
return AIMessage(
|
||||
content="",
|
||||
instruct_content=AIMessage.create_instruct_value(
|
||||
kvs={
|
||||
"project_path": str(repo.workdir),
|
||||
"requirements_filename": str(repo.docs.workdir / REQUIREMENT_FILENAME),
|
||||
"prd_filenames": [str(repo.docs.prd.workdir / i) for i in repo.docs.prd.all_files],
|
||||
},
|
||||
class_name="PrepareDocumentsOutput",
|
||||
),
|
||||
cause_by=self,
|
||||
send_to=self.send_to,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,16 +8,23 @@
|
|||
1. Divide the context into three components: legacy code, unit test code, and console log.
|
||||
2. Move the document storage operations related to WritePRD from the save operation of WriteDesign.
|
||||
3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality.
|
||||
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE
|
||||
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import AIMessage, Document, Documents
|
||||
from metagpt.schema import AIMessage, Document, Documents, Message
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import aread, to_markdown_code_block
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.report import DocsReporter
|
||||
|
||||
NEW_REQ_TEMPLATE = """
|
||||
|
|
@ -29,19 +36,56 @@ NEW_REQ_TEMPLATE = """
|
|||
"""
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "write a project schedule given a project system design file"])
|
||||
class WriteTasks(Action):
|
||||
name: str = "CreateTasks"
|
||||
i_context: Optional[str] = None
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
async def run(self, with_messages):
|
||||
changed_system_designs = self.repo.docs.system_design.changed_files
|
||||
changed_tasks = self.repo.docs.task.changed_files
|
||||
async def run(
|
||||
self, with_messages: List[Message] = None, *, user_requirement: str = "", design_filename: str = "", **kwargs
|
||||
) -> AIMessage:
|
||||
"""
|
||||
Write a project schedule given a project system design file.
|
||||
|
||||
Args:
|
||||
user_requirement (str, optional): A string specifying the user's requirements. Defaults to an empty string.
|
||||
design_filename (str): The filename of the project system design file. Defaults to an empty string.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
AIMessage: The generated project schedule.
|
||||
|
||||
Example:
|
||||
# Write a new project schedule.
|
||||
>>> design_filename = "/path/to/design/filename"
|
||||
>>> action = WriteTasks()
|
||||
>>> result = await action.run(design_filename=design_filename)
|
||||
>>> print(result.content)
|
||||
The project schedule is balabala...
|
||||
|
||||
# Write a new project schedule with the user requirement.
|
||||
>>> design_filename = "/path/to/design/filename"
|
||||
>>> user_requirement = "Your user requirements"
|
||||
>>> action = WriteTasks()
|
||||
>>> result = await action.run(design_filename=design_filename, user_requirement=user_requirement)
|
||||
>>> print(result.content)
|
||||
The project schedule is balabala...
|
||||
"""
|
||||
if not with_messages:
|
||||
return await self._execute_api(user_requirement=user_requirement, design_filename=design_filename)
|
||||
|
||||
self.input_args = with_messages[-1].instruct_content
|
||||
self.repo = ProjectRepo(self.input_args.project_path)
|
||||
changed_system_designs = self.input_args.changed_system_design_filenames
|
||||
changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())]
|
||||
change_files = Documents()
|
||||
# Rewrite the system designs that have undergone changes based on the git head diff under
|
||||
# `docs/system_designs/`.
|
||||
for filename in changed_system_designs:
|
||||
task_doc = await self._update_tasks(filename=filename)
|
||||
change_files.docs[filename] = task_doc
|
||||
change_files.docs[str(self.repo.docs.task.workdir / task_doc.filename)] = task_doc
|
||||
|
||||
# Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`.
|
||||
for filename in changed_tasks:
|
||||
|
|
@ -54,6 +98,11 @@ class WriteTasks(Action):
|
|||
logger.info("Nothing has changed.")
|
||||
# Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for
|
||||
# global optimization in subsequent steps.
|
||||
kvs = self.input_args.model_dump()
|
||||
kvs["changed_task_filenames"] = [
|
||||
str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())
|
||||
]
|
||||
kvs["python_package_dependency_filename"] = str(self.repo.workdir / PACKAGE_REQUIREMENTS_FILENAME)
|
||||
return AIMessage(
|
||||
content="WBS is completed. "
|
||||
+ "\n".join(
|
||||
|
|
@ -61,12 +110,14 @@ class WriteTasks(Action):
|
|||
+ list(self.repo.docs.task.changed_files.keys())
|
||||
+ list(self.repo.resources.api_spec_and_task.changed_files.keys())
|
||||
),
|
||||
instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTaskOutput"),
|
||||
cause_by=self,
|
||||
)
|
||||
|
||||
async def _update_tasks(self, filename):
|
||||
system_design_doc = await self.repo.docs.system_design.get(filename)
|
||||
task_doc = await self.repo.docs.task.get(filename)
|
||||
root_relative_path = Path(filename).relative_to(self.repo.workdir)
|
||||
system_design_doc = await Document.load(filename=filename, project_path=self.repo.workdir)
|
||||
task_doc = await self.repo.docs.task.get(root_relative_path.name)
|
||||
async with DocsReporter(enable_llm_stream=True) as reporter:
|
||||
await reporter.async_report({"type": "task"}, "meta")
|
||||
if task_doc:
|
||||
|
|
@ -75,7 +126,7 @@ class WriteTasks(Action):
|
|||
else:
|
||||
rsp = await self._run_new_tasks(context=system_design_doc.content)
|
||||
task_doc = await self.repo.docs.task.save(
|
||||
filename=filename,
|
||||
filename=system_design_doc.filename,
|
||||
content=rsp.instruct_content.model_dump_json(),
|
||||
dependencies={system_design_doc.root_relative_path},
|
||||
)
|
||||
|
|
@ -84,7 +135,7 @@ class WriteTasks(Action):
|
|||
await reporter.async_report(self.repo.workdir / md.root_relative_path, "path")
|
||||
return task_doc
|
||||
|
||||
async def _run_new_tasks(self, context):
|
||||
async def _run_new_tasks(self, context: str):
|
||||
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
return node
|
||||
|
||||
|
|
@ -106,3 +157,11 @@ class WriteTasks(Action):
|
|||
continue
|
||||
packages.add(pkg)
|
||||
await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
|
||||
|
||||
async def _execute_api(self, user_requirement: str = "", design_filename: str = ""):
|
||||
context = to_markdown_code_block(user_requirement)
|
||||
if not design_filename:
|
||||
content = await aread(filename=design_filename)
|
||||
context += to_markdown_code_block(content)
|
||||
node = await self._run_new_tasks(context)
|
||||
return AIMessage(content=node.instruct_content.model_dump_json())
|
||||
|
|
|
|||
|
|
@ -6,13 +6,16 @@
|
|||
@Modified By: mashenquan, 2023/12/5. Archive the summarization content of issue discovery for use in WriteCode.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.common import get_markdown_code_block_type
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
NOTICE
|
||||
|
|
@ -90,6 +93,8 @@ flowchart TB
|
|||
class SummarizeCode(Action):
|
||||
name: str = "SummarizeCode"
|
||||
i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
|
||||
async def summarize_code(self, prompt):
|
||||
|
|
@ -101,11 +106,10 @@ class SummarizeCode(Action):
|
|||
design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name)
|
||||
task_pathname = Path(self.i_context.task_filename)
|
||||
task_doc = await self.repo.docs.task.get(filename=task_pathname.name)
|
||||
src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs
|
||||
code_blocks = []
|
||||
for filename in self.i_context.codes_filenames:
|
||||
code_doc = await src_file_repo.get(filename)
|
||||
code_block = f"```python\n{code_doc.content}\n```\n-----"
|
||||
code_doc = await self.repo.srcs.get(filename)
|
||||
code_block = f"```{get_markdown_code_block_type(filename)}\n{code_doc.content}\n```\n---\n"
|
||||
code_blocks.append(code_block)
|
||||
format_example = FORMAT_EXAMPLE
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
|
|
|
|||
|
|
@ -16,17 +16,18 @@
|
|||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
|
||||
from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE
|
||||
from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.common import CodeParser, get_markdown_code_block_type
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.report import EditorReporter
|
||||
|
||||
|
|
@ -44,9 +45,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
{task}
|
||||
|
||||
## Legacy Code
|
||||
```Code
|
||||
{code}
|
||||
```
|
||||
|
||||
## Debug logs
|
||||
```text
|
||||
|
|
@ -61,14 +60,14 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
```
|
||||
|
||||
# Format example
|
||||
## Code: {filename}
|
||||
## Code: {demo_filename}.py
|
||||
```python
|
||||
## {filename}
|
||||
## {demo_filename}.py
|
||||
...
|
||||
```
|
||||
## Code: {filename}
|
||||
## Code: {demo_filename}.js
|
||||
```javascript
|
||||
// {filename}
|
||||
// {demo_filename}.js
|
||||
...
|
||||
```
|
||||
|
||||
|
|
@ -89,6 +88,8 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
class WriteCode(Action):
|
||||
name: str = "WriteCode"
|
||||
i_context: Document = Field(default_factory=Document)
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code(self, prompt) -> str:
|
||||
|
|
@ -97,10 +98,16 @@ class WriteCode(Action):
|
|||
return code
|
||||
|
||||
async def run(self, *args, **kwargs) -> CodingContext:
|
||||
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
|
||||
bug_feedback = None
|
||||
if self.input_args and hasattr(self.input_args, "issue_filename"):
|
||||
bug_feedback = await Document.load(self.input_args.issue_filename)
|
||||
coding_context = CodingContext.loads(self.i_context.content)
|
||||
if not coding_context.code_plan_and_change_doc:
|
||||
coding_context.code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(
|
||||
filename=coding_context.task_doc.filename
|
||||
)
|
||||
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
|
||||
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
requirement_doc = await Document.load(self.input_args.requirements_filename)
|
||||
summary_doc = None
|
||||
if coding_context.design_doc and coding_context.design_doc.filename:
|
||||
summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
|
||||
|
|
@ -109,29 +116,28 @@ class WriteCode(Action):
|
|||
test_detail = RunCodeResult.loads(test_doc.content)
|
||||
logs = test_detail.stderr
|
||||
|
||||
if bug_feedback:
|
||||
code_context = coding_context.code_doc.content
|
||||
elif self.config.inc:
|
||||
if self.config.inc or bug_feedback:
|
||||
code_context = await self.get_codes(
|
||||
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True
|
||||
)
|
||||
else:
|
||||
code_context = await self.get_codes(
|
||||
coding_context.task_doc,
|
||||
exclude=self.i_context.filename,
|
||||
project_repo=self.repo.with_src_path(self.context.src_workspace),
|
||||
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo
|
||||
)
|
||||
|
||||
if self.config.inc:
|
||||
prompt = REFINED_TEMPLATE.format(
|
||||
user_requirement=requirement_doc.content if requirement_doc else "",
|
||||
code_plan_and_change=str(coding_context.code_plan_and_change_doc),
|
||||
code_plan_and_change=coding_context.code_plan_and_change_doc.content
|
||||
if coding_context.code_plan_and_change_doc
|
||||
else "",
|
||||
design=coding_context.design_doc.content if coding_context.design_doc else "",
|
||||
task=coding_context.task_doc.content if coding_context.task_doc else "",
|
||||
code=code_context,
|
||||
logs=logs,
|
||||
feedback=bug_feedback.content if bug_feedback else "",
|
||||
filename=self.i_context.filename,
|
||||
demo_filename=Path(self.i_context.filename).stem,
|
||||
summary_log=summary_doc.content if summary_doc else "",
|
||||
)
|
||||
else:
|
||||
|
|
@ -142,6 +148,7 @@ class WriteCode(Action):
|
|||
logs=logs,
|
||||
feedback=bug_feedback.content if bug_feedback else "",
|
||||
filename=self.i_context.filename,
|
||||
demo_filename=Path(self.i_context.filename).stem,
|
||||
summary_log=summary_doc.content if summary_doc else "",
|
||||
)
|
||||
logger.info(f"Writing {coding_context.filename}..")
|
||||
|
|
@ -150,10 +157,11 @@ class WriteCode(Action):
|
|||
code = await self.write_code(prompt)
|
||||
if not coding_context.code_doc:
|
||||
# avoid root_path pydantic ValidationError if use WriteCode alone
|
||||
root_path = self.context.src_workspace if self.context.src_workspace else ""
|
||||
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
|
||||
coding_context.code_doc = Document(
|
||||
filename=coding_context.filename, root_path=str(self.repo.src_relative_path)
|
||||
)
|
||||
coding_context.code_doc.content = code
|
||||
await reporter.async_report(self.repo.workdir / coding_context.code_doc.root_relative_path, "path")
|
||||
await reporter.async_report(coding_context.code_doc, "document")
|
||||
return coding_context
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -178,35 +186,32 @@ class WriteCode(Action):
|
|||
code_filenames = m.get(TASK_LIST.key, []) if not use_inc else m.get(REFINED_TASK_LIST.key, [])
|
||||
codes = []
|
||||
src_file_repo = project_repo.srcs
|
||||
|
||||
# Incremental development scenario
|
||||
if use_inc:
|
||||
src_files = src_file_repo.all_files
|
||||
# Get the old workspace contained the old codes and old workspace are created in previous CodePlanAndChange
|
||||
old_file_repo = project_repo.git_repo.new_file_repository(relative_path=project_repo.old_workspace)
|
||||
old_files = old_file_repo.all_files
|
||||
# Get the union of the files in the src and old workspaces
|
||||
union_files_list = list(set(src_files) | set(old_files))
|
||||
for filename in union_files_list:
|
||||
for filename in src_file_repo.all_files:
|
||||
code_block_type = get_markdown_code_block_type(filename)
|
||||
# Exclude the current file from the all code snippets
|
||||
if filename == exclude:
|
||||
# If the file is in the old workspace, use the old code
|
||||
# Exclude unnecessary code to maintain a clean and focused main.py file, ensuring only relevant and
|
||||
# essential functionality is included for the project’s requirements
|
||||
if filename in old_files and filename != "main.py":
|
||||
if filename != "main.py":
|
||||
# Use old code
|
||||
doc = await old_file_repo.get(filename=filename)
|
||||
doc = await src_file_repo.get(filename=filename)
|
||||
# If the file is in the src workspace, skip it
|
||||
else:
|
||||
continue
|
||||
codes.insert(0, f"-----Now, {filename} to be rewritten\n```{doc.content}```\n=====")
|
||||
codes.insert(
|
||||
0, f"### The name of file to rewrite: `{filename}`\n```{code_block_type}\n{doc.content}```\n"
|
||||
)
|
||||
logger.info(f"Prepare to rewrite `{filename}`")
|
||||
# The code snippets are generated from the src workspace
|
||||
else:
|
||||
doc = await src_file_repo.get(filename=filename)
|
||||
# If the file does not exist in the src workspace, skip it
|
||||
if not doc:
|
||||
continue
|
||||
codes.append(f"----- {filename}\n```{doc.content}```")
|
||||
codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n")
|
||||
|
||||
# Normal scenario
|
||||
else:
|
||||
|
|
@ -217,6 +222,7 @@ class WriteCode(Action):
|
|||
doc = await src_file_repo.get(filename=filename)
|
||||
if not doc:
|
||||
continue
|
||||
codes.append(f"----- {filename}\n```{doc.content}```")
|
||||
code_block_type = get_markdown_code_block_type(filename)
|
||||
codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n")
|
||||
|
||||
return "\n".join(codes)
|
||||
|
|
|
|||
|
|
@ -5,15 +5,16 @@
|
|||
@Author : mannaandpoem
|
||||
@File : write_code_plan_and_change_an.py
|
||||
"""
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodePlanAndChangeContext
|
||||
from metagpt.schema import CodePlanAndChangeContext, Document
|
||||
from metagpt.utils.common import get_markdown_code_block_type
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
DEVELOPMENT_PLAN = ActionNode(
|
||||
key="Development Plan",
|
||||
|
|
@ -162,9 +163,8 @@ Role: You are a professional engineer; The main goal is to complete incremental
|
|||
{task}
|
||||
|
||||
## Legacy Code
|
||||
```Code
|
||||
{code}
|
||||
```
|
||||
|
||||
|
||||
## Debug logs
|
||||
```text
|
||||
|
|
@ -179,14 +179,14 @@ Role: You are a professional engineer; The main goal is to complete incremental
|
|||
```
|
||||
|
||||
# Format example
|
||||
## Code: {filename}
|
||||
## Code: {demo_filename}.py
|
||||
```python
|
||||
## {filename}
|
||||
## {demo_filename}.py
|
||||
...
|
||||
```
|
||||
## Code: {filename}
|
||||
## Code: {demo_filename}.js
|
||||
```javascript
|
||||
// {filename}
|
||||
// {demo_filename}.js
|
||||
...
|
||||
```
|
||||
|
||||
|
|
@ -211,13 +211,15 @@ WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChan
|
|||
class WriteCodePlanAndChange(Action):
|
||||
name: str = "WriteCodePlanAndChange"
|
||||
i_context: CodePlanAndChangeContext = Field(default_factory=CodePlanAndChangeContext)
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
self.llm.system_prompt = "You are a professional software engineer, your primary responsibility is to "
|
||||
"meticulously craft comprehensive incremental development plan and deliver detailed incremental change"
|
||||
prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename)
|
||||
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
|
||||
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
|
||||
prd_doc = await Document.load(filename=self.i_context.prd_filename)
|
||||
design_doc = await Document.load(filename=self.i_context.design_filename)
|
||||
task_doc = await Document.load(filename=self.i_context.task_filename)
|
||||
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
|
||||
requirement=f"```text\n{self.i_context.requirement}\n```",
|
||||
issue=f"```text\n{self.i_context.issue}\n```",
|
||||
|
|
@ -230,8 +232,9 @@ class WriteCodePlanAndChange(Action):
|
|||
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
||||
async def get_old_codes(self) -> str:
|
||||
self.repo.old_workspace = self.repo.git_repo.workdir / os.path.basename(self.config.project_path)
|
||||
old_file_repo = self.repo.git_repo.new_file_repository(relative_path=self.repo.old_workspace)
|
||||
old_codes = await old_file_repo.get_all()
|
||||
codes = [f"----- {code.filename}\n```{code.content}```" for code in old_codes]
|
||||
old_codes = await self.repo.srcs.get_all()
|
||||
codes = [
|
||||
f"### File Name: `{code.filename}`\n```{get_markdown_code_block_type(code.filename)}\n{code.content}```\n"
|
||||
for code in old_codes
|
||||
]
|
||||
return "\n".join(codes)
|
||||
|
|
|
|||
|
|
@ -7,16 +7,18 @@
|
|||
@Modified By: mashenquan, 2023/11/27. Following the think-act principle, solidify the task parameters when creating the
|
||||
WriteCode object, rather than passing them in when calling the run function.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.report import EditorReporter
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
# System
|
||||
|
|
@ -126,18 +128,27 @@ or
|
|||
class WriteCodeReview(Action):
|
||||
name: str = "WriteCodeReview"
|
||||
i_context: CodingContext = Field(default_factory=CodingContext)
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
|
||||
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, doc):
|
||||
filename = doc.filename
|
||||
cr_rsp = await self._aask(context_prompt + cr_prompt)
|
||||
result = CodeParser.parse_block("Code Review Result", cr_rsp)
|
||||
if "LGTM" in result:
|
||||
return result, None
|
||||
|
||||
# if LBTM, rewrite code
|
||||
rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}"
|
||||
code_rsp = await self._aask(rewrite_prompt)
|
||||
code = CodeParser.parse_code(text=code_rsp)
|
||||
async with EditorReporter(enable_llm_stream=True) as reporter:
|
||||
await reporter.async_report(
|
||||
{"type": "code", "filename": filename, "src_path": doc.root_relative_path}, "meta"
|
||||
)
|
||||
rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}"
|
||||
code_rsp = await self._aask(rewrite_prompt)
|
||||
code = CodeParser.parse_code(text=code_rsp)
|
||||
doc.content = code
|
||||
await reporter.async_report(doc, "document")
|
||||
return result, code
|
||||
|
||||
async def run(self, *args, **kwargs) -> CodingContext:
|
||||
|
|
@ -150,7 +161,7 @@ class WriteCodeReview(Action):
|
|||
code_context = await WriteCode.get_codes(
|
||||
self.i_context.task_doc,
|
||||
exclude=self.i_context.filename,
|
||||
project_repo=self.repo.with_src_path(self.context.src_workspace),
|
||||
project_repo=self.repo,
|
||||
use_inc=self.config.inc,
|
||||
)
|
||||
|
||||
|
|
@ -160,7 +171,7 @@ class WriteCodeReview(Action):
|
|||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
if self.config.inc:
|
||||
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
requirement_doc = await Document.load(filename=self.input_args.requirements_filename)
|
||||
insert_ctx_list = [
|
||||
"## User New Requirements\n" + str(requirement_doc) + "\n",
|
||||
"## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n",
|
||||
|
|
@ -182,7 +193,7 @@ class WriteCodeReview(Action):
|
|||
f"len(self.i_context.code_doc.content)={len2}"
|
||||
)
|
||||
result, rewrited_code = await self.write_code_review_and_rewrite(
|
||||
context_prompt, cr_prompt, self.i_context.code_doc.filename
|
||||
context_prompt, cr_prompt, self.i_context.code_doc
|
||||
)
|
||||
if "LBTM" in result:
|
||||
iterative_code = rewrited_code
|
||||
|
|
|
|||
|
|
@ -9,12 +9,17 @@
|
|||
2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality.
|
||||
3. Move the document storage operations related to WritePRD from the save operation of WriteDesign.
|
||||
@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD.
|
||||
@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
|
@ -30,13 +35,16 @@ from metagpt.actions.write_prd_an import (
|
|||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
COMPETITIVE_ANALYSIS_FILE_REPO,
|
||||
DEFAULT_WORKSPACE_ROOT,
|
||||
REQUIREMENT_FILENAME,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import AIMessage, Document, Documents, Message
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import CodeParser, aread, awrite, to_markdown_code_block
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.report import DocsReporter, GalleryReporter
|
||||
|
||||
CONTEXT_TEMPLATE = """
|
||||
|
|
@ -59,6 +67,7 @@ NEW_REQ_TEMPLATE = """
|
|||
"""
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "write product requirement documents"])
|
||||
class WritePRD(Action):
|
||||
"""WritePRD deal with the following situations:
|
||||
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
|
||||
|
|
@ -66,10 +75,97 @@ class WritePRD(Action):
|
|||
3. Requirement update: If the requirement is an update, the PRD document will be updated.
|
||||
"""
|
||||
|
||||
async def run(self, with_messages, *args, **kwargs) -> Message:
|
||||
"""Run the action."""
|
||||
req: Document = await self.repo.requirement
|
||||
docs: list[Document] = await self.repo.docs.prd.get_all()
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
with_messages: List[Message] = None,
|
||||
*,
|
||||
user_requirement: str = "",
|
||||
output_pathname: str = "",
|
||||
legacy_prd_filename: str = "",
|
||||
extra_info: str = "",
|
||||
**kwargs,
|
||||
) -> AIMessage:
|
||||
"""
|
||||
Write a Product Requirement Document.
|
||||
|
||||
Args:
|
||||
user_requirement (str): A string detailing the user's requirements.
|
||||
output_pathname (str, optional): The path name of file that the output document should be saved to. Defaults to "".
|
||||
legacy_prd_filename (str, optional): The file path of the legacy Product Requirement Document to use as a reference. Defaults to "".
|
||||
extra_info (str, optional): Additional information to include in the document. Defaults to "".
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
AIMessage: The resulting message after generating the Product Requirement Document.
|
||||
|
||||
Example:
|
||||
# Write a new PRD(Product Requirement Document)
|
||||
>>> user_requirement = "YOUR REQUIREMENTS"
|
||||
>>> extra_info = "YOUR EXTRA INFO"
|
||||
>>> write_prd = WritePRD()
|
||||
>>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info)
|
||||
>>> print(result.content)
|
||||
PRD filename: "/path/to/prd/directory/213434ad.json"
|
||||
|
||||
# Modify a exists PRD(Product Requirement Document)
|
||||
>>> user_requirement = "YOUR REQUIREMENTS"
|
||||
>>> extra_info = "YOUR EXTRA INFO"
|
||||
>>> legacy_prd_filename = "/path/to/exists/prd_filename"
|
||||
>>> write_prd = WritePRD()
|
||||
>>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename)
|
||||
>>> print(result.content)
|
||||
PRD filename: "/path/to/prd/directory/213434ad.json"
|
||||
|
||||
# Write and save a new PRD(Product Requirement Document) to the path name.
|
||||
>>> user_requirement = "YOUR REQUIREMENTS"
|
||||
>>> extra_info = "YOUR EXTRA INFO"
|
||||
>>> output_pathname = "/path/to/prd/directory/213434ad.json"
|
||||
>>> write_prd = WritePRD()
|
||||
>>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, output_pathname=output_pathname)
|
||||
>>> print(result.content)
|
||||
PRD filename: "/path/to/prd/directory/213434ad.json"
|
||||
|
||||
# Modify a exists PRD(Product Requirement Document) and save to the path name.
|
||||
>>> user_requirement = "YOUR REQUIREMENTS"
|
||||
>>> extra_info = "YOUR EXTRA INFO"
|
||||
>>> legacy_prd_filename = "/path/to/exists/prd_filename"
|
||||
>>> output_pathname = "/path/to/prd/directory/213434ad.json"
|
||||
>>> write_prd = WritePRD()
|
||||
>>> result = await write_prd.run(user_requirement=user_requirement, extra_info=extra_info, legacy_prd_filename=legacy_prd_filename, output_pathname=output_pathname)
|
||||
>>> print(result.content)
|
||||
PRD filename: "/path/to/prd/directory/213434ad.json"
|
||||
|
||||
"""
|
||||
if not with_messages:
|
||||
return await self._execute_api(
|
||||
user_requirement=user_requirement,
|
||||
output_pathname=output_pathname,
|
||||
legacy_prd_filename=legacy_prd_filename,
|
||||
extra_info=extra_info,
|
||||
)
|
||||
|
||||
self.input_args = with_messages[-1].instruct_content
|
||||
if not self.input_args:
|
||||
self.repo = ProjectRepo(self.context.kwargs.project_path)
|
||||
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content)
|
||||
self.input_args = AIMessage.create_instruct_value(
|
||||
kvs={
|
||||
"project_path": self.context.kwargs.project_path,
|
||||
"requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME),
|
||||
"prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files],
|
||||
},
|
||||
class_name="PrepareDocumentsOutput",
|
||||
)
|
||||
else:
|
||||
self.repo = ProjectRepo(self.input_args.project_path)
|
||||
req = await Document.load(filename=self.input_args.requirements_filename)
|
||||
docs: list[Document] = [
|
||||
await Document.load(filename=i, project_path=self.repo.workdir) for i in self.input_args.prd_filenames
|
||||
]
|
||||
|
||||
if not req:
|
||||
raise FileNotFoundError("No requirement document found.")
|
||||
|
||||
|
|
@ -82,10 +178,18 @@ class WritePRD(Action):
|
|||
# if requirement is related to other documents, update them, otherwise create a new one
|
||||
if related_docs := await self.get_related_docs(req, docs):
|
||||
logger.info(f"Requirement update detected: {req.content}")
|
||||
await self._handle_requirement_update(req, related_docs)
|
||||
await self._handle_requirement_update(req=req, related_docs=related_docs)
|
||||
else:
|
||||
logger.info(f"New requirement detected: {req.content}")
|
||||
await self._handle_new_requirement(req)
|
||||
|
||||
kvs = self.input_args.model_dump()
|
||||
kvs["changed_prd_filenames"] = [
|
||||
str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys())
|
||||
]
|
||||
kvs["project_path"] = str(self.repo.workdir)
|
||||
kvs["requirements_filename"] = str(self.repo.docs.workdir / REQUIREMENT_FILENAME)
|
||||
self.context.kwargs.project_path = str(self.repo.workdir)
|
||||
return AIMessage(
|
||||
content="PRD is completed. "
|
||||
+ "\n".join(
|
||||
|
|
@ -93,6 +197,7 @@ class WritePRD(Action):
|
|||
+ list(self.repo.resources.prd.changed_files.keys())
|
||||
+ list(self.repo.resources.competitive_analysis.changed_files.keys())
|
||||
),
|
||||
instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput"),
|
||||
cause_by=self,
|
||||
)
|
||||
|
||||
|
|
@ -103,19 +208,31 @@ class WritePRD(Action):
|
|||
return AIMessage(
|
||||
content=f"A new issue is received: {BUGFIX_FILENAME}",
|
||||
cause_by=FixBug,
|
||||
instruct_content=AIMessage.create_instruct_value(
|
||||
{
|
||||
"project_path": str(self.repo.workdir),
|
||||
"issue_filename": str(self.repo.docs.workdir / BUGFIX_FILENAME),
|
||||
"requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME),
|
||||
},
|
||||
class_name="IssueDetail",
|
||||
),
|
||||
send_to="Alex", # the name of Engineer
|
||||
)
|
||||
|
||||
async def _new_prd(self, requirement: str) -> ActionNode:
|
||||
project_name = self.project_name
|
||||
context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name)
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(
|
||||
context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
|
||||
) # schema=schema
|
||||
return node
|
||||
|
||||
async def _handle_new_requirement(self, req: Document) -> ActionOutput:
|
||||
"""handle new requirement"""
|
||||
async with DocsReporter(enable_llm_stream=True) as reporter:
|
||||
await reporter.async_report({"type": "prd"}, "meta")
|
||||
project_name = self.project_name
|
||||
context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name)
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(
|
||||
context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
|
||||
) # schema=schema
|
||||
node = await self._new_prd(req.content)
|
||||
await self._rename_workspace(node)
|
||||
new_prd_doc = await self.repo.docs.prd.save(
|
||||
filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json()
|
||||
|
|
@ -128,7 +245,7 @@ class WritePRD(Action):
|
|||
async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput:
|
||||
# ... requirement update logic ...
|
||||
for doc in related_docs:
|
||||
await self._update_prd(req, doc)
|
||||
await self._update_prd(req=req, prd_doc=doc)
|
||||
return Documents.from_iterable(documents=related_docs).to_action_output()
|
||||
|
||||
async def _is_bugfix(self, context: str) -> bool:
|
||||
|
|
@ -159,7 +276,7 @@ class WritePRD(Action):
|
|||
async def _update_prd(self, req: Document, prd_doc: Document) -> Document:
|
||||
async with DocsReporter(enable_llm_stream=True) as reporter:
|
||||
await reporter.async_report({"type": "prd"}, "meta")
|
||||
new_prd_doc: Document = await self._merge(req, prd_doc)
|
||||
new_prd_doc: Document = await self._merge(req=req, related_doc=prd_doc)
|
||||
await self.repo.docs.prd.save_doc(doc=new_prd_doc)
|
||||
await self._save_competitive_analysis(new_prd_doc)
|
||||
md = await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
|
||||
|
|
@ -186,4 +303,29 @@ class WritePRD(Action):
|
|||
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
|
||||
if ws_name:
|
||||
self.project_name = ws_name
|
||||
self.repo.git_repo.rename_root(self.project_name)
|
||||
if self.repo:
|
||||
self.repo.git_repo.rename_root(self.project_name)
|
||||
|
||||
async def _execute_api(
|
||||
self, user_requirement: str, output_pathname: str, legacy_prd_filename: str, extra_info: str
|
||||
) -> AIMessage:
|
||||
content = "#### User Requirements\n{user_requirement}\n#### Extra Info\n{extra_info}\n".format(
|
||||
user_requirement=to_markdown_code_block(val=user_requirement),
|
||||
extra_info=to_markdown_code_block(val=extra_info),
|
||||
)
|
||||
req = Document(content=content)
|
||||
if not legacy_prd_filename:
|
||||
node = await self._new_prd(requirement=req.content)
|
||||
new_prd = Document(content=node.instruct_content.model_dump_json())
|
||||
else:
|
||||
content = await aread(filename=legacy_prd_filename)
|
||||
old_prd = Document(content=content)
|
||||
new_prd = await self._merge(req=req, related_doc=old_prd)
|
||||
|
||||
if not output_pathname:
|
||||
output_path = DEFAULT_WORKSPACE_ROOT
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
output_pathname = Path(output_path) / f"{uuid.uuid4().hex}.json"
|
||||
await awrite(filename=output_pathname, data=new_prd.content)
|
||||
kvs = AIMessage.create_instruct_value({"changed_prd_filenames": [str(output_pathname)]})
|
||||
return AIMessage(content=f'PRD filename: "{str(output_pathname)}"', instruct_content=kvs)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
@Author : alexanderwu
|
||||
@File : write_prd_an.py
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
|
|
@ -132,7 +132,7 @@ REQUIREMENT_ANALYSIS = ActionNode(
|
|||
|
||||
REFINED_REQUIREMENT_ANALYSIS = ActionNode(
|
||||
key="Refined Requirement Analysis",
|
||||
expected_type=List[str],
|
||||
expected_type=Union[List[str], str],
|
||||
instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project "
|
||||
"due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements "
|
||||
"required for the refined project scope.",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
@ -48,6 +49,9 @@ class Config(CLIParams, YamlModel):
|
|||
# Key Parameters
|
||||
llm: LLMConfig
|
||||
|
||||
# RAG Embedding
|
||||
embedding: EmbeddingConfig = EmbeddingConfig()
|
||||
|
||||
# Global Proxy. Not used by LLM, but by other tools such as browsers.
|
||||
proxy: str = ""
|
||||
|
||||
|
|
|
|||
50
metagpt/configs/embedding_config.py
Normal file
50
metagpt/configs/embedding_config.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class EmbeddingType(Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class EmbeddingConfig(YamlModel):
|
||||
"""Config for Embedding.
|
||||
|
||||
Examples:
|
||||
---------
|
||||
api_type: "openai"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "azure"
|
||||
api_key: "YOU_API_KEY"
|
||||
base_url: "YOU_BASE_URL"
|
||||
api_version: "YOU_API_VERSION"
|
||||
|
||||
api_type: "gemini"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "ollama"
|
||||
base_url: "YOU_BASE_URL"
|
||||
model: "YOU_MODEL"
|
||||
"""
|
||||
|
||||
api_type: Optional[EmbeddingType] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None
|
||||
embed_batch_size: Optional[int] = None
|
||||
|
||||
@field_validator("api_type", mode="before")
|
||||
@classmethod
|
||||
def check_api_type(cls, v):
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
|
@ -8,7 +8,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -22,8 +21,6 @@ from metagpt.utils.cost_manager import (
|
|||
FireworksCostManager,
|
||||
TokenCostManager,
|
||||
)
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class AttrDict(BaseModel):
|
||||
|
|
@ -66,9 +63,6 @@ class Context(BaseModel):
|
|||
kwargs: AttrDict = AttrDict()
|
||||
config: Config = Config.default()
|
||||
|
||||
repo: Optional[ProjectRepo] = None
|
||||
git_repo: Optional[GitRepository] = None
|
||||
src_workspace: Optional[Path] = None
|
||||
cost_manager: CostManager = CostManager()
|
||||
|
||||
_llm: Optional[BaseLLM] = None
|
||||
|
|
@ -80,11 +74,6 @@ class Context(BaseModel):
|
|||
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
def set_repo_dir(self, path: str | Path):
|
||||
repo_path = Path(path)
|
||||
self.git_repo = GitRepository(local_path=repo_path, auto_init=True)
|
||||
self.repo = ProjectRepo(self.git_repo)
|
||||
|
||||
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
||||
"""Return a CostManager instance"""
|
||||
if llm_config.api_type == LLMType.FIREWORKS:
|
||||
|
|
@ -117,7 +106,6 @@ class Context(BaseModel):
|
|||
Dict[str, Any]: A dictionary containing serialized data.
|
||||
"""
|
||||
return {
|
||||
"workdir": str(self.repo.workdir) if self.repo else "",
|
||||
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
|
||||
"cost_manager": self.cost_manager.model_dump_json(),
|
||||
}
|
||||
|
|
@ -130,13 +118,6 @@ class Context(BaseModel):
|
|||
"""
|
||||
if not serialized_data:
|
||||
return
|
||||
workdir = serialized_data.get("workdir")
|
||||
if workdir:
|
||||
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
|
||||
self.repo = ProjectRepo(self.git_repo)
|
||||
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
if src_workspace.exists():
|
||||
self.src_workspace = src_workspace
|
||||
kwargs = serialized_data.get("kwargs")
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from metagpt.logs import logger
|
|||
from metagpt.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from metagpt.roles.role import Role # noqa: F401
|
||||
|
|
@ -243,8 +244,9 @@ class Environment(ExtEnv):
|
|||
self.member_addrs[obj] = addresses
|
||||
|
||||
def archive(self, auto_archive=True):
|
||||
if auto_archive and self.context.git_repo:
|
||||
self.context.git_repo.archive()
|
||||
if auto_archive and self.context.kwargs.get("project_path"):
|
||||
git_repo = GitRepository(self.context.kwargs.project_path)
|
||||
git_repo.archive()
|
||||
|
||||
@classmethod
|
||||
def model_rebuild(cls, **kwargs):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
|
|
@ -63,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
response_synthesizer: Optional[BaseSynthesizer] = None,
|
||||
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
index: Optional[BaseIndex] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
retriever=retriever,
|
||||
|
|
@ -71,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
self.index = index
|
||||
self._transformations = transformations or self._default_transformations()
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
|
|
@ -103,12 +103,17 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
cls._fix_document_metadata(documents)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
transformations = transformations or cls._default_transformations()
|
||||
nodes = run_transformations(documents, transformations=transformations)
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_objs(
|
||||
|
|
@ -137,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_index(
|
||||
|
|
@ -161,6 +169,13 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
|
||||
def retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
|
||||
nodes = super().retrieve(query_bundle)
|
||||
self._try_reconstruct_obj(nodes)
|
||||
return nodes
|
||||
|
||||
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Allow query to be str."""
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
|
|
@ -176,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
self._fix_document_metadata(documents)
|
||||
|
||||
nodes = run_transformations(documents, transformations=self.index._transformations)
|
||||
nodes = run_transformations(documents, transformations=self._transformations)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def add_objs(self, objs: list[RAGObject]):
|
||||
|
|
@ -192,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_nodes(
|
||||
cls,
|
||||
nodes: list[BaseNode],
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
transformations=transformations,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
|
|
@ -201,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
|
|
@ -208,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
|
|
@ -259,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
return embed_model or get_rag_embedding()
|
||||
|
||||
@staticmethod
|
||||
def _default_transformations():
|
||||
return [SentenceSplitter()]
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ class GenericFactory:
|
|||
if creator:
|
||||
return creator(**kwargs)
|
||||
|
||||
self._raise_for_key(key)
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"Creator not registered for key: {key}")
|
||||
|
||||
|
||||
|
|
@ -33,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
|
|||
"""Designed to get objects based on object type."""
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Key is config, such as a pydantic model.
|
||||
"""Get instance by the type of key.
|
||||
|
||||
Call func by the type of key, and the key will be passed to func.
|
||||
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
|
||||
Raise Exception if key not found.
|
||||
"""
|
||||
creator = self._creators.get(type(key))
|
||||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
self._raise_for_key(key)
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
|
||||
|
||||
Return None if not found.
|
||||
"""
|
||||
if config is not None and hasattr(config, key):
|
||||
val = getattr(config, key)
|
||||
if val is not None:
|
||||
|
|
@ -54,6 +64,4 @@ class ConfigBasedFactory(GenericFactory):
|
|||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
|
||||
raise KeyError(
|
||||
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,37 +1,103 @@
|
|||
"""RAG Embedding Factory."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex Embedding with MetaGPT's config."""
|
||||
"""Create LlamaIndex Embedding with MetaGPT's embedding config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
EmbeddingType.OPENAI: self._create_openai,
|
||||
EmbeddingType.AZURE: self._create_azure,
|
||||
EmbeddingType.GEMINI: self._create_gemini,
|
||||
EmbeddingType.OLLAMA: self._create_ollama,
|
||||
# For backward compatibility
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
|
||||
"""Key is EmbeddingType."""
|
||||
return super().get_instance(key or self._resolve_embedding_type())
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
|
||||
def _resolve_embedding_type(self) -> EmbeddingType | LLMType:
|
||||
"""Resolves the embedding type.
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAIEmbedding(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE.
|
||||
Raise TypeError if embedding type not found.
|
||||
"""
|
||||
if config.embedding.api_type:
|
||||
return config.embedding.api_type
|
||||
|
||||
if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
|
||||
return config.llm.api_type
|
||||
|
||||
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
|
||||
|
||||
def _create_openai(self) -> OpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
api_base=config.embedding.base_url or config.llm.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return OpenAIEmbedding(**params)
|
||||
|
||||
def _create_azure(self) -> AzureOpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
azure_endpoint=config.embedding.base_url or config.llm.base_url,
|
||||
api_version=config.embedding.api_version or config.llm.api_version,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return AzureOpenAIEmbedding(**params)
|
||||
|
||||
def _create_gemini(self) -> GeminiEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key,
|
||||
api_base=config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return GeminiEmbedding(**params)
|
||||
|
||||
def _create_ollama(self) -> OllamaEmbedding:
|
||||
params = dict(
|
||||
base_url=config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return OllamaEmbedding(**params)
|
||||
|
||||
def _try_set_model_and_batch_size(self, params: dict):
|
||||
"""Set the model_name and embed_batch_size only when they are specified."""
|
||||
if config.embedding.model:
|
||||
params["model_name"] = config.embedding.model
|
||||
|
||||
if config.embedding.embed_batch_size:
|
||||
params["embed_batch_size"] = config.embedding.embed_batch_size
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}")
|
||||
|
||||
|
||||
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""RAG LLM."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
|
||||
|
|
@ -15,7 +15,7 @@ from pydantic import Field
|
|||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import run_coroutine_in_new_loop
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.token_counter import TOKEN_MAX
|
||||
|
||||
|
||||
|
|
@ -39,7 +39,8 @@ class RAGLLM(CustomLLM):
|
|||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
|
||||
NestAsyncio.apply_once()
|
||||
return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs))
|
||||
|
||||
@llm_completion_callback()
|
||||
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from metagpt.rag.factories.base import ConfigBasedFactory
|
|||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import (
|
||||
BaseRankerConfig,
|
||||
BGERerankConfig,
|
||||
CohereRerankConfig,
|
||||
ColbertRerankConfig,
|
||||
LLMRankerConfig,
|
||||
ObjectRankerConfig,
|
||||
|
|
@ -22,6 +24,8 @@ class RankerFactory(ConfigBasedFactory):
|
|||
LLMRankerConfig: self._create_llm_ranker,
|
||||
ColbertRerankConfig: self._create_colbert_ranker,
|
||||
ObjectRankerConfig: self._create_object_ranker,
|
||||
CohereRerankConfig: self._create_cohere_rerank,
|
||||
BGERerankConfig: self._create_bge_rerank,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -45,6 +49,26 @@ class RankerFactory(ConfigBasedFactory):
|
|||
)
|
||||
return ColbertRerank(**config.model_dump())
|
||||
|
||||
def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank:
|
||||
try:
|
||||
from llama_index.postprocessor.cohere_rerank import CohereRerank
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`"
|
||||
)
|
||||
return CohereRerank(**config.model_dump())
|
||||
|
||||
def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank:
|
||||
try:
|
||||
from llama_index.postprocessor.flag_embedding_reranker import (
|
||||
FlagEmbeddingReranker,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`"
|
||||
)
|
||||
return FlagEmbeddingReranker(**config.model_dump())
|
||||
|
||||
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
|
||||
return ObjectSortPostprocessor(**config.model_dump())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import copy
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
|
@ -24,10 +27,25 @@ from metagpt.rag.schema import (
|
|||
ElasticsearchKeywordRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_or_build_index(build_index_func):
|
||||
"""Decorator to get or build an index.
|
||||
|
||||
Get index using `_extract_index` method, if not found, using build_index_func.
|
||||
"""
|
||||
|
||||
@wraps(build_index_func)
|
||||
def wrapper(self, config, **kwargs):
|
||||
index = self._extract_index(config, **kwargs)
|
||||
if index is not None:
|
||||
return index
|
||||
return build_index_func(self, config, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
|
|
@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
|
||||
|
||||
def _create_default(self, **kwargs) -> RAGRetriever:
|
||||
return self._extract_index(**kwargs).as_retriever()
|
||||
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
|
||||
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_faiss_index(config, **kwargs)
|
||||
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
index = self._extract_index(config, **kwargs)
|
||||
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
|
||||
|
||||
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_chroma_index(config, **kwargs)
|
||||
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_es_index(config, **kwargs)
|
||||
|
||||
return ElasticsearchRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
|
||||
return self._val_from_config_or_kwargs("nodes", config, **kwargs)
|
||||
|
||||
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
||||
def _build_default_index(self, **kwargs) -> VectorStoreIndex:
|
||||
index = VectorStoreIndex(
|
||||
nodes=self._extract_nodes(**kwargs),
|
||||
embed_model=self._extract_embed_model(**kwargs),
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
@get_or_build_index
|
||||
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
def _build_index_from_vector_store(
|
||||
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
old_index = self._extract_index(config, **kwargs)
|
||||
new_index = VectorStoreIndex(
|
||||
nodes=list(old_index.docstore.docs.values()),
|
||||
index = VectorStoreIndex(
|
||||
nodes=self._extract_nodes(config, **kwargs),
|
||||
storage_context=storage_context,
|
||||
embed_model=old_index._embed_model,
|
||||
embed_model=self._extract_embed_model(config, **kwargs),
|
||||
)
|
||||
return new_index
|
||||
|
||||
return index
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
|
|
|
|||
|
|
@ -40,8 +40,10 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
if self._index:
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
if self._index:
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
|
@ -1,14 +1,17 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, ClassVar, Literal, Optional, Union
|
||||
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
|
|
@ -31,7 +34,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig):
|
|||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
|
||||
EmbeddingType.GEMINI: 768,
|
||||
EmbeddingType.OLLAMA: 4096,
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
|
|
@ -45,6 +60,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
|
|||
|
||||
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchStoreConfig(BaseModel):
|
||||
|
|
@ -101,6 +119,16 @@ class ColbertRerankConfig(BaseRankerConfig):
|
|||
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
|
||||
|
||||
|
||||
class CohereRerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="rerank-english-v3.0")
|
||||
api_key: str = Field(default="YOUR_COHERE_API")
|
||||
|
||||
|
||||
class BGERerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
|
||||
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
|
||||
|
||||
|
||||
class ObjectRankerConfig(BaseRankerConfig):
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
|
|
@ -130,6 +158,9 @@ class ChromaIndexConfig(VectorIndexConfig):
|
|||
"""Config for chroma-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
|
|
|
|||
|
|
@ -6,11 +6,9 @@
|
|||
@File : architect.py
|
||||
"""
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
class Architect(Role):
|
||||
|
|
@ -36,22 +34,7 @@ class Architect(Role):
|
|||
super().__init__(**kwargs)
|
||||
self.enable_memory = False
|
||||
# Initialize actions specific to the Architect role
|
||||
self.set_actions([PrepareDocuments(send_to=any_to_str(self), context=self.context), WriteDesign])
|
||||
self.set_actions([WriteDesign])
|
||||
|
||||
# Set events or actions the Architect should watch or be aware of
|
||||
self._watch({UserRequirement, PrepareDocuments, WritePRD})
|
||||
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
mappings = {
|
||||
any_to_str(UserRequirement): 0,
|
||||
any_to_str(PrepareDocuments): 1,
|
||||
any_to_str(WritePRD): 1,
|
||||
}
|
||||
for i in self.rc.news:
|
||||
idx = mappings.get(i.cause_by, -1)
|
||||
if idx < 0:
|
||||
continue
|
||||
self.rc.todo = self.actions[idx]
|
||||
return bool(self.rc.todo)
|
||||
return False
|
||||
self._watch({WritePRD})
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from metagpt.strategy.thinking_command import (
|
|||
)
|
||||
from metagpt.tools.tool_recommend import BM25ToolRecommender
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.report import ThoughtReporter
|
||||
|
||||
|
||||
class DataAnalyst(DataInterpreter):
|
||||
|
|
@ -82,8 +83,8 @@ class DataAnalyst(DataInterpreter):
|
|||
available_commands=prepare_command_prompt(self.available_commands),
|
||||
)
|
||||
context = self.llm.format_msg(self.working_memory.get() + [Message(content=prompt, role="user")])
|
||||
|
||||
rsp = await self.llm.aask(context)
|
||||
async with ThoughtReporter():
|
||||
rsp = await self.llm.aask(context)
|
||||
self.commands = json.loads(CodeParser.parse_code(block=None, text=rsp))
|
||||
self.rc.memory.add(Message(content=rsp, role="assistant"))
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from metagpt.schema import Message, Task, TaskResult
|
|||
from metagpt.strategy.task_type import TaskType
|
||||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.report import ThoughtReporter
|
||||
|
||||
REACT_THINK_PROMPT = """
|
||||
# User Requirement
|
||||
|
|
@ -73,7 +74,8 @@ class DataInterpreter(Role):
|
|||
return True
|
||||
|
||||
prompt = REACT_THINK_PROMPT.format(user_requirement=self.user_requirement, context=context)
|
||||
rsp = await self.llm.aask(prompt)
|
||||
async with ThoughtReporter():
|
||||
rsp = await self.llm.aask(prompt)
|
||||
rsp_dict = json.loads(CodeParser.parse_code(text=rsp))
|
||||
self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant"))
|
||||
need_action = rsp_dict["state"]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from metagpt.strategy.planner import Planner
|
|||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.report import ThoughtReporter
|
||||
|
||||
|
||||
@register_tool(include_functions=["ask_human", "reply_to_human"])
|
||||
|
|
@ -118,7 +119,8 @@ class RoleZero(Role):
|
|||
)
|
||||
context = self.llm.format_msg(self.rc.memory.get(self.memory_k) + [Message(content=prompt, role="user")])
|
||||
print(*context, sep="\n" + "*" * 5 + "\n")
|
||||
self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg)
|
||||
async with ThoughtReporter():
|
||||
self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg)
|
||||
self.rc.memory.add(Message(content=self.command_rsp, role="assistant"))
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -24,20 +24,15 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from metagpt.actions import (
|
||||
Action,
|
||||
UserRequirement,
|
||||
WriteCode,
|
||||
WriteCodeReview,
|
||||
WriteTasks,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import WriteCode, WriteCodeReview, WriteTasks
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange
|
||||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
CODE_PLAN_AND_CHANGE_FILE_REPO,
|
||||
MESSAGE_ROUTE_TO_SELF,
|
||||
REQUIREMENT_FILENAME,
|
||||
|
|
@ -63,6 +58,7 @@ from metagpt.utils.common import (
|
|||
init_python_folder,
|
||||
)
|
||||
from metagpt.utils.git_repository import ChangeType
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
IS_PASS_PROMPT = """
|
||||
{context}
|
||||
|
|
@ -100,23 +96,14 @@ class Engineer(Role):
|
|||
summarize_todos: list = []
|
||||
next_todo_action: str = ""
|
||||
n_summarize: int = 0
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.enable_memory = False
|
||||
self.set_actions([WriteCode])
|
||||
self._watch(
|
||||
[
|
||||
UserRequirement,
|
||||
PrepareDocuments,
|
||||
WriteTasks,
|
||||
SummarizeCode,
|
||||
WriteCode,
|
||||
WriteCodeReview,
|
||||
FixBug,
|
||||
WriteCodePlanAndChange,
|
||||
]
|
||||
)
|
||||
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug, WriteCodePlanAndChange])
|
||||
self.code_todos = []
|
||||
self.summarize_todos = []
|
||||
self.next_todo_action = any_to_name(WriteCode)
|
||||
|
|
@ -139,14 +126,20 @@ class Engineer(Role):
|
|||
coding_context = await todo.run()
|
||||
# Code review
|
||||
if review:
|
||||
action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm)
|
||||
action = WriteCodeReview(
|
||||
i_context=coding_context,
|
||||
repo=self.repo,
|
||||
input_args=self.input_args,
|
||||
context=self.context,
|
||||
llm=self.llm,
|
||||
)
|
||||
self._init_action(action)
|
||||
coding_context = await action.run()
|
||||
|
||||
dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path}
|
||||
if self.config.inc:
|
||||
dependencies.add(coding_context.code_plan_and_change_doc.root_relative_path)
|
||||
await self.project_repo.srcs.save(
|
||||
await self.repo.srcs.save(
|
||||
filename=coding_context.filename,
|
||||
dependencies=list(dependencies),
|
||||
content=coding_context.code_doc.content,
|
||||
|
|
@ -186,9 +179,9 @@ class Engineer(Role):
|
|||
summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name
|
||||
dependencies = {todo.i_context.design_filename, todo.i_context.task_filename}
|
||||
for filename in todo.i_context.codes_filenames:
|
||||
rpath = self.project_repo.src_relative_path / filename
|
||||
rpath = self.repo.src_relative_path / filename
|
||||
dependencies.add(str(rpath))
|
||||
await self.project_repo.resources.code_summary.save(
|
||||
await self.repo.resources.code_summary.save(
|
||||
filename=summary_filename, content=summary, dependencies=dependencies
|
||||
)
|
||||
is_pass, reason = await self._is_pass(summary)
|
||||
|
|
@ -196,23 +189,39 @@ class Engineer(Role):
|
|||
todo.i_context.reason = reason
|
||||
tasks.append(todo.i_context.model_dump())
|
||||
|
||||
await self.project_repo.docs.code_summary.save(
|
||||
await self.repo.docs.code_summary.save(
|
||||
filename=Path(todo.i_context.design_filename).name,
|
||||
content=todo.i_context.model_dump_json(),
|
||||
dependencies=dependencies,
|
||||
)
|
||||
else:
|
||||
await self.project_repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name)
|
||||
await self.repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name)
|
||||
self.summarize_todos = []
|
||||
logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}")
|
||||
if not tasks or self.config.max_auto_summarize_code == 0:
|
||||
self.n_summarize = 0
|
||||
kvs = self.input_args.model_dump()
|
||||
kvs["changed_src_filenames"] = [
|
||||
str(self.repo.srcs.workdir / i) for i in list(self.repo.srcs.changed_files.keys())
|
||||
]
|
||||
if self.repo.docs.code_plan_and_change.changed_files:
|
||||
kvs["changed_code_plan_and_change_filenames"] = [
|
||||
str(self.repo.docs.code_plan_and_change.workdir / i)
|
||||
for i in list(self.repo.docs.code_plan_and_change.changed_files.keys())
|
||||
]
|
||||
if self.repo.docs.code_summary.changed_files:
|
||||
kvs["changed_code_summary_filenames"] = [
|
||||
str(self.repo.docs.code_summary.workdir / i)
|
||||
for i in list(self.repo.docs.code_summary.changed_files.keys())
|
||||
]
|
||||
return AIMessage(
|
||||
content=f"Coding is complete. The source code is at {self.project_repo.workdir.name}/{self.project_repo.srcs.root_path}, containing: "
|
||||
content=f"Coding is complete. The source code is at {self.repo.workdir.name}/{self.repo.srcs.root_path}, containing: "
|
||||
+ "\n".join(
|
||||
list(self.project_repo.resources.code_summary.changed_files.keys())
|
||||
+ list(self.project_repo.srcs.changed_files.keys())
|
||||
list(self.repo.resources.code_summary.changed_files.keys())
|
||||
+ list(self.repo.srcs.changed_files.keys())
|
||||
+ list(self.repo.resources.code_plan_and_change.changed_files.keys())
|
||||
),
|
||||
instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="SummarizeCodeOutput"),
|
||||
cause_by=SummarizeCode,
|
||||
send_to="Edward", # The name of QaEngineer
|
||||
)
|
||||
|
|
@ -227,15 +236,15 @@ class Engineer(Role):
|
|||
code_plan_and_change = node.instruct_content.model_dump_json()
|
||||
dependencies = {
|
||||
REQUIREMENT_FILENAME,
|
||||
str(self.project_repo.docs.prd.root_path / self.rc.todo.i_context.prd_filename),
|
||||
str(self.project_repo.docs.system_design.root_path / self.rc.todo.i_context.design_filename),
|
||||
str(self.project_repo.docs.task.root_path / self.rc.todo.i_context.task_filename),
|
||||
str(Path(self.rc.todo.i_context.prd_filename).relative_to(self.repo.workdir)),
|
||||
str(Path(self.rc.todo.i_context.design_filename).relative_to(self.repo.workdir)),
|
||||
str(Path(self.rc.todo.i_context.task_filename).relative_to(self.repo.workdir)),
|
||||
}
|
||||
code_plan_and_change_filepath = Path(self.rc.todo.i_context.design_filename)
|
||||
await self.project_repo.docs.code_plan_and_change.save(
|
||||
await self.repo.docs.code_plan_and_change.save(
|
||||
filename=code_plan_and_change_filepath.name, content=code_plan_and_change, dependencies=dependencies
|
||||
)
|
||||
await self.project_repo.resources.code_plan_and_change.save(
|
||||
await self.repo.resources.code_plan_and_change.save(
|
||||
filename=code_plan_and_change_filepath.with_suffix(".md").name,
|
||||
content=node.content,
|
||||
dependencies=dependencies,
|
||||
|
|
@ -250,55 +259,49 @@ class Engineer(Role):
|
|||
return True, rsp
|
||||
return False, rsp
|
||||
|
||||
async def _think(self) -> Action | None:
|
||||
async def _think(self) -> bool:
|
||||
if not self.rc.news:
|
||||
return None
|
||||
return False
|
||||
msg = self.rc.news[0]
|
||||
if msg.cause_by == any_to_str(UserRequirement):
|
||||
self.rc.todo = PrepareDocuments(
|
||||
key_descriptions={
|
||||
"project_path": 'the project path if exists in "Original Requirement"',
|
||||
"src_filename": 'the file name of the source code file explicitly requested for modification if exists in "Original Requirement"',
|
||||
},
|
||||
context=self.context,
|
||||
send_to=any_to_str(self),
|
||||
)
|
||||
return self.rc.todo
|
||||
|
||||
if not self.src_workspace:
|
||||
self.src_workspace = get_project_srcs_path(self.project_repo.workdir)
|
||||
input_args = msg.instruct_content
|
||||
if msg.cause_by in {any_to_str(WriteTasks), any_to_str(FixBug)}:
|
||||
self.input_args = input_args
|
||||
self.repo = ProjectRepo(input_args.project_path)
|
||||
if self.repo.src_relative_path is None:
|
||||
path = get_project_srcs_path(self.repo.workdir)
|
||||
self.repo.with_src_path(path)
|
||||
write_plan_and_change_filters = any_to_str_set([PrepareDocuments, WriteTasks, FixBug])
|
||||
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode])
|
||||
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
|
||||
if self.config.inc and msg.cause_by in write_plan_and_change_filters:
|
||||
logger.debug(f"TODO WriteCodePlanAndChange:{msg.model_dump_json()}")
|
||||
await self._new_code_plan_and_change_action(cause_by=msg.cause_by)
|
||||
return self.rc.todo
|
||||
return bool(self.rc.todo)
|
||||
if msg.cause_by in write_code_filters:
|
||||
logger.debug(f"TODO WriteCode:{msg.model_dump_json()}")
|
||||
await self._new_code_actions()
|
||||
return self.rc.todo
|
||||
return bool(self.rc.todo)
|
||||
if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self):
|
||||
logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}")
|
||||
await self._new_summarize_actions()
|
||||
return self.rc.todo
|
||||
return None
|
||||
return bool(self.rc.todo)
|
||||
return False
|
||||
|
||||
async def _new_coding_context(self, filename, dependency) -> Optional[CodingContext]:
|
||||
old_code_doc = await self.project_repo.srcs.get(filename)
|
||||
old_code_doc = await self.repo.srcs.get(filename)
|
||||
if not old_code_doc:
|
||||
old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="")
|
||||
old_code_doc = Document(root_path=str(self.repo.src_relative_path), filename=filename, content="")
|
||||
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
|
||||
task_doc = None
|
||||
design_doc = None
|
||||
code_plan_and_change_doc = await self._get_any_code_plan_and_change() if await self._is_fixbug() else None
|
||||
for i in dependencies:
|
||||
if str(i.parent) == TASK_FILE_REPO:
|
||||
task_doc = await self.project_repo.docs.task.get(i.name)
|
||||
task_doc = await self.repo.docs.task.get(i.name)
|
||||
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
|
||||
design_doc = await self.project_repo.docs.system_design.get(i.name)
|
||||
design_doc = await self.repo.docs.system_design.get(i.name)
|
||||
elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO:
|
||||
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name)
|
||||
code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(i.name)
|
||||
if not task_doc or not design_doc:
|
||||
if filename == "__init__.py": # `__init__.py` created by `init_python_folder`
|
||||
return None
|
||||
|
|
@ -318,34 +321,66 @@ class Engineer(Role):
|
|||
if not context:
|
||||
return None # `__init__.py` created by `init_python_folder`
|
||||
coding_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json()
|
||||
root_path=str(self.repo.src_relative_path), filename=filename, content=context.model_dump_json()
|
||||
)
|
||||
return coding_doc
|
||||
|
||||
async def _new_code_actions(self):
|
||||
bug_fix = await self._is_fixbug()
|
||||
# Prepare file repos
|
||||
changed_src_files = self.project_repo.srcs.changed_files
|
||||
changed_src_files = self.repo.srcs.changed_files
|
||||
if self.context.kwargs.src_filename:
|
||||
changed_src_files = {self.context.kwargs.src_filename: ChangeType.UNTRACTED}
|
||||
if bug_fix:
|
||||
changed_src_files = self.project_repo.srcs.all_files
|
||||
changed_task_files = self.project_repo.docs.task.changed_files
|
||||
changed_src_files = self.repo.srcs.all_files
|
||||
changed_files = Documents()
|
||||
# Recode caused by upstream changes.
|
||||
for filename in changed_task_files:
|
||||
design_doc = await self.project_repo.docs.system_design.get(filename)
|
||||
task_doc = await self.project_repo.docs.task.get(filename)
|
||||
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename)
|
||||
if hasattr(self.input_args, "changed_task_filenames"):
|
||||
changed_task_filenames = self.input_args.changed_task_filenames
|
||||
else:
|
||||
changed_task_filenames = [
|
||||
str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())
|
||||
]
|
||||
for filename in changed_task_filenames:
|
||||
task_filename = Path(filename)
|
||||
design_filename = None
|
||||
if hasattr(self.input_args, "changed_system_design_filenames"):
|
||||
changed_system_design_filenames = self.input_args.changed_system_design_filenames
|
||||
else:
|
||||
changed_system_design_filenames = [
|
||||
str(self.repo.docs.system_design.workdir / i)
|
||||
for i in list(self.repo.docs.system_design.changed_files.keys())
|
||||
]
|
||||
for i in changed_system_design_filenames:
|
||||
if task_filename.name == Path(i).name:
|
||||
design_filename = Path(i)
|
||||
break
|
||||
code_plan_and_change_filename = None
|
||||
if hasattr(self.input_args, "changed_code_plan_and_change_filenames"):
|
||||
changed_code_plan_and_change_filenames = self.input_args.changed_code_plan_and_change_filenames
|
||||
else:
|
||||
changed_code_plan_and_change_filenames = [
|
||||
str(self.repo.docs.code_plan_and_change.workdir / i)
|
||||
for i in list(self.repo.docs.code_plan_and_change.changed_files.keys())
|
||||
]
|
||||
for i in changed_code_plan_and_change_filenames:
|
||||
if task_filename.name == Path(i).name:
|
||||
code_plan_and_change_filename = Path(i)
|
||||
break
|
||||
design_doc = await Document.load(filename=design_filename, project_path=self.repo.workdir)
|
||||
task_doc = await Document.load(filename=task_filename, project_path=self.repo.workdir)
|
||||
code_plan_and_change_doc = await Document.load(
|
||||
filename=code_plan_and_change_filename, project_path=self.repo.workdir
|
||||
)
|
||||
task_list = self._parse_tasks(task_doc)
|
||||
await self._init_python_folder(task_list)
|
||||
for task_filename in task_list:
|
||||
if self.context.kwargs.src_filename and task_filename != self.context.kwargs.src_filename:
|
||||
continue
|
||||
old_code_doc = await self.project_repo.srcs.get(task_filename)
|
||||
old_code_doc = await self.repo.srcs.get(task_filename)
|
||||
if not old_code_doc:
|
||||
old_code_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path), filename=task_filename, content=""
|
||||
root_path=str(self.repo.src_relative_path), filename=task_filename, content=""
|
||||
)
|
||||
if not code_plan_and_change_doc:
|
||||
context = CodingContext(
|
||||
|
|
@ -360,7 +395,7 @@ class Engineer(Role):
|
|||
code_plan_and_change_doc=code_plan_and_change_doc,
|
||||
)
|
||||
coding_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path),
|
||||
root_path=str(self.repo.src_relative_path),
|
||||
filename=task_filename,
|
||||
content=context.model_dump_json(),
|
||||
)
|
||||
|
|
@ -371,10 +406,11 @@ class Engineer(Role):
|
|||
)
|
||||
changed_files.docs[task_filename] = coding_doc
|
||||
self.code_todos = [
|
||||
WriteCode(i_context=i, context=self.context, llm=self.llm) for i in changed_files.docs.values()
|
||||
WriteCode(i_context=i, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm)
|
||||
for i in changed_files.docs.values()
|
||||
]
|
||||
# Code directly modified by the user.
|
||||
dependency = await self.git_repo.get_dependency()
|
||||
dependency = await self.repo.git_repo.get_dependency()
|
||||
for filename in changed_src_files:
|
||||
if filename in changed_files.docs:
|
||||
continue
|
||||
|
|
@ -382,24 +418,30 @@ class Engineer(Role):
|
|||
if not coding_doc:
|
||||
continue # `__init__.py` created by `init_python_folder`
|
||||
changed_files.docs[filename] = coding_doc
|
||||
self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm))
|
||||
self.code_todos.append(
|
||||
WriteCode(
|
||||
i_context=coding_doc, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm
|
||||
)
|
||||
)
|
||||
|
||||
if self.code_todos:
|
||||
self.set_todo(self.code_todos[0])
|
||||
|
||||
async def _new_summarize_actions(self):
|
||||
src_files = self.project_repo.srcs.all_files
|
||||
src_files = self.repo.srcs.all_files
|
||||
# Generate a SummarizeCode action for each pair of (system_design_doc, task_doc).
|
||||
summarizations = defaultdict(list)
|
||||
for filename in src_files:
|
||||
dependencies = await self.project_repo.srcs.get_dependency(filename=filename)
|
||||
dependencies = await self.repo.srcs.get_dependency(filename=filename)
|
||||
ctx = CodeSummarizeContext.loads(filenames=list(dependencies))
|
||||
summarizations[ctx].append(filename)
|
||||
for ctx, filenames in summarizations.items():
|
||||
if not ctx.design_filename or not ctx.task_filename:
|
||||
continue # cause by `__init__.py` which is created by `init_python_folder`
|
||||
ctx.codes_filenames = filenames
|
||||
new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm)
|
||||
new_summarize = SummarizeCode(
|
||||
i_context=ctx, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm
|
||||
)
|
||||
for i, act in enumerate(self.summarize_todos):
|
||||
if act.i_context.task_filename == new_summarize.i_context.task_filename:
|
||||
self.summarize_todos[i] = new_summarize
|
||||
|
|
@ -412,16 +454,37 @@ class Engineer(Role):
|
|||
|
||||
async def _new_code_plan_and_change_action(self, cause_by: str):
|
||||
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""
|
||||
files = self.project_repo.all_files
|
||||
options = {}
|
||||
if cause_by != any_to_str(FixBug):
|
||||
requirement_doc = await self.project_repo.docs.get(REQUIREMENT_FILENAME)
|
||||
requirement_doc = await Document.load(filename=self.input_args.requirements_filename)
|
||||
options["requirement"] = requirement_doc.content
|
||||
else:
|
||||
fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME)
|
||||
fixbug_doc = await Document.load(filename=self.input_args.issue_filename)
|
||||
options["issue"] = fixbug_doc.content
|
||||
code_plan_and_change_ctx = CodePlanAndChangeContext.loads(files, **options)
|
||||
self.rc.todo = WriteCodePlanAndChange(i_context=code_plan_and_change_ctx, context=self.context, llm=self.llm)
|
||||
# The code here is flawed: if there are multiple unrelated requirements, this piece of logic will break
|
||||
if hasattr(self.input_args, "changed_prd_filenames"):
|
||||
code_plan_and_change_ctx = CodePlanAndChangeContext(
|
||||
requirement=options.get("requirement", ""),
|
||||
issue=options.get("issue", ""),
|
||||
prd_filename=self.input_args.changed_prd_filenames[0],
|
||||
design_filename=self.input_args.changed_system_design_filenames[0],
|
||||
task_filename=self.input_args.changed_task_filenames[0],
|
||||
)
|
||||
else:
|
||||
code_plan_and_change_ctx = CodePlanAndChangeContext(
|
||||
requirement=options.get("requirement", ""),
|
||||
issue=options.get("issue", ""),
|
||||
prd_filename=str(self.repo.docs.prd.workdir / self.repo.docs.prd.all_files[0]),
|
||||
design_filename=str(self.repo.docs.system_design.workdir / self.repo.docs.system_design.all_files[0]),
|
||||
task_filename=str(self.repo.docs.task.workdir / self.repo.docs.task.all_files[0]),
|
||||
)
|
||||
self.rc.todo = WriteCodePlanAndChange(
|
||||
i_context=code_plan_and_change_ctx,
|
||||
repo=self.repo,
|
||||
input_args=self.input_args,
|
||||
context=self.context,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_description(self) -> str:
|
||||
|
|
@ -433,17 +496,16 @@ class Engineer(Role):
|
|||
filename = Path(i)
|
||||
if filename.suffix != ".py":
|
||||
continue
|
||||
workdir = self.src_workspace / filename.parent
|
||||
workdir = self.repo.srcs.workdir / filename.parent
|
||||
await init_python_folder(workdir)
|
||||
|
||||
async def _is_fixbug(self) -> bool:
|
||||
fixbug_doc = await self.project_repo.docs.get(BUGFIX_FILENAME)
|
||||
return bool(fixbug_doc and fixbug_doc.content)
|
||||
return bool(self.input_args and hasattr(self.input_args, "issue_filename"))
|
||||
|
||||
async def _get_any_code_plan_and_change(self) -> Optional[Document]:
|
||||
changed_files = self.project_repo.docs.code_plan_and_change.changed_files
|
||||
changed_files = self.repo.docs.code_plan_and_change.changed_files
|
||||
for filename in changed_files.keys():
|
||||
doc = await self.project_repo.docs.code_plan_and_change.get(filename)
|
||||
doc = await self.repo.docs.code_plan_and_change.get(filename)
|
||||
if doc and doc.content:
|
||||
return doc
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@
|
|||
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.roles.role import Role, RoleReactMode
|
||||
from metagpt.utils.common import any_to_name, any_to_str
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
class ProductManager(Role):
|
||||
|
|
@ -40,7 +42,7 @@ class ProductManager(Role):
|
|||
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
if self.git_repo and not self.config.git_reinit:
|
||||
if GitRepository.is_git_dir(self.config.project_path) and not self.config.git_reinit:
|
||||
self._set_state(1)
|
||||
else:
|
||||
self._set_state(0)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,9 @@
|
|||
@File : project_manager.py
|
||||
"""
|
||||
|
||||
from metagpt.actions import UserRequirement, WriteTasks
|
||||
from metagpt.actions import WriteTasks
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
class ProjectManager(Role):
|
||||
|
|
@ -35,20 +33,5 @@ class ProjectManager(Role):
|
|||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.enable_memory = False
|
||||
self.set_actions([PrepareDocuments(send_to=any_to_str(self), context=self.context), WriteTasks])
|
||||
self._watch([UserRequirement, PrepareDocuments, WriteDesign])
|
||||
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
mappings = {
|
||||
any_to_str(UserRequirement): 0,
|
||||
any_to_str(PrepareDocuments): 1,
|
||||
any_to_str(WriteDesign): 1,
|
||||
}
|
||||
for i in self.rc.news:
|
||||
idx = mappings.get(i.cause_by, -1)
|
||||
if idx < 0:
|
||||
continue
|
||||
self.rc.todo = self.actions[idx]
|
||||
return bool(self.rc.todo)
|
||||
return False
|
||||
self.set_actions([WriteTasks])
|
||||
self._watch([WriteDesign])
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@
|
|||
@Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results
|
||||
of SummarizeCode.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import DebugError, RunCode, UserRequirement, WriteTest
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
|
|
@ -25,9 +28,11 @@ from metagpt.schema import AIMessage, Document, Message, RunCodeContext, Testing
|
|||
from metagpt.utils.common import (
|
||||
any_to_str,
|
||||
any_to_str_set,
|
||||
get_project_srcs_path,
|
||||
init_python_folder,
|
||||
parse_recipient,
|
||||
)
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.report import EditorReporter
|
||||
|
||||
|
||||
|
|
@ -41,6 +46,8 @@ class QaEngineer(Role):
|
|||
)
|
||||
test_round_allowed: int = 5
|
||||
test_round: int = 0
|
||||
repo: Optional[ProjectRepo] = Field(default=None, exclude=True)
|
||||
input_args: Optional[BaseModel] = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -48,31 +55,26 @@ class QaEngineer(Role):
|
|||
|
||||
# FIXME: a bit hack here, only init one action to circumvent _think() logic,
|
||||
# will overwrite _think() in future updates
|
||||
self.set_actions(
|
||||
[
|
||||
WriteTest,
|
||||
]
|
||||
)
|
||||
self._watch([UserRequirement, PrepareDocuments, SummarizeCode, WriteTest, RunCode, DebugError])
|
||||
self.set_actions([WriteTest])
|
||||
self._watch([SummarizeCode, WriteTest, RunCode, DebugError])
|
||||
self.test_round = 0
|
||||
|
||||
async def _write_test(self, message: Message) -> None:
|
||||
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
|
||||
reqa_file = self.context.kwargs.reqa_file or self.config.reqa_file
|
||||
changed_files = {reqa_file} if reqa_file else set(src_file_repo.changed_files.keys())
|
||||
changed_files = {reqa_file} if reqa_file else set(self.repo.srcs.changed_files.keys())
|
||||
for filename in changed_files:
|
||||
# write tests
|
||||
if not filename or "test" in filename:
|
||||
continue
|
||||
code_doc = await src_file_repo.get(filename)
|
||||
if not code_doc:
|
||||
code_doc = await self.repo.srcs.get(filename)
|
||||
if not code_doc or not code_doc.content:
|
||||
continue
|
||||
if not code_doc.filename.endswith(".py"):
|
||||
continue
|
||||
test_doc = await self.project_repo.tests.get("test_" + code_doc.filename)
|
||||
test_doc = await self.repo.tests.get("test_" + code_doc.filename)
|
||||
if not test_doc:
|
||||
test_doc = Document(
|
||||
root_path=str(self.project_repo.tests.root_path), filename="test_" + code_doc.filename, content=""
|
||||
root_path=str(self.repo.tests.root_path), filename="test_" + code_doc.filename, content=""
|
||||
)
|
||||
logger.info(f"Writing {test_doc.filename}..")
|
||||
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
|
||||
|
|
@ -81,40 +83,38 @@ class QaEngineer(Role):
|
|||
async with EditorReporter(enable_llm_stream=True) as reporter:
|
||||
await reporter.async_report({"type": "test", "filename": test_doc.filename}, "meta")
|
||||
|
||||
doc = await self.project_repo.tests.save_doc(
|
||||
doc = await self.repo.tests.save_doc(
|
||||
doc=context.test_doc, dependencies={context.code_doc.root_relative_path}
|
||||
)
|
||||
await reporter.async_report(self.project_repo.workdir / doc.root_relative_path, "path")
|
||||
await reporter.async_report(self.repo.workdir / doc.root_relative_path, "path")
|
||||
|
||||
# prepare context for run tests in next round
|
||||
run_code_context = RunCodeContext(
|
||||
command=["python", context.test_doc.root_relative_path],
|
||||
code_filename=context.code_doc.filename,
|
||||
test_filename=context.test_doc.filename,
|
||||
working_directory=str(self.project_repo.workdir),
|
||||
additional_python_paths=[str(self.context.src_workspace)],
|
||||
working_directory=str(self.repo.workdir),
|
||||
additional_python_paths=[str(self.repo.srcs.workdir)],
|
||||
)
|
||||
self.publish_message(
|
||||
AIMessage(content=run_code_context.model_dump_json(), cause_by=WriteTest, send_to=MESSAGE_ROUTE_TO_SELF)
|
||||
)
|
||||
|
||||
logger.info(f"Done {str(self.project_repo.tests.workdir)} generating.")
|
||||
logger.info(f"Done {str(self.repo.tests.workdir)} generating.")
|
||||
|
||||
async def _run_code(self, msg):
|
||||
run_code_context = RunCodeContext.loads(msg.content)
|
||||
src_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
|
||||
run_code_context.code_filename
|
||||
)
|
||||
src_doc = await self.repo.srcs.get(run_code_context.code_filename)
|
||||
if not src_doc:
|
||||
return
|
||||
test_doc = await self.project_repo.tests.get(run_code_context.test_filename)
|
||||
test_doc = await self.repo.tests.get(run_code_context.test_filename)
|
||||
if not test_doc:
|
||||
return
|
||||
run_code_context.code = src_doc.content
|
||||
run_code_context.test_code = test_doc.content
|
||||
result = await RunCode(i_context=run_code_context, context=self.context, llm=self.llm).run()
|
||||
run_code_context.output_filename = run_code_context.test_filename + ".json"
|
||||
await self.project_repo.test_outputs.save(
|
||||
await self.repo.test_outputs.save(
|
||||
filename=run_code_context.output_filename,
|
||||
content=result.model_dump_json(),
|
||||
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},
|
||||
|
|
@ -124,31 +124,53 @@ class QaEngineer(Role):
|
|||
# the recipient might be Engineer or myself
|
||||
recipient = parse_recipient(result.summary)
|
||||
mappings = {"Engineer": "Alex", "QaEngineer": "Edward"}
|
||||
self.publish_message(
|
||||
AIMessage(
|
||||
content=run_code_context.model_dump_json(),
|
||||
cause_by=RunCode,
|
||||
send_to=mappings.get(recipient, MESSAGE_ROUTE_TO_NONE),
|
||||
if recipient != "Engineer":
|
||||
self.publish_message(
|
||||
AIMessage(
|
||||
content=run_code_context.model_dump_json(),
|
||||
cause_by=RunCode,
|
||||
instruct_content=self.input_args,
|
||||
send_to=MESSAGE_ROUTE_TO_SELF,
|
||||
)
|
||||
)
|
||||
else:
|
||||
kvs = self.input_args.model_dump()
|
||||
kvs["changed_test_filenames"] = [
|
||||
str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys())
|
||||
]
|
||||
self.publish_message(
|
||||
AIMessage(
|
||||
content=run_code_context.model_dump_json(),
|
||||
cause_by=RunCode,
|
||||
instruct_content=self.input_args,
|
||||
send_to=mappings.get(recipient, MESSAGE_ROUTE_TO_NONE),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _debug_error(self, msg):
|
||||
run_code_context = RunCodeContext.loads(msg.content)
|
||||
code = await DebugError(i_context=run_code_context, context=self.context, llm=self.llm).run()
|
||||
await self.project_repo.tests.save(filename=run_code_context.test_filename, content=code)
|
||||
code = await DebugError(
|
||||
i_context=run_code_context, repo=self.repo, input_args=self.input_args, context=self.context, llm=self.llm
|
||||
).run()
|
||||
await self.repo.tests.save(filename=run_code_context.test_filename, content=code)
|
||||
run_code_context.output = None
|
||||
self.publish_message(
|
||||
AIMessage(content=run_code_context.model_dump_json(), cause_by=DebugError, send_to=MESSAGE_ROUTE_TO_SELF)
|
||||
)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
if self.project_path:
|
||||
await init_python_folder(self.project_repo.tests.workdir)
|
||||
if self.input_args.project_path:
|
||||
await init_python_folder(self.repo.tests.workdir)
|
||||
if self.test_round > self.test_round_allowed:
|
||||
kvs = self.input_args.model_dump()
|
||||
kvs["changed_test_filenames"] = [
|
||||
str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys())
|
||||
]
|
||||
result_msg = AIMessage(
|
||||
content=f"Exceeding {self.test_round_allowed} rounds of tests, stop. "
|
||||
+ "\n".join(list(self.project_repo.tests.changed_files.keys())),
|
||||
+ "\n".join(list(self.repo.tests.changed_files.keys())),
|
||||
cause_by=WriteTest,
|
||||
instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTestOutput"),
|
||||
send_to=MESSAGE_ROUTE_TO_NONE,
|
||||
)
|
||||
return result_msg
|
||||
|
|
@ -171,8 +193,13 @@ class QaEngineer(Role):
|
|||
elif msg.cause_by == any_to_str(UserRequirement):
|
||||
return await self._parse_user_requirement(msg)
|
||||
self.test_round += 1
|
||||
kvs = self.input_args.model_dump()
|
||||
kvs["changed_test_filenames"] = [
|
||||
str(self.repo.tests.workdir / i) for i in list(self.repo.tests.changed_files.keys())
|
||||
]
|
||||
return AIMessage(
|
||||
content=f"Round {self.test_round} of tests done",
|
||||
instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTestOutput"),
|
||||
cause_by=WriteTest,
|
||||
send_to=MESSAGE_ROUTE_TO_NONE,
|
||||
)
|
||||
|
|
@ -190,3 +217,15 @@ class QaEngineer(Role):
|
|||
if not self.src_workspace:
|
||||
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
return rsp
|
||||
|
||||
async def _think(self) -> bool:
|
||||
if not self.rc.news:
|
||||
return False
|
||||
msg = self.rc.news[0]
|
||||
if msg.cause_by == any_to_str(SummarizeCode):
|
||||
self.input_args = msg.instruct_content
|
||||
self.repo = ProjectRepo(self.input_args.project_path)
|
||||
if self.repo.src_relative_path is None:
|
||||
path = get_project_srcs_path(self.repo.workdir)
|
||||
self.repo.with_src_path(path)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ from metagpt.schema import (
|
|||
)
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -196,29 +195,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
value.context = self.context
|
||||
self.rc.todo = value
|
||||
|
||||
@property
|
||||
def git_repo(self):
|
||||
"""Git repo"""
|
||||
return self.context.git_repo
|
||||
|
||||
@git_repo.setter
|
||||
def git_repo(self, value):
|
||||
self.context.git_repo = value
|
||||
|
||||
@property
|
||||
def src_workspace(self):
|
||||
"""Source workspace under git repo"""
|
||||
return self.context.src_workspace
|
||||
|
||||
@src_workspace.setter
|
||||
def src_workspace(self, value):
|
||||
self.context.src_workspace = value
|
||||
|
||||
@property
|
||||
def project_repo(self) -> ProjectRepo:
|
||||
project_repo = ProjectRepo(self.context.git_repo)
|
||||
return project_repo.with_src_path(self.context.src_workspace) if self.context.src_workspace else project_repo
|
||||
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
"""Prompt schema: json/markdown"""
|
||||
|
|
@ -410,8 +386,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
msg = response
|
||||
else:
|
||||
msg = AIMessage(content=response or "", cause_by=self.rc.todo, sent_from=self)
|
||||
if self.enable_memory:
|
||||
self.rc.memory.add(msg)
|
||||
self.rc.memory.add(msg)
|
||||
|
||||
return msg
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from pydantic import (
|
|||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
create_model,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
|
|
@ -43,14 +44,19 @@ from metagpt.const import (
|
|||
MESSAGE_ROUTE_FROM,
|
||||
MESSAGE_ROUTE_TO,
|
||||
MESSAGE_ROUTE_TO_ALL,
|
||||
PRDS_FILE_REPO,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import DotClassInfo
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import CodeParser, any_to_str, any_to_str_set, import_class
|
||||
from metagpt.utils.common import (
|
||||
CodeParser,
|
||||
any_to_str,
|
||||
any_to_str_set,
|
||||
aread,
|
||||
import_class,
|
||||
)
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.report import TaskReporter
|
||||
from metagpt.utils.serialize import (
|
||||
|
|
@ -158,6 +164,30 @@ class Document(BaseModel):
|
|||
def __repr__(self):
|
||||
return self.content
|
||||
|
||||
@classmethod
|
||||
async def load(
|
||||
cls, filename: Union[str, Path], project_path: Optional[Union[str, Path]] = None
|
||||
) -> Optional["Document"]:
|
||||
"""
|
||||
Load a document from a file.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The path to the file to load.
|
||||
project_path (Optional[Union[str, Path]], optional): The path to the project. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Optional[Document]: The loaded document, or None if the file does not exist.
|
||||
|
||||
"""
|
||||
if not filename or not Path(filename).exists():
|
||||
return None
|
||||
content = await aread(filename=filename)
|
||||
doc = cls(content=content, filename=str(filename))
|
||||
if project_path and Path(filename).is_relative_to(project_path):
|
||||
doc.root_path = Path(filename).relative_to(project_path).parent
|
||||
doc.filename = Path(filename).name
|
||||
return doc
|
||||
|
||||
|
||||
class Documents(BaseModel):
|
||||
"""A class representing a collection of documents.
|
||||
|
|
@ -361,6 +391,22 @@ class Message(BaseModel):
|
|||
def add_metadata(self, key: str, value: str):
|
||||
self.metadata[key] = value
|
||||
|
||||
@staticmethod
|
||||
def create_instruct_value(kvs: Dict[str, Any], class_name: str = "") -> BaseModel:
|
||||
"""
|
||||
Dynamically creates a Pydantic BaseModel subclass based on a given dictionary.
|
||||
|
||||
Parameters:
|
||||
- data: A dictionary from which to create the BaseModel subclass.
|
||||
|
||||
Returns:
|
||||
- A Pydantic BaseModel subclass instance populated with the given data.
|
||||
"""
|
||||
if not class_name:
|
||||
class_name = "DM" + uuid.uuid4().hex[0:8]
|
||||
dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()})
|
||||
return dynamic_class.model_validate(kvs)
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
|
|
@ -787,22 +833,6 @@ class CodePlanAndChangeContext(BaseModel):
|
|||
design_filename: str = ""
|
||||
task_filename: str = ""
|
||||
|
||||
@staticmethod
|
||||
def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext:
|
||||
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", ""))
|
||||
for filename in filenames:
|
||||
filename = Path(filename)
|
||||
if filename.is_relative_to(PRDS_FILE_REPO):
|
||||
ctx.prd_filename = filename.name
|
||||
continue
|
||||
if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO):
|
||||
ctx.design_filename = filename.name
|
||||
continue
|
||||
if filename.is_relative_to(TASK_FILE_REPO):
|
||||
ctx.task_filename = filename.name
|
||||
continue
|
||||
return ctx
|
||||
|
||||
|
||||
# mermaid class view
|
||||
class UMLClassMeta(BaseModel):
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def generate_repo(
|
|||
company.run_project(idea, send_to=any_to_str(ProductManager))
|
||||
asyncio.run(company.run(n_round=n_round))
|
||||
|
||||
return ctx.repo
|
||||
return ctx.kwargs.get("project_path")
|
||||
|
||||
|
||||
@app.command("", help="Start a new project.")
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from uuid import uuid4
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.file import MemoryFileSystem
|
||||
from metagpt.utils.parse_html import simplify_html
|
||||
from metagpt.utils.report import BrowserReporter
|
||||
|
||||
|
||||
|
|
@ -35,16 +40,48 @@ class Browser:
|
|||
print("Now on page ", url)
|
||||
await self._view()
|
||||
|
||||
async def open_new_page(self, url: str):
|
||||
async def open_new_page(self, url: str, timeout: float = 30000):
|
||||
"""open a new page in the browser and view the page"""
|
||||
async with self.reporter as reporter:
|
||||
page = await self.browser.new_page()
|
||||
await reporter.async_report(url, "url")
|
||||
await page.goto(url)
|
||||
await page.goto(url, timeout=timeout)
|
||||
self.pages[url] = page
|
||||
await self._set_current_page(page, url)
|
||||
await reporter.async_report(page, "page")
|
||||
|
||||
async def view_page_element_to_scrape(self, requirement: str, keep_links: bool = False) -> None:
|
||||
"""view the HTML content of current page to understand the structure. When executed, the content will be printed out
|
||||
|
||||
Args:
|
||||
requirement (str): Providing a clear and detailed requirement helps in focusing the inspection on the desired elements.
|
||||
keep_links (bool): Whether to keep the hyperlinks in the HTML content. Set to True if links are required
|
||||
"""
|
||||
html = await self.current_page.content()
|
||||
html = simplify_html(html, url=self.current_page.url, keep_links=keep_links)
|
||||
mem_fs = MemoryFileSystem()
|
||||
filename = f"{uuid4().hex}.html"
|
||||
with mem_fs.open(filename, "w") as f:
|
||||
f.write(html)
|
||||
|
||||
# Since RAG is an optional optimization, if it fails, the simplified HTML can be used as a fallback.
|
||||
with contextlib.suppress(Exception):
|
||||
from metagpt.rag.engines import SimpleEngine # avoid circular import
|
||||
|
||||
# TODO make `from_docs` asynchronous
|
||||
engine = SimpleEngine.from_docs(input_files=[filename], fs=mem_fs)
|
||||
nodes = await engine.aretrieve(requirement)
|
||||
html = "\n".join(i.text for i in nodes)
|
||||
|
||||
mem_fs.rm_file(filename)
|
||||
print(html)
|
||||
|
||||
async def get_page_content(self) -> str:
|
||||
"""Get the HTML content of current page."""
|
||||
html = await self.current_page.content()
|
||||
html_content = html.strip()
|
||||
return html_content
|
||||
|
||||
async def switch_page(self, url: str):
|
||||
"""switch to an opened page in the browser and view the page"""
|
||||
if url in self.pages:
|
||||
|
|
@ -152,8 +189,8 @@ class Browser:
|
|||
|
||||
async def _view(self, keep_len: int = 5000) -> str:
|
||||
"""simulate human viewing the current page, return the visible text with links"""
|
||||
visible_text_with_links = await self.current_page.evaluate(VIEW_CONTENT_JS)
|
||||
print("The visible text and their links (if any): ", visible_text_with_links[:keep_len])
|
||||
# visible_text_with_links = await self.current_page.evaluate(VIEW_CONTENT_JS)
|
||||
# print("The visible text and their links (if any): ", visible_text_with_links[:keep_len])
|
||||
# html_content = await self._view_page_html(keep_len=keep_len)
|
||||
# print("The html content: ", html_content)
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class Editor:
|
|||
file_path=file_path,
|
||||
block_content=block_content,
|
||||
)
|
||||
self.resource.report(result.file_path, "path")
|
||||
self.resource.report(result.file_path, "path", extra={"type": "search", "line": i, "symbol": symbol})
|
||||
return result
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from github.Issue import Issue
|
|||
from github.PullRequest import PullRequest
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.git_repository import GitBranch, GitRepository
|
||||
|
||||
|
||||
@register_tool(tags=["software development", "git", "Commit the changes and push to remote git repository."])
|
||||
|
|
@ -18,7 +17,7 @@ async def git_push(
|
|||
access_token: str,
|
||||
comments: str = "Commit",
|
||||
new_branch: str = "",
|
||||
) -> GitBranch:
|
||||
) -> "GitBranch":
|
||||
"""
|
||||
Pushes changes from a local Git repository to its remote counterpart.
|
||||
|
||||
|
|
@ -49,6 +48,8 @@ async def git_push(
|
|||
base branch:'master', head branch:'feature/new', repo_name:'iorisa/snake-game'
|
||||
|
||||
"""
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
if not GitRepository.is_git_dir(local_path):
|
||||
raise ValueError("Invalid local git repository")
|
||||
|
||||
|
|
|
|||
|
|
@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any:
|
|||
new_loop.call_soon_threadsafe(new_loop.stop)
|
||||
t.join()
|
||||
new_loop.close()
|
||||
|
||||
|
||||
class NestAsyncio:
|
||||
"""Make asyncio event loop reentrant."""
|
||||
|
||||
is_applied = False
|
||||
|
||||
@classmethod
|
||||
def apply_once(cls):
|
||||
"""Ensures `nest_asyncio.apply()` is called only once."""
|
||||
if not cls.is_applied:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
cls.is_applied = True
|
||||
|
|
|
|||
|
|
@ -646,7 +646,7 @@ def role_raise_decorator(func):
|
|||
raise Exception(format_trackback_info(limit=None))
|
||||
except Exception as e:
|
||||
if self.latest_observed_msg:
|
||||
logger.warning(
|
||||
logger.exception(
|
||||
"There is a exception in role's execution, in order to resume, "
|
||||
"we delete the newest role communication message in the role's memory."
|
||||
)
|
||||
|
|
@ -667,6 +667,8 @@ def role_raise_decorator(func):
|
|||
@handle_exception
|
||||
async def aread(filename: str | Path, encoding="utf-8") -> str:
|
||||
"""Read file asynchronously."""
|
||||
if not filename or not Path(filename).exists():
|
||||
return ""
|
||||
try:
|
||||
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
|
||||
content = await reader.read()
|
||||
|
|
@ -899,3 +901,51 @@ async def init_python_folder(workdir: str | Path):
|
|||
return
|
||||
async with aiofiles.open(init_filename, "a"):
|
||||
os.utime(init_filename, None)
|
||||
|
||||
|
||||
def get_markdown_code_block_type(filename: str) -> str:
|
||||
if not filename:
|
||||
return ""
|
||||
ext = Path(filename).suffix
|
||||
types = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".java": "java",
|
||||
".cpp": "cpp",
|
||||
".c": "c",
|
||||
".html": "html",
|
||||
".css": "css",
|
||||
".xml": "xml",
|
||||
".json": "json",
|
||||
".yaml": "yaml",
|
||||
".md": "markdown",
|
||||
".sql": "sql",
|
||||
".rb": "ruby",
|
||||
".php": "php",
|
||||
".sh": "bash",
|
||||
".swift": "swift",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".pl": "perl",
|
||||
".asm": "assembly",
|
||||
".r": "r",
|
||||
".scss": "scss",
|
||||
".sass": "sass",
|
||||
".lua": "lua",
|
||||
".ts": "typescript",
|
||||
".tsx": "tsx",
|
||||
".jsx": "jsx",
|
||||
".yml": "yaml",
|
||||
".ini": "ini",
|
||||
".toml": "toml",
|
||||
".svg": "xml", # SVG can often be treated as XML
|
||||
# Add more file extensions and corresponding code block types as needed
|
||||
}
|
||||
return types.get(ext, "")
|
||||
|
||||
|
||||
def to_markdown_code_block(val: str, type_: str = "") -> str:
|
||||
if not val:
|
||||
return val or ""
|
||||
val = val.replace("```", "\\`\\`\\`")
|
||||
return f"\n```{type_}\n{val}\n```\n"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
|
@ -68,3 +69,9 @@ class File:
|
|||
content = b"".join(chunks)
|
||||
logger.debug(f"Successfully read file, the path of file: {file_path}")
|
||||
return content
|
||||
|
||||
|
||||
class MemoryFileSystem(_MemoryFileSystem):
|
||||
@classmethod
|
||||
def _strip_protocol(cls, path):
|
||||
return super()._strip_protocol(str(path))
|
||||
|
|
|
|||
|
|
@ -156,6 +156,8 @@ class GitRepository:
|
|||
:param local_path: The local path to check.
|
||||
:return: True if the directory is a Git repository, False otherwise.
|
||||
"""
|
||||
if not local_path:
|
||||
return False
|
||||
git_dir = Path(local_path) / ".git"
|
||||
if git_dir.exists() and is_git_dir(git_dir):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
from typing import Generator, Optional
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import htmlmin
|
||||
from bs4 import BeautifulSoup
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
|
@ -38,6 +39,22 @@ class WebPage(BaseModel):
|
|||
elif url.startswith(("http://", "https://")):
|
||||
yield urljoin(self.url, url)
|
||||
|
||||
def get_slim_soup(self, keep_links: bool = False):
|
||||
soup = _get_soup(self.html)
|
||||
keep_attrs = ["class"]
|
||||
if keep_links:
|
||||
keep_attrs.append("href")
|
||||
|
||||
for i in soup.find_all(True):
|
||||
for name in list(i.attrs):
|
||||
if i[name] and name not in keep_attrs:
|
||||
del i[name]
|
||||
|
||||
for i in soup.find_all(["svg", "img", "video", "audio"]):
|
||||
i.decompose()
|
||||
|
||||
return soup
|
||||
|
||||
|
||||
def get_html_content(page: str, base: str):
|
||||
soup = _get_soup(page)
|
||||
|
|
@ -48,7 +65,12 @@ def get_html_content(page: str, base: str):
|
|||
def _get_soup(page: str):
|
||||
soup = BeautifulSoup(page, "html.parser")
|
||||
# https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
|
||||
for s in soup(["style", "script", "[document]", "head", "title"]):
|
||||
for s in soup(["style", "script", "[document]", "head", "title", "footer"]):
|
||||
s.extract()
|
||||
|
||||
return soup
|
||||
|
||||
|
||||
def simplify_html(html: str, url: str, keep_links: bool = False):
|
||||
html = WebPage(inner_text="", html=html, url=url).get_slim_soup(keep_links).decode()
|
||||
return htmlmin.minify(html, remove_comments=True, remove_empty_space=True)
|
||||
|
|
|
|||
|
|
@ -140,10 +140,11 @@ class ProjectRepo(FileRepository):
|
|||
return bool(code_files)
|
||||
|
||||
def with_src_path(self, path: str | Path) -> ProjectRepo:
|
||||
try:
|
||||
self._srcs_path = Path(path).relative_to(self.workdir)
|
||||
except ValueError:
|
||||
self._srcs_path = Path(path)
|
||||
path = Path(path)
|
||||
if path.is_relative_to(self.workdir):
|
||||
self._srcs_path = path.relative_to(self.workdir)
|
||||
else:
|
||||
self._srcs_path = path
|
||||
return self
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class BlockType(str, Enum):
|
|||
GALLERY = "Gallery"
|
||||
NOTEBOOK = "Notebook"
|
||||
DOCS = "Docs"
|
||||
THOUGHT = "Thought"
|
||||
|
||||
|
||||
END_MARKER_NAME = "end_marker"
|
||||
|
|
@ -55,23 +56,23 @@ class ResourceReporter(BaseModel):
|
|||
callback_url: str = Field(METAGPT_REPORTER_DEFAULT_URL, description="The URL to which the report should be sent")
|
||||
_llm_task: Optional[asyncio.Task] = PrivateAttr(None)
|
||||
|
||||
def report(self, value: Any, name: str):
|
||||
def report(self, value: Any, name: str, extra: Optional[dict] = None):
|
||||
"""Synchronously report resource observation data.
|
||||
|
||||
Args:
|
||||
value: The data to report.
|
||||
name: The type name of the data.
|
||||
"""
|
||||
return self._report(value, name)
|
||||
return self._report(value, name, extra)
|
||||
|
||||
async def async_report(self, value: Any, name: str):
|
||||
async def async_report(self, value: Any, name: str, extra: Optional[dict] = None):
|
||||
"""Asynchronously report resource observation data.
|
||||
|
||||
Args:
|
||||
value: The data to report.
|
||||
name: The type name of the data.
|
||||
"""
|
||||
return await self._async_report(value, name)
|
||||
return await self._async_report(value, name, extra)
|
||||
|
||||
@classmethod
|
||||
def set_report_fn(cls, fn: Callable):
|
||||
|
|
@ -100,20 +101,20 @@ class ResourceReporter(BaseModel):
|
|||
"""
|
||||
cls._async_report = fn
|
||||
|
||||
def _report(self, value: Any, name: str):
|
||||
def _report(self, value: Any, name: str, extra: Optional[dict] = None):
|
||||
if not self.callback_url:
|
||||
return
|
||||
|
||||
data = self._format_data(value, name)
|
||||
data = self._format_data(value, name, extra)
|
||||
resp = requests.post(self.callback_url, json=data)
|
||||
resp.raise_for_status()
|
||||
return resp.text
|
||||
|
||||
async def _async_report(self, value: Any, name: str):
|
||||
async def _async_report(self, value: Any, name: str, extra: Optional[dict] = None):
|
||||
if not self.callback_url:
|
||||
return
|
||||
|
||||
data = self._format_data(value, name)
|
||||
data = self._format_data(value, name, extra)
|
||||
url = self.callback_url
|
||||
_result = urlparse(url)
|
||||
sessiion_kwargs = {}
|
||||
|
|
@ -129,9 +130,16 @@ class ResourceReporter(BaseModel):
|
|||
resp.raise_for_status()
|
||||
return await resp.text()
|
||||
|
||||
def _format_data(self, value, name):
|
||||
def _format_data(self, value, name, extra):
|
||||
data = self.model_dump(mode="json", exclude=("callback_url", "llm_stream"))
|
||||
data["value"] = str(value) if isinstance(value, Path) else value
|
||||
if isinstance(value, BaseModel):
|
||||
value = value.model_dump(mode="json")
|
||||
elif isinstance(value, Path):
|
||||
value = str(value)
|
||||
|
||||
if name == "path":
|
||||
value = os.path.abspath(value)
|
||||
data["value"] = value
|
||||
data["name"] = name
|
||||
role = CURRENT_ROLE.get(None)
|
||||
if role:
|
||||
|
|
@ -139,6 +147,8 @@ class ResourceReporter(BaseModel):
|
|||
else:
|
||||
role_name = os.environ.get("METAGPT_ROLE")
|
||||
data["role"] = role_name
|
||||
if extra:
|
||||
data["extra"] = extra
|
||||
return data
|
||||
|
||||
def __enter__(self):
|
||||
|
|
@ -252,6 +262,16 @@ class TaskReporter(ObjectReporter):
|
|||
block: Literal[BlockType.TASK] = BlockType.TASK
|
||||
|
||||
|
||||
class ThoughtReporter(ObjectReporter):
|
||||
"""Reporter for object resources to Task Block."""
|
||||
|
||||
block: Literal[BlockType.THOUGHT] = BlockType.THOUGHT
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.async_report({})
|
||||
return await super().__aenter__()
|
||||
|
||||
|
||||
class FileReporter(ResourceReporter):
|
||||
"""File resource callback for reporting complete file paths.
|
||||
|
||||
|
|
@ -259,13 +279,23 @@ class FileReporter(ResourceReporter):
|
|||
if the file can be partially output for display first, use streaming callback.
|
||||
"""
|
||||
|
||||
def report(self, value: Union[Path, dict, Any], name: Literal["path", "meta", "content"] = "path"):
|
||||
def report(
|
||||
self,
|
||||
value: Union[Path, dict, Any],
|
||||
name: Literal["path", "meta", "content"] = "path",
|
||||
extra: Optional[dict] = None,
|
||||
):
|
||||
"""Report file resource synchronously."""
|
||||
return super().report(value, name)
|
||||
return super().report(value, name, extra)
|
||||
|
||||
async def async_report(self, value: Path, name: Literal["path", "meta", "content"] = "path"):
|
||||
async def async_report(
|
||||
self,
|
||||
value: Union[Path, dict, Any],
|
||||
name: Literal["path", "meta", "content"] = "path",
|
||||
extra: Optional[dict] = None,
|
||||
):
|
||||
"""Report file resource asynchronously."""
|
||||
return await super().async_report(value, name)
|
||||
return await super().async_report(value, name, extra)
|
||||
|
||||
|
||||
class NotebookReporter(FileReporter):
|
||||
|
|
|
|||
|
|
@ -71,4 +71,6 @@ dashscope==1.14.1
|
|||
rank-bm25==0.2.2 # for tool recommendation
|
||||
gymnasium==0.29.1
|
||||
pylint~=3.0.3
|
||||
pygithub~=2.3
|
||||
pygithub~=2.3
|
||||
htmlmin
|
||||
fsspec
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -32,12 +32,15 @@ extras_require = {
|
|||
"llama-index-core==0.10.15",
|
||||
"llama-index-embeddings-azure-openai==0.1.6",
|
||||
"llama-index-embeddings-openai==0.1.5",
|
||||
"llama-index-embeddings-gemini==0.1.6",
|
||||
"llama-index-embeddings-ollama==0.1.2",
|
||||
"llama-index-llms-azure-openai==0.1.4",
|
||||
"llama-index-readers-file==0.1.4",
|
||||
"llama-index-retrievers-bm25==0.1.3",
|
||||
"llama-index-vector-stores-faiss==0.1.1",
|
||||
"llama-index-vector-stores-elasticsearch==0.1.6",
|
||||
"llama-index-vector-stores-chroma==0.1.6",
|
||||
"docx2txt==0.8",
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import aiohttp.web
|
||||
|
|
@ -23,7 +22,6 @@ from metagpt.context import Context as MetagptContext
|
|||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.mock.mock_aiohttp import MockAioResponse
|
||||
from tests.mock.mock_curl_cffi import MockCurlCffiResponse
|
||||
from tests.mock.mock_httplib2 import MockHttplib2Response
|
||||
|
|
@ -149,13 +147,14 @@ def loguru_caplog(caplog):
|
|||
@pytest.fixture(scope="function")
|
||||
def context(request):
|
||||
ctx = MetagptContext()
|
||||
ctx.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
|
||||
ctx.repo = ProjectRepo(ctx.git_repo)
|
||||
repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
|
||||
ctx.config.project_path = str(repo.workdir)
|
||||
|
||||
# Destroy git repo at the end of the test session.
|
||||
def fin():
|
||||
if ctx.git_repo:
|
||||
ctx.git_repo.delete_repository()
|
||||
if ctx.config.project_path:
|
||||
git_repo = GitRepository(ctx.config.project_path)
|
||||
git_repo.delete_repository()
|
||||
|
||||
# Register the function for destroying the environment.
|
||||
request.addfinalizer(fin)
|
||||
|
|
@ -279,6 +278,6 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache):
|
|||
@pytest.fixture
|
||||
def git_dir():
|
||||
"""Fixture to get the unittest directory."""
|
||||
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
|
||||
git_dir = DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}"
|
||||
git_dir.mkdir(parents=True, exist_ok=True)
|
||||
return git_dir
|
||||
|
|
|
|||
|
|
@ -303,5 +303,4 @@ def test_action_node_from_pydantic_and_print_everything():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -6,37 +6,104 @@
|
|||
@File : test_design_api.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import AIMessage, Message
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.data.incremental_dev_project.mock import DESIGN_SAMPLE, REFINED_PRD_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api(context):
|
||||
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE
|
||||
for prd in inputs:
|
||||
await context.repo.docs.prd.save(filename="new_prd.txt", content=prd)
|
||||
async def test_design(context):
|
||||
# Mock new design env
|
||||
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
|
||||
context.kwargs.project_path = context.config.project_path
|
||||
context.kwargs.inc = False
|
||||
filename = "prd.txt"
|
||||
repo = ProjectRepo(context.kwargs.project_path)
|
||||
await repo.docs.prd.save(filename=filename, content=prd)
|
||||
kvs = {
|
||||
"project_path": str(context.kwargs.project_path),
|
||||
"changed_prd_filenames": [str(repo.docs.prd.workdir / filename)],
|
||||
}
|
||||
instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput")
|
||||
|
||||
design_api = WriteDesign(context=context)
|
||||
|
||||
result = await design_api.run(Message(content=prd, instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refined_design_api(context):
|
||||
await context.repo.docs.prd.save(filename="1.txt", content=str(REFINED_PRD_JSON))
|
||||
await context.repo.docs.system_design.save(filename="1.txt", content=DESIGN_SAMPLE)
|
||||
|
||||
design_api = WriteDesign(context=context, llm=LLM())
|
||||
|
||||
result = await design_api.run(Message(content="", instruct_content=None))
|
||||
design_api = WriteDesign(context=context)
|
||||
result = await design_api.run([Message(content=prd, instruct_content=instruct_content)])
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.instruct_content
|
||||
assert repo.docs.system_design.changed_files
|
||||
|
||||
# Mock incremental design env
|
||||
context.kwargs.inc = True
|
||||
await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON))
|
||||
await repo.docs.system_design.save(filename=filename, content=DESIGN_SAMPLE)
|
||||
|
||||
result = await design_api.run([Message(content="", instruct_content=instruct_content)])
|
||||
logger.info(result)
|
||||
assert result
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.instruct_content
|
||||
assert repo.docs.system_design.changed_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_requirement", "prd_filename", "legacy_design_filename"),
|
||||
[
|
||||
("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None),
|
||||
("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None),
|
||||
(
|
||||
"write 2048 game",
|
||||
str(METAGPT_ROOT / "tests/data/prd.json"),
|
||||
str(METAGPT_ROOT / "tests/data/system_design.json"),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api(context, user_requirement, prd_filename, legacy_design_filename):
|
||||
action = WriteDesign()
|
||||
result = await action.run(
|
||||
user_requirement=user_requirement, prd_filename=prd_filename, legacy_design_filename=legacy_design_filename
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert str(DEFAULT_WORKSPACE_ROOT) in result.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_requirement", "prd_filename", "legacy_design_filename"),
|
||||
[
|
||||
("我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", None, None),
|
||||
("write 2048 game", str(METAGPT_ROOT / "tests/data/prd.json"), None),
|
||||
(
|
||||
"write 2048 game",
|
||||
str(METAGPT_ROOT / "tests/data/prd.json"),
|
||||
str(METAGPT_ROOT / "tests/data/system_design.json"),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api_dir(context, user_requirement, prd_filename, legacy_design_filename):
|
||||
action = WriteDesign()
|
||||
result = await action.run(
|
||||
user_requirement=user_requirement,
|
||||
prd_filename=prd_filename,
|
||||
legacy_design_filename=legacy_design_filename,
|
||||
output_pathname=str(Path(context.config.project_path) / "1.txt"),
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert str(context.config.project_path) in result.content
|
||||
assert result.instruct_content
|
||||
assert result.instruct_content.changed_system_design_filenames
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -5,13 +5,15 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_project_management.py
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import AIMessage, Message
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
REFINED_DESIGN_JSON,
|
||||
REFINED_PRD_JSON,
|
||||
|
|
@ -22,29 +24,46 @@ from tests.metagpt.actions.mock_json import DESIGN, PRD
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task(context):
|
||||
await context.repo.docs.prd.save("1.txt", content=str(PRD))
|
||||
await context.repo.docs.system_design.save("1.txt", content=str(DESIGN))
|
||||
logger.info(context.git_repo)
|
||||
# Mock write tasks env
|
||||
context.kwargs.project_path = context.config.project_path
|
||||
context.kwargs.inc = False
|
||||
repo = ProjectRepo(context.kwargs.project_path)
|
||||
filename = "1.txt"
|
||||
await repo.docs.prd.save(filename=filename, content=str(PRD))
|
||||
await repo.docs.system_design.save(filename=filename, content=str(DESIGN))
|
||||
kvs = {
|
||||
"project_path": context.kwargs.project_path,
|
||||
"changed_system_design_filenames": [str(repo.docs.system_design.workdir / filename)],
|
||||
}
|
||||
instruct_content = AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput")
|
||||
|
||||
action = WriteTasks(context=context)
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
result = await action.run([Message(content="", instruct_content=instruct_content)])
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
assert result.instruct_content.changed_task_filenames
|
||||
|
||||
# Mock incremental env
|
||||
context.kwargs.inc = True
|
||||
await repo.docs.prd.save(filename=filename, content=str(REFINED_PRD_JSON))
|
||||
await repo.docs.system_design.save(filename=filename, content=str(REFINED_DESIGN_JSON))
|
||||
await repo.docs.task.save(filename=filename, content=TASK_SAMPLE)
|
||||
|
||||
result = await action.run([Message(content="", instruct_content=instruct_content)])
|
||||
logger.info(result)
|
||||
assert result
|
||||
assert result.instruct_content.changed_task_filenames
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refined_task(context):
|
||||
await context.repo.docs.prd.save("2.txt", content=str(REFINED_PRD_JSON))
|
||||
await context.repo.docs.system_design.save("2.txt", content=str(REFINED_DESIGN_JSON))
|
||||
await context.repo.docs.task.save("2.txt", content=TASK_SAMPLE)
|
||||
|
||||
logger.info(context.git_repo)
|
||||
|
||||
action = WriteTasks(context=context, llm=LLM())
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
async def test_task_api(context):
|
||||
action = WriteTasks()
|
||||
result = await action.run(design_filename=str(METAGPT_ROOT / "tests/data/system_design.json"))
|
||||
assert result
|
||||
assert result.content
|
||||
m = json.loads(result.content)
|
||||
assert m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -26,12 +26,7 @@ from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPL
|
|||
|
||||
def setup_inc_workdir(context, inc: bool = False):
|
||||
"""setup incremental workdir for testing"""
|
||||
context.src_workspace = context.git_repo.workdir / "src"
|
||||
if inc:
|
||||
context.config.inc = inc
|
||||
context.repo.old_workspace = context.repo.git_repo.workdir / "old"
|
||||
context.config.project_path = "old"
|
||||
|
||||
context.config.inc = inc
|
||||
return context
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,25 +6,26 @@
|
|||
@File : test_write_prd.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`.
|
||||
"""
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import RoleReactMode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import AIMessage, Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE, PRD_SAMPLE
|
||||
from tests.metagpt.actions.test_write_code import setup_inc_workdir
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from tests.data.incremental_dev_project.mock import NEW_REQUIREMENT_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd(new_filename, context):
|
||||
product_manager = ProductManager(context=context)
|
||||
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
product_manager.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
|
||||
assert prd.cause_by == any_to_str(WritePRD)
|
||||
|
|
@ -34,38 +35,39 @@ async def test_write_prd(new_filename, context):
|
|||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd.content != ""
|
||||
assert product_manager.context.repo.docs.prd.changed_files
|
||||
repo = ProjectRepo(context.kwargs.project_path)
|
||||
assert repo.docs.prd.changed_files
|
||||
repo.git_repo.archive()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd_inc(new_filename, context, git_dir):
|
||||
context = setup_inc_workdir(context, inc=True)
|
||||
await context.repo.docs.prd.save("1.txt", PRD_SAMPLE)
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
|
||||
# Mock incremental requirement
|
||||
context.config.inc = True
|
||||
context.config.project_path = context.kwargs.project_path
|
||||
repo = ProjectRepo(context.config.project_path)
|
||||
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=NEW_REQUIREMENT_SAMPLE)
|
||||
|
||||
action = WritePRD(context=context)
|
||||
prd = await action.run(Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None))
|
||||
prd = await action.run([Message(content=NEW_REQUIREMENT_SAMPLE, instruct_content=None)])
|
||||
logger.info(NEW_REQUIREMENT_SAMPLE)
|
||||
logger.info(prd)
|
||||
|
||||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd.content != ""
|
||||
assert "Refined Requirements" in prd.content
|
||||
assert repo.git_repo.changed_files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_debug(new_filename, context, git_dir):
|
||||
context.src_workspace = context.git_repo.workdir / context.git_repo.workdir.name
|
||||
# Mock legacy project
|
||||
context.kwargs.project_path = str(git_dir)
|
||||
repo = ProjectRepo(context.kwargs.project_path)
|
||||
repo.with_src_path(git_dir.name)
|
||||
await repo.srcs.save(filename="main.py", content='if __name__ == "__main__":\nmain()')
|
||||
requirements = "ValueError: undefined variable `st`."
|
||||
await repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
|
||||
await context.repo.with_src_path(context.src_workspace).srcs.save(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()'
|
||||
)
|
||||
requirements = "Please fix the bug in the code."
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
action = WritePRD(context=context)
|
||||
|
||||
prd = await action.run(Message(content=requirements, instruct_content=None))
|
||||
prd = await action.run([Message(content=requirements, instruct_content=None)])
|
||||
logger.info(prd)
|
||||
|
||||
# Assert the prd is not None or empty
|
||||
|
|
@ -73,5 +75,40 @@ async def test_fix_debug(new_filename, context, git_dir):
|
|||
assert prd.content != ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd_api(context):
|
||||
action = WritePRD()
|
||||
result = await action.run(user_requirement="write a snake game.")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert str(DEFAULT_WORKSPACE_ROOT) in result.content
|
||||
|
||||
result = await action.run(
|
||||
user_requirement="write a snake game.",
|
||||
output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"),
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert result.instruct_content
|
||||
assert str(context.config.project_path) in result.content
|
||||
|
||||
legacy_prd_filename = result.instruct_content.changed_prd_filenames[-1]
|
||||
|
||||
result = await action.run(user_requirement="Add moving enemy.", legacy_prd_filename=legacy_prd_filename)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert str(DEFAULT_WORKSPACE_ROOT) in result.content
|
||||
|
||||
result = await action.run(
|
||||
user_requirement="Add moving enemy.",
|
||||
output_pathname=str(Path(context.config.project_path) / f"{uuid.uuid4().hex}.json"),
|
||||
legacy_prd_filename=legacy_prd_filename,
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content
|
||||
assert result.instruct_content
|
||||
assert str(context.config.project_path) in result.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -25,10 +25,6 @@ class TestSimpleEngine:
|
|||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
|
@ -45,7 +41,6 @@ class TestSimpleEngine:
|
|||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
|
|
@ -81,11 +76,8 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
)
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_retriever.assert_called_once()
|
||||
mock_get_rankers.assert_called_once()
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
|
|
@ -119,7 +111,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is not None
|
||||
assert engine._transformations is not None
|
||||
|
||||
def test_from_objs_with_bm25_config(self):
|
||||
# Setup
|
||||
|
|
@ -137,6 +129,7 @@ class TestSimpleEngine:
|
|||
def test_from_index(self, mocker, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index.as_retriever.return_value = "retriever"
|
||||
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
|
||||
mock_get_index.return_value = mock_index
|
||||
|
||||
|
|
@ -149,7 +142,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is mock_index
|
||||
assert engine._retriever == "retriever"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
|
|
@ -200,14 +193,11 @@ class TestSimpleEngine:
|
|||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Exec
|
||||
|
|
@ -230,7 +220,7 @@ class TestSimpleEngine:
|
|||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Exec
|
||||
engine.add_objs(objs=objs)
|
||||
|
|
|
|||
|
|
@ -97,6 +97,5 @@ class TestConfigBasedFactory:
|
|||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert val is None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
|
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
|
|||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
|
||||
@staticmethod
|
||||
def mock_openai_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
@staticmethod
|
||||
def mock_azure_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
@staticmethod
|
||||
def mock_gemini_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
@staticmethod
|
||||
def mock_ollama_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
@pytest.mark.parametrize(
|
||||
("mock_func", "embedding_type"),
|
||||
[
|
||||
(mock_openai_embedding, LLMType.OPENAI),
|
||||
(mock_azure_embedding, LLMType.AZURE),
|
||||
(mock_openai_embedding, EmbeddingType.OPENAI),
|
||||
(mock_azure_embedding, EmbeddingType.AZURE),
|
||||
(mock_gemini_embedding, EmbeddingType.GEMINI),
|
||||
(mock_ollama_embedding, EmbeddingType.OLLAMA),
|
||||
],
|
||||
)
|
||||
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock = mock_func(mocker)
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(embedding_type)
|
||||
|
||||
# Assert
|
||||
mock.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_config):
|
||||
# Mock
|
||||
mock_openai_embedding = self.mock_openai_embedding(mocker)
|
||||
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
|
|
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
|
|||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, embed_batch_size, expected_params",
|
||||
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
|
||||
)
|
||||
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
|
||||
# Mock
|
||||
mock_config.embedding.model = model
|
||||
mock_config.embedding.embed_batch_size = embed_batch_size
|
||||
|
||||
# Setup
|
||||
test_params = {}
|
||||
|
||||
# Exec
|
||||
self.embedding_factory._try_set_model_and_batch_size(test_params)
|
||||
|
||||
# Assert
|
||||
assert test_params == expected_params
|
||||
|
||||
def test_resolve_embedding_type(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = EmbeddingType.OPENAI
|
||||
|
||||
# Exec
|
||||
embedding_type = self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
# Assert
|
||||
assert embedding_type == EmbeddingType.OPENAI
|
||||
|
||||
def test_resolve_embedding_type_exception(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.GEMINI
|
||||
|
||||
# Assert
|
||||
with pytest.raises(TypeError):
|
||||
self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
def test_raise_for_key(self):
|
||||
with pytest.raises(ValueError):
|
||||
self.embedding_factory._raise_for_key("key")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
||||
|
|
@ -43,6 +45,14 @@ class TestRetrieverFactory:
|
|||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_nodes(self, mocker):
|
||||
return [TextNode(text="msg")]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
|
|
@ -52,42 +62,40 @@ class TestRetrieverFactory:
|
|||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
retriever = self.retriever_factory.get_retriever(
|
||||
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
|
||||
)
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
|
||||
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
|
||||
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
|
||||
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
|
||||
|
|
@ -111,3 +119,19 @@ class TestRetrieverFactory:
|
|||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_get_or_build_when_get(self, mocker):
|
||||
want = "existing_index"
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
||||
def test_get_or_build_when_build(self, mocker):
|
||||
want = "call_build_es_index"
|
||||
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
|
|
|||
|
|
@ -392,5 +392,11 @@ async def test_parse_resources(context, content: str, key_descriptions):
|
|||
assert k in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("name", "value"), [("c1", {"age": 10, "name": "Alice"}), ("", {"path": __file__})])
|
||||
def test_create_instruct_value(name, value):
|
||||
obj = Message.create_instruct_value(kvs=value, class_name=name)
|
||||
assert obj.model_dump() == value
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue