Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev

This commit is contained in:
莘权 马 2024-01-16 09:59:15 +08:00
commit aab27a7c4e
39 changed files with 1325 additions and 188 deletions

Binary file not shown.

View file

@ -34,8 +34,10 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
node: ActionNode = Field(default=None, exclude=True)
@property
def project_repo(self):
return ProjectRepo(self.context.git_repo)
def repo(self) -> ProjectRepo:
if not self.context.repo:
self.context.repo = ProjectRepo(self.context.git_repo)
return self.context.repo
@property
def prompt_schema(self):

View file

@ -61,7 +61,7 @@ Follow instructions of nodes, generate output and make sure it follows the forma
REVIEW_TEMPLATE = """
## context
Compare the keys of nodes_output and the corresponding requirements one by one. If a key that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys.
Compare the key's value of nodes_output and the corresponding requirements one by one. If a key's value that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys.
### nodes_output
{nodes_output}
@ -86,7 +86,7 @@ Compare the keys of nodes_output and the corresponding requirements one by one.
{constraint}
## action
generate output and make sure it follows the format example.
Follow format example's {prompt_schema} format, generate output and make sure it follows the format example.
"""
REVISE_TEMPLATE = """
@ -108,7 +108,7 @@ change the nodes_output key's value to meet its comment and no need to add extra
{constraint}
## action
generate output and make sure it follows the format example.
Follow format example's {prompt_schema} format, generate output and make sure it follows the format example.
"""
@ -469,7 +469,10 @@ class ActionNode:
return dict()
prompt = template.format(
nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4), tag=TAG, constraint=FORMAT_CONSTRAINT
nodes_output=json.dumps(nodes_output, ensure_ascii=False),
tag=TAG,
constraint=FORMAT_CONSTRAINT,
prompt_schema="json",
)
content = await self.llm.aask(prompt)
@ -563,10 +566,11 @@ class ActionNode:
instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys)
prompt = template.format(
nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4),
nodes_output=json.dumps(nodes_output, ensure_ascii=False),
example=example,
instruction=instruction,
constraint=FORMAT_CONSTRAINT,
prompt_schema="json",
)
# step2, use `_aask_v1` to get revise structure result

View file

@ -49,7 +49,7 @@ class DebugError(Action):
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
async def run(self, *args, **kwargs) -> str:
output_doc = await self.project_repo.test_outputs.get(filename=self.i_context.output_filename)
output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename)
if not output_doc:
return ""
output_detail = RunCodeResult.loads(output_doc.content)
@ -59,12 +59,12 @@ class DebugError(Action):
return ""
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
code_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get(
filename=self.i_context.code_filename
)
if not code_doc:
return ""
test_doc = await self.project_repo.tests.get(filename=self.i_context.test_filename)
test_doc = await self.repo.tests.get(filename=self.i_context.test_filename)
if not test_doc:
return ""
prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr)

View file

@ -40,10 +40,10 @@ class WriteDesign(Action):
async def run(self, with_messages: Message, schema: str = None):
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
changed_prds = self.project_repo.docs.prd.changed_files
changed_prds = self.repo.docs.prd.changed_files
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
# changes.
changed_system_designs = self.project_repo.docs.system_design.changed_files
changed_system_designs = self.repo.docs.system_design.changed_files
# For those PRDs and design documents that have undergone changes, regenerate the design content.
changed_files = Documents()
@ -73,21 +73,21 @@ class WriteDesign(Action):
return system_design_doc
async def _update_system_design(self, filename) -> Document:
prd = await self.project_repo.docs.prd.get(filename)
old_system_design_doc = await self.project_repo.docs.system_design.get(filename)
prd = await self.repo.docs.prd.get(filename)
old_system_design_doc = await self.repo.docs.system_design.get(filename)
if not old_system_design_doc:
system_design = await self._new_system_design(context=prd.content)
doc = await self.project_repo.docs.system_design.save(
doc = await self.repo.docs.system_design.save(
filename=filename,
content=system_design.instruct_content.model_dump_json(),
dependencies={prd.root_relative_path},
)
else:
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
await self.project_repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
await self._save_data_api_design(doc)
await self._save_seq_flow(doc)
await self.project_repo.resources.system_design.save_pdf(doc=doc)
await self.repo.resources.system_design.save_pdf(doc=doc)
return doc
async def _save_data_api_design(self, design_doc):
@ -95,7 +95,7 @@ class WriteDesign(Action):
data_api_design = m.get("Data structures and interfaces")
if not data_api_design:
return
pathname = self.project_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
pathname = self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(data_api_design, pathname)
logger.info(f"Save class view to {str(pathname)}")
@ -104,7 +104,7 @@ class WriteDesign(Action):
seq_flow = m.get("Program call flow")
if not seq_flow:
return
pathname = self.project_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
pathname = self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(seq_flow, pathname)
logger.info(f"Saving sequence flow to {str(pathname)}")

View file

@ -15,6 +15,7 @@ from metagpt.actions import Action, ActionOutput
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class PrepareDocuments(Action):
@ -38,13 +39,14 @@ class PrepareDocuments(Action):
shutil.rmtree(path)
self.config.project_path = path
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
self.context.repo = ProjectRepo(self.context.git_repo)
async def run(self, with_messages, **kwargs):
"""Create and initialize the workspace folder, initialize the Git environment."""
self._init_repo()
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
doc = await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
# Send a Message notification to the WritePRD action, instructing it to process requirements using
# `docs/requirement.txt` and `docs/prds/`.
return ActionOutput(content=doc.content, instruct_content=doc)

View file

@ -13,8 +13,8 @@
import json
from typing import Optional
from metagpt.actions import ActionOutput
from metagpt.actions.action import Action
from metagpt.actions.action_output import ActionOutput
from metagpt.actions.project_management_an import PM_NODE
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
from metagpt.logs import logger
@ -34,8 +34,8 @@ class WriteTasks(Action):
i_context: Optional[str] = None
async def run(self, with_messages):
changed_system_designs = self.project_repo.docs.system_design.changed_files
changed_tasks = self.project_repo.docs.task.changed_files
changed_system_designs = self.repo.docs.system_design.changed_files
changed_tasks = self.repo.docs.task.changed_files
change_files = Documents()
# Rewrite the system designs that have undergone changes based on the git head diff under
# `docs/system_designs/`.
@ -57,16 +57,14 @@ class WriteTasks(Action):
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _update_tasks(self, filename):
system_design_doc = await self.project_repo.docs.system_design.get(filename)
task_doc = await self.project_repo.docs.task.get(filename)
system_design_doc = await self.repo.docs.system_design.get(filename)
task_doc = await self.repo.docs.task.get(filename)
if task_doc:
task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc)
await self.project_repo.docs.task.save_doc(
doc=task_doc, dependencies={system_design_doc.root_relative_path}
)
await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path})
else:
rsp = await self._run_new_tasks(context=system_design_doc.content)
task_doc = await self.project_repo.docs.task.save(
task_doc = await self.repo.docs.task.save(
filename=filename,
content=rsp.instruct_content.model_dump_json(),
dependencies={system_design_doc.root_relative_path},
@ -87,7 +85,7 @@ class WriteTasks(Action):
async def _update_requirements(self, doc):
m = json.loads(doc.content)
packages = set(m.get("Required Python third-party packages", set()))
requirement_doc = await self.project_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
if not requirement_doc:
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
lines = requirement_doc.content.splitlines()
@ -95,4 +93,4 @@ class WriteTasks(Action):
if pkg == "":
continue
packages.add(pkg)
await self.project_repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))

View file

@ -98,10 +98,10 @@ class SummarizeCode(Action):
async def run(self):
design_pathname = Path(self.i_context.design_filename)
design_doc = await self.project_repo.docs.system_design.get(filename=design_pathname.name)
design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name)
task_pathname = Path(self.i_context.task_filename)
task_doc = await self.project_repo.docs.task.get(filename=task_pathname.name)
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
task_doc = await self.repo.docs.task.get(filename=task_pathname.name)
src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs
code_blocks = []
for filename in self.i_context.codes_filenames:
code_doc = await src_file_repo.get(filename)

View file

@ -88,12 +88,12 @@ class WriteCode(Action):
return code
async def run(self, *args, **kwargs) -> CodingContext:
bug_feedback = await self.project_repo.docs.get(filename=BUGFIX_FILENAME)
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
coding_context = CodingContext.loads(self.i_context.content)
test_doc = await self.project_repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
summary_doc = None
if coding_context.design_doc and coding_context.design_doc.filename:
summary_doc = await self.project_repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
logs = ""
if test_doc:
test_detail = RunCodeResult.loads(test_doc.content)
@ -105,7 +105,7 @@ class WriteCode(Action):
code_context = await self.get_codes(
coding_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
project_repo=self.repo.with_src_path(self.context.src_workspace),
)
prompt = PROMPT_TEMPLATE.format(

View file

@ -143,7 +143,7 @@ class WriteCodeReview(Action):
code_context = await WriteCode.get_codes(
self.i_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
project_repo=self.repo.with_src_path(self.context.src_workspace),
)
context = "\n".join(
[

View file

@ -15,7 +15,6 @@ from __future__ import annotations
import json
from pathlib import Path
from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
@ -58,96 +57,106 @@ NEW_REQ_TEMPLATE = """
class WritePRD(Action):
name: str = "WritePRD"
content: Optional[str] = None
"""WritePRD deal with the following situations:
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
2. New requirement: If the requirement is a new requirement, the PRD document will be generated.
3. Requirement update: If the requirement is an update, the PRD document will be updated.
"""
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
requirement_doc = await self.project_repo.docs.get(filename=REQUIREMENT_FILENAME)
if requirement_doc and await self._is_bugfix(requirement_doc.content):
await self.project_repo.docs.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=self,
send_to="Alex", # the name of Engineer
)
"""Run the action."""
req: Document = await self.repo.requirement
docs: list[Document] = await self.repo.docs.prd.get_all()
if not req:
raise FileNotFoundError("No requirement document found.")
if await self._is_bugfix(req.content):
logger.info(f"Bugfix detected: {req.content}")
return await self._handle_bugfix(req)
# remove bugfix file from last round in case of conflict
await self.repo.docs.delete(filename=BUGFIX_FILENAME)
# if requirement is related to other documents, update them, otherwise create a new one
if related_docs := await self.get_related_docs(req, docs):
logger.info(f"Requirement update detected: {req.content}")
return await self._handle_requirement_update(req, related_docs)
else:
await self.project_repo.docs.delete(filename=BUGFIX_FILENAME)
logger.info(f"New requirement detected: {req.content}")
return await self._handle_new_requirement(req)
prd_docs = await self.project_repo.docs.prd.get_all()
change_files = Documents()
for prd_doc in prd_docs:
prd_doc = await self._update_prd(requirement_doc=requirement_doc, prd_doc=prd_doc, *args, **kwargs)
if not prd_doc:
continue
change_files.docs[prd_doc.filename] = prd_doc
logger.info(f"rewrite prd: {prd_doc.filename}")
# If there is no existing PRD, generate one using 'docs/requirement.txt'.
if not change_files.docs:
prd_doc = await self._update_prd(requirement_doc=requirement_doc, *args, **kwargs)
if prd_doc:
change_files.docs[prd_doc.filename] = prd_doc
logger.debug(f"new prd: {prd_doc.filename}")
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
# 'publish' message to transition the workflow to the next stage. This design allows room for global
# optimization in subsequent steps.
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _handle_bugfix(self, req: Document) -> Message:
# ... bugfix logic ...
await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content)
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=self,
send_to="Alex", # the name of Engineer
)
async def _run_new_requirement(self, requirements) -> ActionOutput:
async def _handle_new_requirement(self, req: Document) -> ActionOutput:
"""handle new requirement"""
project_name = self.project_name
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
await self._rename_workspace(node)
return node
new_prd_doc = await self.repo.docs.prd.save(
filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json()
)
await self._save_competitive_analysis(new_prd_doc)
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
return Documents.from_iterable(documents=[new_prd_doc]).to_action_output()
async def _is_relative(self, new_requirement_doc, old_prd_doc) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd_doc.content, requirements=new_requirement_doc.content)
async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput:
# ... requirement update logic ...
for doc in related_docs:
await self._update_prd(req, doc)
return Documents.from_iterable(documents=related_docs).to_action_output()
async def _is_bugfix(self, context: str) -> bool:
if not self.repo.code_files_exists():
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
return node.get("issue_type") == "BUG"
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
"""get the related documents"""
# refine: use gather to speed up
return [i for i in docs if await self._is_related(req, i)]
async def _is_related(self, req: Document, old_prd: Document) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
return node.get("is_relative") == "YES"
async def _merge(self, new_requirement_doc, prd_doc) -> Document:
async def _merge(self, req: Document, related_doc: Document) -> Document:
if not self.project_name:
self.project_name = Path(self.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
prd_doc.content = node.instruct_content.model_dump_json()
related_doc.content = node.instruct_content.model_dump_json()
await self._rename_workspace(node)
return prd_doc
return related_doc
async def _update_prd(self, requirement_doc, prd_doc=None, *args, **kwargs) -> Document | None:
if not prd_doc:
prd = await self._run_new_requirement(
requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs
)
new_prd_doc = await self.project_repo.docs.prd.save(
filename=FileRepository.new_filename() + ".json", content=prd.instruct_content.model_dump_json()
)
elif await self._is_relative(requirement_doc, prd_doc):
new_prd_doc = await self._merge(requirement_doc, prd_doc)
self.project_repo.docs.prd.save_doc(doc=new_prd_doc)
else:
return None
async def _update_prd(self, req: Document, prd_doc: Document) -> Document:
new_prd_doc: Document = await self._merge(req, prd_doc)
self.repo.docs.prd.save_doc(doc=new_prd_doc)
await self._save_competitive_analysis(new_prd_doc)
await self.project_repo.resources.prd.save_pdf(doc=new_prd_doc)
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
return new_prd_doc
async def _save_competitive_analysis(self, prd_doc):
async def _save_competitive_analysis(self, prd_doc: Document):
m = json.loads(prd_doc.content)
quadrant_chart = m.get("Competitive Quadrant Chart")
if not quadrant_chart:
return
pathname = (
self.project_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
)
if not pathname.parent.exists():
pathname.parent.mkdir(parents=True, exist_ok=True)
pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(self.config.mermaid_engine, quadrant_chart, pathname)
async def _rename_workspace(self, prd):
@ -158,15 +167,4 @@ class WritePRD(Action):
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
if ws_name:
self.project_name = ws_name
self.project_repo.git_repo.rename_root(self.project_name)
async def _is_bugfix(self, context) -> bool:
git_workdir = self.project_repo.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
if not src_workdir.exists():
return False
code_files = self.project_repo.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
return node.get("issue_type") == "BUG"
self.repo.git_repo.rename_root(self.project_name)

View file

@ -17,6 +17,7 @@ from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class AttrDict(BaseModel):
@ -58,6 +59,8 @@ class Context(BaseModel):
kwargs: AttrDict = AttrDict()
config: Config = Config.default()
repo: Optional[ProjectRepo] = None
git_repo: Optional[GitRepository] = None
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()
@ -67,8 +70,8 @@ class Context(BaseModel):
def new_environ(self):
"""Return a new os.environ object"""
env = os.environ.copy()
i = self.options
env.update({k: v for k, v in i.items() if isinstance(v, str)})
# i = self.options
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:

View file

@ -235,3 +235,7 @@ class OpenAILLM(BaseLLM):
async def amoderation(self, content: Union[str, list[str]]):
"""Moderate content."""
return await self.aclient.moderations.create(input=content)
async def atext_to_speech(self, **kwargs):
"""text to speech"""
return await self.aclient.audio.speech.create(**kwargs)

View file

@ -481,6 +481,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
rsp = await self._act_by_order()
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
rsp = await self._plan_and_act()
else:
raise ValueError(f"Unsupported react mode: {self.rc.react_mode}")
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
return rsp

View file

@ -10,8 +10,9 @@
from pydantic import Field
from metagpt.actions import ActionOutput, SearchAndSummarize
from metagpt.actions import SearchAndSummarize
from metagpt.actions.action_node import ActionNode
from metagpt.actions.action_output import ActionOutput
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
@ -162,6 +162,26 @@ class Documents(BaseModel):
docs: Dict[str, Document] = Field(default_factory=dict)
@classmethod
def from_iterable(cls, documents: Iterable[Document]) -> Documents:
"""Create a Documents instance from a list of Document instances.
:param documents: A list of Document instances.
:return: A Documents instance.
"""
docs = {doc.filename: doc for doc in documents}
return Documents(docs=docs)
def to_action_output(self) -> "ActionOutput":
"""Convert to action output string.
:return: A string representing action output.
"""
from metagpt.actions.action_output import ActionOutput
return ActionOutput(content=self.model_dump_json(), instruct_content=self)
class Message(BaseModel):
"""list[<role>: <content>]"""
@ -212,7 +232,7 @@ class Message(BaseModel):
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
@field_serializer("instruct_content", mode="plain")
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
ic_dict = None
if ic:
# compatible with custom-defined ActionOutput

View file

@ -44,19 +44,20 @@ class SearchEngine:
self,
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
**kwargs,
):
if engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper().run
run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run
elif engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper().run
run_func = importlib.import_module(module).SerperWrapper(**kwargs).run
elif engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper().run
run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run
elif engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper().run
run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
else:

View file

@ -13,7 +13,7 @@ from metagpt.utils.parse_html import WebPage
class WebBrowserEngine:
def __init__(
self,
engine: WebBrowserEngineType | None = WebBrowserEngineType.PLAYWRIGHT,
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
if engine is None:

View file

@ -33,7 +33,7 @@ class SeleniumWrapper:
def __init__(
self,
browser_type: Literal["chrome", "firefox", "edge", "ie"] | None = None,
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome",
launch_kwargs: dict | None = None,
*,
loop: asyncio.AbstractEventLoop | None = None,

View file

@ -21,6 +21,7 @@ from metagpt.const import (
GRAPH_REPO_FILE_REPO,
PRD_PDF_FILE_REPO,
PRDS_FILE_REPO,
REQUIREMENT_FILENAME,
RESOURCES_FILE_REPO,
SD_OUTPUT_FILE_REPO,
SEQ_FLOW_FILE_REPO,
@ -93,6 +94,10 @@ class ProjectRepo(FileRepository):
self.test_outputs = self._git_repo.new_file_repository(relative_path=TEST_OUTPUTS_FILE_REPO)
self._srcs_path = None
@property
async def requirement(self):
return await self.docs.get(filename=REQUIREMENT_FILENAME)
@property
def git_repo(self) -> GitRepository:
return self._git_repo
@ -107,6 +112,15 @@ class ProjectRepo(FileRepository):
raise ValueError("Call with_srcs first.")
return self._git_repo.new_file_repository(self._srcs_path)
def code_files_exists(self) -> bool:
git_workdir = self.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
if not src_workdir.exists():
return False
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
def with_src_path(self, path: str | Path) -> ProjectRepo:
try:
self._srcs_path = Path(path).relative_to(self.workdir)

View file

@ -12,6 +12,7 @@ import logging
import os
import re
import uuid
from typing import Callable
import pytest
@ -20,6 +21,9 @@ from metagpt.context import CONTEXT
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from tests.mock.mock_aiohttp import MockAioResponse
from tests.mock.mock_curl_cffi import MockCurlCffiResponse
from tests.mock.mock_httplib2 import MockHttplib2Response
from tests.mock.mock_llm import MockLLM
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
@ -123,7 +127,7 @@ def proxy():
server = await asyncio.start_server(handle_client, "127.0.0.1", 0)
return server, "http://{}:{}".format(*server.sockets[0].getsockname())
return proxy_func()
return proxy_func
# see https://github.com/Delgan/loguru/issues/59#issuecomment-466591978
@ -164,39 +168,63 @@ def new_filename(mocker):
yield mocker
@pytest.fixture(scope="session")
def search_rsp_cache():
rsp_cache_file_path = TEST_DATA_PATH / "search_rsp_cache.json" # read repo-provided
if os.path.exists(rsp_cache_file_path):
with open(rsp_cache_file_path, "r") as f1:
rsp_cache_json = json.load(f1)
else:
rsp_cache_json = {}
yield rsp_cache_json
with open(rsp_cache_file_path, "w") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
@pytest.fixture
def aiohttp_mocker(mocker):
class MockAioResponse:
async def json(self, *args, **kwargs):
return self._json
def set_json(self, json):
self._json = json
response = MockAioResponse()
class MockCTXMng:
async def __aenter__(self):
return response
async def __aexit__(self, *args, **kwargs):
pass
def __await__(self):
yield
return response
def mock_request(self, method, url, **kwargs):
return MockCTXMng()
MockResponse = type("MockResponse", (MockAioResponse,), {})
def wrap(method):
def run(self, url, **kwargs):
return mock_request(self, method, url, **kwargs)
return MockResponse(self, method, url, **kwargs)
return run
mocker.patch("aiohttp.ClientSession.request", mock_request)
mocker.patch("aiohttp.ClientSession.request", MockResponse)
for i in ["get", "post", "delete", "patch"]:
mocker.patch(f"aiohttp.ClientSession.{i}", wrap(i))
yield MockResponse
yield response
@pytest.fixture
def curl_cffi_mocker(mocker):
MockResponse = type("MockResponse", (MockCurlCffiResponse,), {})
def request(self, *args, **kwargs):
return MockResponse(self, *args, **kwargs)
mocker.patch("curl_cffi.requests.Session.request", request)
yield MockResponse
@pytest.fixture
def httplib2_mocker(mocker):
MockResponse = type("MockResponse", (MockHttplib2Response,), {})
def request(self, *args, **kwargs):
return MockResponse(self, *args, **kwargs)
mocker.patch("httplib2.Http.request", request)
yield MockResponse
@pytest.fixture
def search_engine_mocker(aiohttp_mocker, curl_cffi_mocker, httplib2_mocker, search_rsp_cache):
# aiohttp_mocker: serpapi/serper
# httplib2_mocker: google
# curl_cffi_mocker: ddg
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
aiohttp_mocker.rsp_cache = httplib2_mocker.rsp_cache = curl_cffi_mocker.rsp_cache = search_rsp_cache
aiohttp_mocker.check_funcs = httplib2_mocker.check_funcs = curl_cffi_mocker.check_funcs = check_funcs
yield check_funcs

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -9,10 +9,12 @@
import pytest
from metagpt.actions import research
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@pytest.mark.asyncio
async def test_collect_links(mocker):
async def test_collect_links(mocker, search_engine_mocker):
async def mock_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
@ -26,13 +28,15 @@ async def test_collect_links(mocker):
return "[1,2]"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
resp = await research.CollectLinks().run("The application of MetaGPT")
resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO)).run(
"The application of MetaGPT"
)
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@pytest.mark.asyncio
async def test_collect_links_with_rank_func(mocker):
async def test_collect_links_with_rank_func(mocker, search_engine_mocker):
rank_before = []
rank_after = []
url_per_query = 4
@ -45,7 +49,9 @@ async def test_collect_links_with_rank_func(mocker):
return results
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
resp = await research.CollectLinks(
search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func
).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y
assert [i["link"] for i in y] == z

View file

@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from metagpt.config2 import config
from metagpt.config2 import Config
from metagpt.learn.text_to_embedding import text_to_embedding
from metagpt.utils.common import aread
@ -19,13 +19,14 @@ from metagpt.utils.common import aread
@pytest.mark.asyncio
async def test_text_to_embedding(mocker):
# mock
config = Config.default()
mock_post = mocker.patch("aiohttp.ClientSession.post")
mock_response = mocker.AsyncMock()
mock_response.status = 200
data = await aread(Path(__file__).parent / "../../data/openai/embedding.json")
mock_response.json.return_value = json.loads(data)
mock_post.return_value.__aenter__.return_value = mock_response
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
config.get_openai_llm().proxy = mocker.PropertyMock(return_value="http://mock.proxy")
# Prerequisites
assert config.get_openai_llm().api_key

View file

@ -42,11 +42,23 @@ async def test_aask_code_message():
assert len(rsp["code"]) > 0
@pytest.mark.asyncio
async def test_text_to_speech():
llm = LLM()
resp = await llm.atext_to_speech(
model="tts-1",
voice="alloy",
input="人生说起来长,但知道一个岁月回头看,许多事件仅是仓促的。一段一段拼凑一起,合成了人生。苦难当头时,当下不免觉得是折磨;回头看,也不够是一段短短的人生旅程。",
)
assert 200 == resp.response.status_code
class TestOpenAI:
def test_make_client_kwargs_without_proxy(self):
instance = OpenAILLM(mock_llm_config)
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "mock_api_key", "base_url": "mock_base_url"}
assert kwargs["api_key"] == "mock_api_key"
assert kwargs["base_url"] == "mock_base_url"
assert "http_client" not in kwargs
def test_make_client_kwargs_with_proxy(self):

View file

@ -4,7 +4,10 @@ from tempfile import TemporaryDirectory
import pytest
from metagpt.actions.research import CollectLinks
from metagpt.roles import researcher
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
async def mock_llm_ask(self, prompt: str, system_msgs):
@ -25,12 +28,16 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
@pytest.mark.asyncio
async def test_researcher(mocker):
async def test_researcher(mocker, search_engine_mocker):
with TemporaryDirectory() as dirname:
topic = "dataiku vs. datarobot"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
researcher.RESEARCH_PATH = Path(dirname)
await researcher.Researcher().run(topic)
role = researcher.Researcher()
for i in role.actions:
if isinstance(i, CollectLinks):
i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO)
await role.run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")

View file

@ -18,5 +18,5 @@ async def test_action_serdeser(new_filename):
new_action = WritePRD(**ser_action_dict)
assert new_action.name == "WritePRD"
action_output = await new_action.run(with_messages=Message(content="write a cli snake game"))
assert len(action_output.content) > 0
with pytest.raises(FileNotFoundError):
await new_action.run(with_messages=Message(content="write a cli snake game"))

View file

@ -48,7 +48,6 @@ def test_context_1():
assert ctx.git_repo is None
assert ctx.src_workspace is None
assert ctx.cost_manager is not None
assert ctx.options is not None
def test_context_2():

View file

@ -95,7 +95,7 @@ def test_config_mixin_4_multi_inheritance_override_config():
print(obj.__dict__.keys())
assert "private_config" in obj.__dict__.keys()
assert obj.llm.model == "mock_zhipu_model"
assert obj.config.llm.model == "mock_zhipu_model"
@pytest.mark.asyncio

View file

@ -10,7 +10,7 @@ from pathlib import Path
import pytest
from metagpt.config2 import config
from metagpt.config2 import Config
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
from metagpt.utils.common import aread
@ -18,6 +18,7 @@ from metagpt.utils.common import aread
@pytest.mark.asyncio
async def test_embedding(mocker):
# mock
config = Config.default()
mock_post = mocker.patch("aiohttp.ClientSession.post")
mock_response = mocker.AsyncMock()
mock_response.status = 200

View file

@ -7,20 +7,15 @@
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Callable
import pytest
import tests.data.search
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
search_cache_path = Path(tests.data.search.__path__[0])
class MockSearchEnine:
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
@ -46,24 +41,28 @@ class MockSearchEnine:
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],
)
async def test_search_engine(search_engine_type, run_func: Callable, max_results: int, as_string: bool, aiohttp_mocker):
async def test_search_engine(
search_engine_type,
run_func: Callable,
max_results: int,
as_string: bool,
search_engine_mocker,
):
# Prerequisites
cache_json_path = None
# FIXME: 不能使用全局的config而是要自己实例化对应的config
search_engine_config = {}
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
assert config.search
cache_json_path = search_cache_path / f"serpapi-metagpt-{max_results}.json"
search_engine_config["serpapi_api_key"] = "mock-serpapi-key"
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
assert config.search
search_engine_config["google_api_key"] = "mock-google-key"
search_engine_config["google_cse_id"] = "mock-google-cse"
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
assert config.search
cache_json_path = search_cache_path / f"serper-metagpt-{max_results}.json"
search_engine_config["serper_api_key"] = "mock-serper-key"
if cache_json_path:
with open(cache_json_path) as f:
data = json.load(f)
aiohttp_mocker.set_json(data)
search_engine = SearchEngine(search_engine_type, run_func)
search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config)
rsp = await search_engine.run("metagpt", max_results, as_string)
logger.info(rsp)
if as_string:

View file

@ -22,8 +22,8 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
global_proxy = config.proxy
try:
if use_proxy:
server, proxy = await proxy
config.proxy = proxy
server, proxy_url = await proxy()
config.proxy = proxy_url
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs)
result = await browser.run(url)
assert isinstance(result, WebPage)

View file

@ -25,8 +25,8 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
global_proxy = config.proxy
try:
if use_proxy:
server, proxy = await proxy
config.proxy = proxy
server, proxy_url = await proxy()
config.proxy = proxy_url
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type)
result = await browser.run(url)
assert isinstance(result, WebPage)

View file

@ -0,0 +1,41 @@
import json
from typing import Callable
from aiohttp.client import ClientSession
origin_request = ClientSession.request
class MockAioResponse:
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "aiohttp"
def __init__(self, session, method, url, **kwargs) -> None:
fn = self.check_funcs.get((method, url))
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
self.mng = self.response = None
if self.key not in self.rsp_cache:
self.mng = origin_request(session, method, url, **kwargs)
async def __aenter__(self):
if self.response:
await self.response.__aenter__()
elif self.mng:
self.response = await self.mng.__aenter__()
return self
async def __aexit__(self, *args, **kwargs):
if self.response:
await self.response.__aexit__(*args, **kwargs)
self.response = None
elif self.mng:
await self.mng.__aexit__(*args, **kwargs)
self.mng = None
async def json(self, *args, **kwargs):
if self.key in self.rsp_cache:
return self.rsp_cache[self.key]
data = await self.response.json(*args, **kwargs)
self.rsp_cache[self.key] = data
return data

View file

@ -0,0 +1,22 @@
import json
from typing import Callable
from curl_cffi import requests
origin_request = requests.Session.request
class MockCurlCffiResponse(requests.Response):
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "curl-cffi"
def __init__(self, session, method, url, **kwargs) -> None:
super().__init__()
fn = self.check_funcs.get((method, url))
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
self.response = None
if self.key not in self.rsp_cache:
response = origin_request(session, method, url, **kwargs)
self.rsp_cache[self.key] = response.content.decode()
self.content = self.rsp_cache[self.key].encode()

View file

@ -0,0 +1,29 @@
import json
from typing import Callable
from urllib.parse import parse_qsl, urlparse
import httplib2
origin_request = httplib2.Http.request
class MockHttplib2Response(httplib2.Response):
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "httplib2"
def __init__(self, http, uri, method="GET", **kwargs) -> None:
url = uri.split("?")[0]
result = urlparse(uri)
params = dict(parse_qsl(result.query))
fn = self.check_funcs.get((method, uri))
new_kwargs = {"params": params}
key = f"{self.name}-{method}-{url}-{fn(new_kwargs) if fn else json.dumps(new_kwargs)}"
if key not in self.rsp_cache:
_, self.content = origin_request(http, uri, method, **kwargs)
self.rsp_cache[key] = self.content.decode()
self.content = self.rsp_cache[key]
def __iter__(self):
yield self
yield self.content.encode()