mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev
This commit is contained in:
commit
aab27a7c4e
39 changed files with 1325 additions and 188 deletions
Binary file not shown.
|
|
@ -34,8 +34,10 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
|
|||
node: ActionNode = Field(default=None, exclude=True)
|
||||
|
||||
@property
|
||||
def project_repo(self):
|
||||
return ProjectRepo(self.context.git_repo)
|
||||
def repo(self) -> ProjectRepo:
|
||||
if not self.context.repo:
|
||||
self.context.repo = ProjectRepo(self.context.git_repo)
|
||||
return self.context.repo
|
||||
|
||||
@property
|
||||
def prompt_schema(self):
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ Follow instructions of nodes, generate output and make sure it follows the forma
|
|||
|
||||
REVIEW_TEMPLATE = """
|
||||
## context
|
||||
Compare the keys of nodes_output and the corresponding requirements one by one. If a key that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys.
|
||||
Compare the key's value of nodes_output and the corresponding requirements one by one. If a key's value that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys.
|
||||
|
||||
### nodes_output
|
||||
{nodes_output}
|
||||
|
|
@ -86,7 +86,7 @@ Compare the keys of nodes_output and the corresponding requirements one by one.
|
|||
{constraint}
|
||||
|
||||
## action
|
||||
generate output and make sure it follows the format example.
|
||||
Follow format example's {prompt_schema} format, generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
REVISE_TEMPLATE = """
|
||||
|
|
@ -108,7 +108,7 @@ change the nodes_output key's value to meet its comment and no need to add extra
|
|||
{constraint}
|
||||
|
||||
## action
|
||||
generate output and make sure it follows the format example.
|
||||
Follow format example's {prompt_schema} format, generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -469,7 +469,10 @@ class ActionNode:
|
|||
return dict()
|
||||
|
||||
prompt = template.format(
|
||||
nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4), tag=TAG, constraint=FORMAT_CONSTRAINT
|
||||
nodes_output=json.dumps(nodes_output, ensure_ascii=False),
|
||||
tag=TAG,
|
||||
constraint=FORMAT_CONSTRAINT,
|
||||
prompt_schema="json",
|
||||
)
|
||||
|
||||
content = await self.llm.aask(prompt)
|
||||
|
|
@ -563,10 +566,11 @@ class ActionNode:
|
|||
instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys)
|
||||
|
||||
prompt = template.format(
|
||||
nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4),
|
||||
nodes_output=json.dumps(nodes_output, ensure_ascii=False),
|
||||
example=example,
|
||||
instruction=instruction,
|
||||
constraint=FORMAT_CONSTRAINT,
|
||||
prompt_schema="json",
|
||||
)
|
||||
|
||||
# step2, use `_aask_v1` to get revise structure result
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class DebugError(Action):
|
|||
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await self.project_repo.test_outputs.get(filename=self.i_context.output_filename)
|
||||
output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename)
|
||||
if not output_doc:
|
||||
return ""
|
||||
output_detail = RunCodeResult.loads(output_doc.content)
|
||||
|
|
@ -59,12 +59,12 @@ class DebugError(Action):
|
|||
return ""
|
||||
|
||||
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
|
||||
code_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
|
||||
code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get(
|
||||
filename=self.i_context.code_filename
|
||||
)
|
||||
if not code_doc:
|
||||
return ""
|
||||
test_doc = await self.project_repo.tests.get(filename=self.i_context.test_filename)
|
||||
test_doc = await self.repo.tests.get(filename=self.i_context.test_filename)
|
||||
if not test_doc:
|
||||
return ""
|
||||
prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr)
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ class WriteDesign(Action):
|
|||
|
||||
async def run(self, with_messages: Message, schema: str = None):
|
||||
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
|
||||
changed_prds = self.project_repo.docs.prd.changed_files
|
||||
changed_prds = self.repo.docs.prd.changed_files
|
||||
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
|
||||
# changes.
|
||||
changed_system_designs = self.project_repo.docs.system_design.changed_files
|
||||
changed_system_designs = self.repo.docs.system_design.changed_files
|
||||
|
||||
# For those PRDs and design documents that have undergone changes, regenerate the design content.
|
||||
changed_files = Documents()
|
||||
|
|
@ -73,21 +73,21 @@ class WriteDesign(Action):
|
|||
return system_design_doc
|
||||
|
||||
async def _update_system_design(self, filename) -> Document:
|
||||
prd = await self.project_repo.docs.prd.get(filename)
|
||||
old_system_design_doc = await self.project_repo.docs.system_design.get(filename)
|
||||
prd = await self.repo.docs.prd.get(filename)
|
||||
old_system_design_doc = await self.repo.docs.system_design.get(filename)
|
||||
if not old_system_design_doc:
|
||||
system_design = await self._new_system_design(context=prd.content)
|
||||
doc = await self.project_repo.docs.system_design.save(
|
||||
doc = await self.repo.docs.system_design.save(
|
||||
filename=filename,
|
||||
content=system_design.instruct_content.model_dump_json(),
|
||||
dependencies={prd.root_relative_path},
|
||||
)
|
||||
else:
|
||||
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
|
||||
await self.project_repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
|
||||
await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
|
||||
await self._save_data_api_design(doc)
|
||||
await self._save_seq_flow(doc)
|
||||
await self.project_repo.resources.system_design.save_pdf(doc=doc)
|
||||
await self.repo.resources.system_design.save_pdf(doc=doc)
|
||||
return doc
|
||||
|
||||
async def _save_data_api_design(self, design_doc):
|
||||
|
|
@ -95,7 +95,7 @@ class WriteDesign(Action):
|
|||
data_api_design = m.get("Data structures and interfaces")
|
||||
if not data_api_design:
|
||||
return
|
||||
pathname = self.project_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
|
||||
pathname = self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
|
||||
await self._save_mermaid_file(data_api_design, pathname)
|
||||
logger.info(f"Save class view to {str(pathname)}")
|
||||
|
||||
|
|
@ -104,7 +104,7 @@ class WriteDesign(Action):
|
|||
seq_flow = m.get("Program call flow")
|
||||
if not seq_flow:
|
||||
return
|
||||
pathname = self.project_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
|
||||
pathname = self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
|
||||
await self._save_mermaid_file(seq_flow, pathname)
|
||||
logger.info(f"Saving sequence flow to {str(pathname)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from metagpt.actions import Action, ActionOutput
|
|||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class PrepareDocuments(Action):
|
||||
|
|
@ -38,13 +39,14 @@ class PrepareDocuments(Action):
|
|||
shutil.rmtree(path)
|
||||
self.config.project_path = path
|
||||
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
self.context.repo = ProjectRepo(self.context.git_repo)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
"""Create and initialize the workspace folder, initialize the Git environment."""
|
||||
self._init_repo()
|
||||
|
||||
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
|
||||
doc = await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
|
||||
doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
|
||||
# Send a Message notification to the WritePRD action, instructing it to process requirements using
|
||||
# `docs/requirement.txt` and `docs/prds/`.
|
||||
return ActionOutput(content=doc.content, instruct_content=doc)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.actions import ActionOutput
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.actions.project_management_an import PM_NODE
|
||||
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -34,8 +34,8 @@ class WriteTasks(Action):
|
|||
i_context: Optional[str] = None
|
||||
|
||||
async def run(self, with_messages):
|
||||
changed_system_designs = self.project_repo.docs.system_design.changed_files
|
||||
changed_tasks = self.project_repo.docs.task.changed_files
|
||||
changed_system_designs = self.repo.docs.system_design.changed_files
|
||||
changed_tasks = self.repo.docs.task.changed_files
|
||||
change_files = Documents()
|
||||
# Rewrite the system designs that have undergone changes based on the git head diff under
|
||||
# `docs/system_designs/`.
|
||||
|
|
@ -57,16 +57,14 @@ class WriteTasks(Action):
|
|||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
|
||||
async def _update_tasks(self, filename):
|
||||
system_design_doc = await self.project_repo.docs.system_design.get(filename)
|
||||
task_doc = await self.project_repo.docs.task.get(filename)
|
||||
system_design_doc = await self.repo.docs.system_design.get(filename)
|
||||
task_doc = await self.repo.docs.task.get(filename)
|
||||
if task_doc:
|
||||
task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc)
|
||||
await self.project_repo.docs.task.save_doc(
|
||||
doc=task_doc, dependencies={system_design_doc.root_relative_path}
|
||||
)
|
||||
await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path})
|
||||
else:
|
||||
rsp = await self._run_new_tasks(context=system_design_doc.content)
|
||||
task_doc = await self.project_repo.docs.task.save(
|
||||
task_doc = await self.repo.docs.task.save(
|
||||
filename=filename,
|
||||
content=rsp.instruct_content.model_dump_json(),
|
||||
dependencies={system_design_doc.root_relative_path},
|
||||
|
|
@ -87,7 +85,7 @@ class WriteTasks(Action):
|
|||
async def _update_requirements(self, doc):
|
||||
m = json.loads(doc.content)
|
||||
packages = set(m.get("Required Python third-party packages", set()))
|
||||
requirement_doc = await self.project_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
|
||||
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
|
||||
if not requirement_doc:
|
||||
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
|
||||
lines = requirement_doc.content.splitlines()
|
||||
|
|
@ -95,4 +93,4 @@ class WriteTasks(Action):
|
|||
if pkg == "":
|
||||
continue
|
||||
packages.add(pkg)
|
||||
await self.project_repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
|
||||
await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
|
||||
|
|
|
|||
|
|
@ -98,10 +98,10 @@ class SummarizeCode(Action):
|
|||
|
||||
async def run(self):
|
||||
design_pathname = Path(self.i_context.design_filename)
|
||||
design_doc = await self.project_repo.docs.system_design.get(filename=design_pathname.name)
|
||||
design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name)
|
||||
task_pathname = Path(self.i_context.task_filename)
|
||||
task_doc = await self.project_repo.docs.task.get(filename=task_pathname.name)
|
||||
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
|
||||
task_doc = await self.repo.docs.task.get(filename=task_pathname.name)
|
||||
src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs
|
||||
code_blocks = []
|
||||
for filename in self.i_context.codes_filenames:
|
||||
code_doc = await src_file_repo.get(filename)
|
||||
|
|
|
|||
|
|
@ -88,12 +88,12 @@ class WriteCode(Action):
|
|||
return code
|
||||
|
||||
async def run(self, *args, **kwargs) -> CodingContext:
|
||||
bug_feedback = await self.project_repo.docs.get(filename=BUGFIX_FILENAME)
|
||||
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
|
||||
coding_context = CodingContext.loads(self.i_context.content)
|
||||
test_doc = await self.project_repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
|
||||
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
|
||||
summary_doc = None
|
||||
if coding_context.design_doc and coding_context.design_doc.filename:
|
||||
summary_doc = await self.project_repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
|
||||
summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
|
||||
logs = ""
|
||||
if test_doc:
|
||||
test_detail = RunCodeResult.loads(test_doc.content)
|
||||
|
|
@ -105,7 +105,7 @@ class WriteCode(Action):
|
|||
code_context = await self.get_codes(
|
||||
coding_context.task_doc,
|
||||
exclude=self.i_context.filename,
|
||||
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
|
||||
project_repo=self.repo.with_src_path(self.context.src_workspace),
|
||||
)
|
||||
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class WriteCodeReview(Action):
|
|||
code_context = await WriteCode.get_codes(
|
||||
self.i_context.task_doc,
|
||||
exclude=self.i_context.filename,
|
||||
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
|
||||
project_repo=self.repo.with_src_path(self.context.src_workspace),
|
||||
)
|
||||
context = "\n".join(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
|
@ -58,96 +57,106 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
|
||||
class WritePRD(Action):
|
||||
name: str = "WritePRD"
|
||||
content: Optional[str] = None
|
||||
"""WritePRD deal with the following situations:
|
||||
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
|
||||
2. New requirement: If the requirement is a new requirement, the PRD document will be generated.
|
||||
3. Requirement update: If the requirement is an update, the PRD document will be updated.
|
||||
"""
|
||||
|
||||
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
# related to the PRD. If they are related, rewrite the PRD.
|
||||
requirement_doc = await self.project_repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
if requirement_doc and await self._is_bugfix(requirement_doc.content):
|
||||
await self.project_repo.docs.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
|
||||
await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
|
||||
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
|
||||
return Message(
|
||||
content=bug_fix.model_dump_json(),
|
||||
instruct_content=bug_fix,
|
||||
role="",
|
||||
cause_by=FixBug,
|
||||
sent_from=self,
|
||||
send_to="Alex", # the name of Engineer
|
||||
)
|
||||
"""Run the action."""
|
||||
req: Document = await self.repo.requirement
|
||||
docs: list[Document] = await self.repo.docs.prd.get_all()
|
||||
if not req:
|
||||
raise FileNotFoundError("No requirement document found.")
|
||||
|
||||
if await self._is_bugfix(req.content):
|
||||
logger.info(f"Bugfix detected: {req.content}")
|
||||
return await self._handle_bugfix(req)
|
||||
# remove bugfix file from last round in case of conflict
|
||||
await self.repo.docs.delete(filename=BUGFIX_FILENAME)
|
||||
|
||||
# if requirement is related to other documents, update them, otherwise create a new one
|
||||
if related_docs := await self.get_related_docs(req, docs):
|
||||
logger.info(f"Requirement update detected: {req.content}")
|
||||
return await self._handle_requirement_update(req, related_docs)
|
||||
else:
|
||||
await self.project_repo.docs.delete(filename=BUGFIX_FILENAME)
|
||||
logger.info(f"New requirement detected: {req.content}")
|
||||
return await self._handle_new_requirement(req)
|
||||
|
||||
prd_docs = await self.project_repo.docs.prd.get_all()
|
||||
change_files = Documents()
|
||||
for prd_doc in prd_docs:
|
||||
prd_doc = await self._update_prd(requirement_doc=requirement_doc, prd_doc=prd_doc, *args, **kwargs)
|
||||
if not prd_doc:
|
||||
continue
|
||||
change_files.docs[prd_doc.filename] = prd_doc
|
||||
logger.info(f"rewrite prd: {prd_doc.filename}")
|
||||
# If there is no existing PRD, generate one using 'docs/requirement.txt'.
|
||||
if not change_files.docs:
|
||||
prd_doc = await self._update_prd(requirement_doc=requirement_doc, *args, **kwargs)
|
||||
if prd_doc:
|
||||
change_files.docs[prd_doc.filename] = prd_doc
|
||||
logger.debug(f"new prd: {prd_doc.filename}")
|
||||
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
|
||||
# 'publish' message to transition the workflow to the next stage. This design allows room for global
|
||||
# optimization in subsequent steps.
|
||||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
async def _handle_bugfix(self, req: Document) -> Message:
|
||||
# ... bugfix logic ...
|
||||
await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content)
|
||||
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
|
||||
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
|
||||
return Message(
|
||||
content=bug_fix.model_dump_json(),
|
||||
instruct_content=bug_fix,
|
||||
role="",
|
||||
cause_by=FixBug,
|
||||
sent_from=self,
|
||||
send_to="Alex", # the name of Engineer
|
||||
)
|
||||
|
||||
async def _run_new_requirement(self, requirements) -> ActionOutput:
|
||||
async def _handle_new_requirement(self, req: Document) -> ActionOutput:
|
||||
"""handle new requirement"""
|
||||
project_name = self.project_name
|
||||
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
|
||||
context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name)
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
|
||||
await self._rename_workspace(node)
|
||||
return node
|
||||
new_prd_doc = await self.repo.docs.prd.save(
|
||||
filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json()
|
||||
)
|
||||
await self._save_competitive_analysis(new_prd_doc)
|
||||
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
|
||||
return Documents.from_iterable(documents=[new_prd_doc]).to_action_output()
|
||||
|
||||
async def _is_relative(self, new_requirement_doc, old_prd_doc) -> bool:
|
||||
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd_doc.content, requirements=new_requirement_doc.content)
|
||||
async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput:
|
||||
# ... requirement update logic ...
|
||||
for doc in related_docs:
|
||||
await self._update_prd(req, doc)
|
||||
return Documents.from_iterable(documents=related_docs).to_action_output()
|
||||
|
||||
async def _is_bugfix(self, context: str) -> bool:
|
||||
if not self.repo.code_files_exists():
|
||||
return False
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
|
||||
return node.get("issue_type") == "BUG"
|
||||
|
||||
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
|
||||
"""get the related documents"""
|
||||
# refine: use gather to speed up
|
||||
return [i for i in docs if await self._is_related(req, i)]
|
||||
|
||||
async def _is_related(self, req: Document, old_prd: Document) -> bool:
|
||||
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
|
||||
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
|
||||
return node.get("is_relative") == "YES"
|
||||
|
||||
async def _merge(self, new_requirement_doc, prd_doc) -> Document:
|
||||
async def _merge(self, req: Document, related_doc: Document) -> Document:
|
||||
if not self.project_name:
|
||||
self.project_name = Path(self.project_path).name
|
||||
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
|
||||
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
|
||||
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
|
||||
prd_doc.content = node.instruct_content.model_dump_json()
|
||||
related_doc.content = node.instruct_content.model_dump_json()
|
||||
await self._rename_workspace(node)
|
||||
return prd_doc
|
||||
return related_doc
|
||||
|
||||
async def _update_prd(self, requirement_doc, prd_doc=None, *args, **kwargs) -> Document | None:
|
||||
if not prd_doc:
|
||||
prd = await self._run_new_requirement(
|
||||
requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs
|
||||
)
|
||||
new_prd_doc = await self.project_repo.docs.prd.save(
|
||||
filename=FileRepository.new_filename() + ".json", content=prd.instruct_content.model_dump_json()
|
||||
)
|
||||
elif await self._is_relative(requirement_doc, prd_doc):
|
||||
new_prd_doc = await self._merge(requirement_doc, prd_doc)
|
||||
self.project_repo.docs.prd.save_doc(doc=new_prd_doc)
|
||||
else:
|
||||
return None
|
||||
async def _update_prd(self, req: Document, prd_doc: Document) -> Document:
|
||||
new_prd_doc: Document = await self._merge(req, prd_doc)
|
||||
self.repo.docs.prd.save_doc(doc=new_prd_doc)
|
||||
await self._save_competitive_analysis(new_prd_doc)
|
||||
await self.project_repo.resources.prd.save_pdf(doc=new_prd_doc)
|
||||
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
|
||||
return new_prd_doc
|
||||
|
||||
async def _save_competitive_analysis(self, prd_doc):
|
||||
async def _save_competitive_analysis(self, prd_doc: Document):
|
||||
m = json.loads(prd_doc.content)
|
||||
quadrant_chart = m.get("Competitive Quadrant Chart")
|
||||
if not quadrant_chart:
|
||||
return
|
||||
pathname = (
|
||||
self.project_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
|
||||
)
|
||||
if not pathname.parent.exists():
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
await mermaid_to_file(self.config.mermaid_engine, quadrant_chart, pathname)
|
||||
|
||||
async def _rename_workspace(self, prd):
|
||||
|
|
@ -158,15 +167,4 @@ class WritePRD(Action):
|
|||
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
|
||||
if ws_name:
|
||||
self.project_name = ws_name
|
||||
self.project_repo.git_repo.rename_root(self.project_name)
|
||||
|
||||
async def _is_bugfix(self, context) -> bool:
|
||||
git_workdir = self.project_repo.git_repo.workdir
|
||||
src_workdir = git_workdir / git_workdir.name
|
||||
if not src_workdir.exists():
|
||||
return False
|
||||
code_files = self.project_repo.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
|
||||
if not code_files:
|
||||
return False
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
|
||||
return node.get("issue_type") == "BUG"
|
||||
self.repo.git_repo.rename_root(self.project_name)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
||||
class AttrDict(BaseModel):
|
||||
|
|
@ -58,6 +59,8 @@ class Context(BaseModel):
|
|||
|
||||
kwargs: AttrDict = AttrDict()
|
||||
config: Config = Config.default()
|
||||
|
||||
repo: Optional[ProjectRepo] = None
|
||||
git_repo: Optional[GitRepository] = None
|
||||
src_workspace: Optional[Path] = None
|
||||
cost_manager: CostManager = CostManager()
|
||||
|
|
@ -67,8 +70,8 @@ class Context(BaseModel):
|
|||
def new_environ(self):
|
||||
"""Return a new os.environ object"""
|
||||
env = os.environ.copy()
|
||||
i = self.options
|
||||
env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
# i = self.options
|
||||
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
|
|
|
|||
|
|
@ -235,3 +235,7 @@ class OpenAILLM(BaseLLM):
|
|||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
"""Moderate content."""
|
||||
return await self.aclient.moderations.create(input=content)
|
||||
|
||||
async def atext_to_speech(self, **kwargs):
|
||||
"""text to speech"""
|
||||
return await self.aclient.audio.speech.create(**kwargs)
|
||||
|
|
|
|||
|
|
@ -481,6 +481,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
rsp = await self._act_by_order()
|
||||
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
rsp = await self._plan_and_act()
|
||||
else:
|
||||
raise ValueError(f"Unsupported react mode: {self.rc.react_mode}")
|
||||
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
|
||||
return rsp
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@
|
|||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import ActionOutput, SearchAndSummarize
|
||||
from metagpt.actions import SearchAndSummarize
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from abc import ABC
|
|||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
|
|
@ -162,6 +162,26 @@ class Documents(BaseModel):
|
|||
|
||||
docs: Dict[str, Document] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(cls, documents: Iterable[Document]) -> Documents:
|
||||
"""Create a Documents instance from a list of Document instances.
|
||||
|
||||
:param documents: A list of Document instances.
|
||||
:return: A Documents instance.
|
||||
"""
|
||||
|
||||
docs = {doc.filename: doc for doc in documents}
|
||||
return Documents(docs=docs)
|
||||
|
||||
def to_action_output(self) -> "ActionOutput":
|
||||
"""Convert to action output string.
|
||||
|
||||
:return: A string representing action output.
|
||||
"""
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
|
||||
return ActionOutput(content=self.model_dump_json(), instruct_content=self)
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""list[<role>: <content>]"""
|
||||
|
|
@ -212,7 +232,7 @@ class Message(BaseModel):
|
|||
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
@field_serializer("instruct_content", mode="plain")
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
|
||||
ic_dict = None
|
||||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
|
|
|
|||
|
|
@ -44,19 +44,20 @@ class SearchEngine:
|
|||
self,
|
||||
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper().run
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper().run
|
||||
run_func = importlib.import_module(module).SerperWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper().run
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper().run
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.utils.parse_html import WebPage
|
|||
class WebBrowserEngine:
|
||||
def __init__(
|
||||
self,
|
||||
engine: WebBrowserEngineType | None = WebBrowserEngineType.PLAYWRIGHT,
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT,
|
||||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
if engine is None:
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class SeleniumWrapper:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
browser_type: Literal["chrome", "firefox", "edge", "ie"] | None = None,
|
||||
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome",
|
||||
launch_kwargs: dict | None = None,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from metagpt.const import (
|
|||
GRAPH_REPO_FILE_REPO,
|
||||
PRD_PDF_FILE_REPO,
|
||||
PRDS_FILE_REPO,
|
||||
REQUIREMENT_FILENAME,
|
||||
RESOURCES_FILE_REPO,
|
||||
SD_OUTPUT_FILE_REPO,
|
||||
SEQ_FLOW_FILE_REPO,
|
||||
|
|
@ -93,6 +94,10 @@ class ProjectRepo(FileRepository):
|
|||
self.test_outputs = self._git_repo.new_file_repository(relative_path=TEST_OUTPUTS_FILE_REPO)
|
||||
self._srcs_path = None
|
||||
|
||||
@property
|
||||
async def requirement(self):
|
||||
return await self.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
|
||||
@property
|
||||
def git_repo(self) -> GitRepository:
|
||||
return self._git_repo
|
||||
|
|
@ -107,6 +112,15 @@ class ProjectRepo(FileRepository):
|
|||
raise ValueError("Call with_srcs first.")
|
||||
return self._git_repo.new_file_repository(self._srcs_path)
|
||||
|
||||
def code_files_exists(self) -> bool:
|
||||
git_workdir = self.git_repo.workdir
|
||||
src_workdir = git_workdir / git_workdir.name
|
||||
if not src_workdir.exists():
|
||||
return False
|
||||
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
|
||||
if not code_files:
|
||||
return False
|
||||
|
||||
def with_src_path(self, path: str | Path) -> ProjectRepo:
|
||||
try:
|
||||
self._srcs_path = Path(path).relative_to(self.workdir)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -20,6 +21,9 @@ from metagpt.context import CONTEXT
|
|||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from tests.mock.mock_aiohttp import MockAioResponse
|
||||
from tests.mock.mock_curl_cffi import MockCurlCffiResponse
|
||||
from tests.mock.mock_httplib2 import MockHttplib2Response
|
||||
from tests.mock.mock_llm import MockLLM
|
||||
|
||||
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
|
||||
|
|
@ -123,7 +127,7 @@ def proxy():
|
|||
server = await asyncio.start_server(handle_client, "127.0.0.1", 0)
|
||||
return server, "http://{}:{}".format(*server.sockets[0].getsockname())
|
||||
|
||||
return proxy_func()
|
||||
return proxy_func
|
||||
|
||||
|
||||
# see https://github.com/Delgan/loguru/issues/59#issuecomment-466591978
|
||||
|
|
@ -164,39 +168,63 @@ def new_filename(mocker):
|
|||
yield mocker
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_rsp_cache():
|
||||
rsp_cache_file_path = TEST_DATA_PATH / "search_rsp_cache.json" # read repo-provided
|
||||
if os.path.exists(rsp_cache_file_path):
|
||||
with open(rsp_cache_file_path, "r") as f1:
|
||||
rsp_cache_json = json.load(f1)
|
||||
else:
|
||||
rsp_cache_json = {}
|
||||
yield rsp_cache_json
|
||||
with open(rsp_cache_file_path, "w") as f2:
|
||||
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_mocker(mocker):
|
||||
class MockAioResponse:
|
||||
async def json(self, *args, **kwargs):
|
||||
return self._json
|
||||
|
||||
def set_json(self, json):
|
||||
self._json = json
|
||||
|
||||
response = MockAioResponse()
|
||||
|
||||
class MockCTXMng:
|
||||
async def __aenter__(self):
|
||||
return response
|
||||
|
||||
async def __aexit__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __await__(self):
|
||||
yield
|
||||
return response
|
||||
|
||||
def mock_request(self, method, url, **kwargs):
|
||||
return MockCTXMng()
|
||||
MockResponse = type("MockResponse", (MockAioResponse,), {})
|
||||
|
||||
def wrap(method):
|
||||
def run(self, url, **kwargs):
|
||||
return mock_request(self, method, url, **kwargs)
|
||||
return MockResponse(self, method, url, **kwargs)
|
||||
|
||||
return run
|
||||
|
||||
mocker.patch("aiohttp.ClientSession.request", mock_request)
|
||||
mocker.patch("aiohttp.ClientSession.request", MockResponse)
|
||||
for i in ["get", "post", "delete", "patch"]:
|
||||
mocker.patch(f"aiohttp.ClientSession.{i}", wrap(i))
|
||||
yield MockResponse
|
||||
|
||||
yield response
|
||||
|
||||
@pytest.fixture
|
||||
def curl_cffi_mocker(mocker):
|
||||
MockResponse = type("MockResponse", (MockCurlCffiResponse,), {})
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
return MockResponse(self, *args, **kwargs)
|
||||
|
||||
mocker.patch("curl_cffi.requests.Session.request", request)
|
||||
yield MockResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def httplib2_mocker(mocker):
|
||||
MockResponse = type("MockResponse", (MockHttplib2Response,), {})
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
return MockResponse(self, *args, **kwargs)
|
||||
|
||||
mocker.patch("httplib2.Http.request", request)
|
||||
yield MockResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_engine_mocker(aiohttp_mocker, curl_cffi_mocker, httplib2_mocker, search_rsp_cache):
|
||||
# aiohttp_mocker: serpapi/serper
|
||||
# httplib2_mocker: google
|
||||
# curl_cffi_mocker: ddg
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
aiohttp_mocker.rsp_cache = httplib2_mocker.rsp_cache = curl_cffi_mocker.rsp_cache = search_rsp_cache
|
||||
aiohttp_mocker.check_funcs = httplib2_mocker.check_funcs = curl_cffi_mocker.check_funcs = check_funcs
|
||||
yield check_funcs
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
879
tests/data/search_rsp_cache.json
Normal file
879
tests/data/search_rsp_cache.json
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -9,10 +9,12 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import research
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links(mocker):
|
||||
async def test_collect_links(mocker, search_engine_mocker):
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["metagpt", "llm"]'
|
||||
|
|
@ -26,13 +28,15 @@ async def test_collect_links(mocker):
|
|||
return "[1,2]"
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.CollectLinks().run("The application of MetaGPT")
|
||||
resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO)).run(
|
||||
"The application of MetaGPT"
|
||||
)
|
||||
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
|
||||
assert i in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links_with_rank_func(mocker):
|
||||
async def test_collect_links_with_rank_func(mocker, search_engine_mocker):
|
||||
rank_before = []
|
||||
rank_after = []
|
||||
url_per_query = 4
|
||||
|
|
@ -45,7 +49,9 @@ async def test_collect_links_with_rank_func(mocker):
|
|||
return results
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
|
||||
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
|
||||
resp = await research.CollectLinks(
|
||||
search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func
|
||||
).run("The application of MetaGPT")
|
||||
for x, y, z in zip(rank_before, rank_after, resp.values()):
|
||||
assert x[::-1] == y
|
||||
assert [i["link"] for i in y] == z
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.learn.text_to_embedding import text_to_embedding
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
|
@ -19,13 +19,14 @@ from metagpt.utils.common import aread
|
|||
@pytest.mark.asyncio
|
||||
async def test_text_to_embedding(mocker):
|
||||
# mock
|
||||
config = Config.default()
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.post")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.status = 200
|
||||
data = await aread(Path(__file__).parent / "../../data/openai/embedding.json")
|
||||
mock_response.json.return_value = json.loads(data)
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
|
||||
config.get_openai_llm().proxy = mocker.PropertyMock(return_value="http://mock.proxy")
|
||||
|
||||
# Prerequisites
|
||||
assert config.get_openai_llm().api_key
|
||||
|
|
|
|||
|
|
@ -42,11 +42,23 @@ async def test_aask_code_message():
|
|||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_speech():
|
||||
llm = LLM()
|
||||
resp = await llm.atext_to_speech(
|
||||
model="tts-1",
|
||||
voice="alloy",
|
||||
input="人生说起来长,但知道一个岁月回头看,许多事件仅是仓促的。一段一段拼凑一起,合成了人生。苦难当头时,当下不免觉得是折磨;回头看,也不够是一段短短的人生旅程。",
|
||||
)
|
||||
assert 200 == resp.response.status_code
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
def test_make_client_kwargs_without_proxy(self):
|
||||
instance = OpenAILLM(mock_llm_config)
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "mock_api_key", "base_url": "mock_base_url"}
|
||||
assert kwargs["api_key"] == "mock_api_key"
|
||||
assert kwargs["base_url"] == "mock_base_url"
|
||||
assert "http_client" not in kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy(self):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ from tempfile import TemporaryDirectory
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.research import CollectLinks
|
||||
from metagpt.roles import researcher
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
|
|
@ -25,12 +28,16 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_researcher(mocker):
|
||||
async def test_researcher(mocker, search_engine_mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
topic = "dataiku vs. datarobot"
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
await researcher.Researcher().run(topic)
|
||||
role = researcher.Researcher()
|
||||
for i in role.actions:
|
||||
if isinstance(i, CollectLinks):
|
||||
i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO)
|
||||
await role.run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,5 +18,5 @@ async def test_action_serdeser(new_filename):
|
|||
|
||||
new_action = WritePRD(**ser_action_dict)
|
||||
assert new_action.name == "WritePRD"
|
||||
action_output = await new_action.run(with_messages=Message(content="write a cli snake game"))
|
||||
assert len(action_output.content) > 0
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await new_action.run(with_messages=Message(content="write a cli snake game"))
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ def test_context_1():
|
|||
assert ctx.git_repo is None
|
||||
assert ctx.src_workspace is None
|
||||
assert ctx.cost_manager is not None
|
||||
assert ctx.options is not None
|
||||
|
||||
|
||||
def test_context_2():
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ def test_config_mixin_4_multi_inheritance_override_config():
|
|||
|
||||
print(obj.__dict__.keys())
|
||||
assert "private_config" in obj.__dict__.keys()
|
||||
assert obj.llm.model == "mock_zhipu_model"
|
||||
assert obj.config.llm.model == "mock_zhipu_model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
|
@ -18,6 +18,7 @@ from metagpt.utils.common import aread
|
|||
@pytest.mark.asyncio
|
||||
async def test_embedding(mocker):
|
||||
# mock
|
||||
config = Config.default()
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.post")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.status = 200
|
||||
|
|
|
|||
|
|
@ -7,20 +7,15 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import tests.data.search
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
search_cache_path = Path(tests.data.search.__path__[0])
|
||||
|
||||
|
||||
class MockSearchEnine:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
|
|
@ -46,24 +41,28 @@ class MockSearchEnine:
|
|||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
],
|
||||
)
|
||||
async def test_search_engine(search_engine_type, run_func: Callable, max_results: int, as_string: bool, aiohttp_mocker):
|
||||
async def test_search_engine(
|
||||
search_engine_type,
|
||||
run_func: Callable,
|
||||
max_results: int,
|
||||
as_string: bool,
|
||||
search_engine_mocker,
|
||||
):
|
||||
# Prerequisites
|
||||
cache_json_path = None
|
||||
# FIXME: 不能使用全局的config,而是要自己实例化对应的config
|
||||
search_engine_config = {}
|
||||
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert config.search
|
||||
cache_json_path = search_cache_path / f"serpapi-metagpt-{max_results}.json"
|
||||
search_engine_config["serpapi_api_key"] = "mock-serpapi-key"
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["google_api_key"] = "mock-google-key"
|
||||
search_engine_config["google_cse_id"] = "mock-google-cse"
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert config.search
|
||||
cache_json_path = search_cache_path / f"serper-metagpt-{max_results}.json"
|
||||
search_engine_config["serper_api_key"] = "mock-serper-key"
|
||||
|
||||
if cache_json_path:
|
||||
with open(cache_json_path) as f:
|
||||
data = json.load(f)
|
||||
aiohttp_mocker.set_json(data)
|
||||
search_engine = SearchEngine(search_engine_type, run_func)
|
||||
search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config)
|
||||
rsp = await search_engine.run("metagpt", max_results, as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
|
|||
global_proxy = config.proxy
|
||||
try:
|
||||
if use_proxy:
|
||||
server, proxy = await proxy
|
||||
config.proxy = proxy
|
||||
server, proxy_url = await proxy()
|
||||
config.proxy = proxy_url
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
|
|||
global_proxy = config.proxy
|
||||
try:
|
||||
if use_proxy:
|
||||
server, proxy = await proxy
|
||||
config.proxy = proxy
|
||||
server, proxy_url = await proxy()
|
||||
config.proxy = proxy_url
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
|
|
|
|||
41
tests/mock/mock_aiohttp.py
Normal file
41
tests/mock/mock_aiohttp.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
|
||||
from aiohttp.client import ClientSession
|
||||
|
||||
origin_request = ClientSession.request
|
||||
|
||||
|
||||
class MockAioResponse:
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
rsp_cache: dict[str, str] = {}
|
||||
name = "aiohttp"
|
||||
|
||||
def __init__(self, session, method, url, **kwargs) -> None:
|
||||
fn = self.check_funcs.get((method, url))
|
||||
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
|
||||
self.mng = self.response = None
|
||||
if self.key not in self.rsp_cache:
|
||||
self.mng = origin_request(session, method, url, **kwargs)
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.response:
|
||||
await self.response.__aenter__()
|
||||
elif self.mng:
|
||||
self.response = await self.mng.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args, **kwargs):
|
||||
if self.response:
|
||||
await self.response.__aexit__(*args, **kwargs)
|
||||
self.response = None
|
||||
elif self.mng:
|
||||
await self.mng.__aexit__(*args, **kwargs)
|
||||
self.mng = None
|
||||
|
||||
async def json(self, *args, **kwargs):
|
||||
if self.key in self.rsp_cache:
|
||||
return self.rsp_cache[self.key]
|
||||
data = await self.response.json(*args, **kwargs)
|
||||
self.rsp_cache[self.key] = data
|
||||
return data
|
||||
22
tests/mock/mock_curl_cffi.py
Normal file
22
tests/mock/mock_curl_cffi.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
|
||||
from curl_cffi import requests
|
||||
|
||||
origin_request = requests.Session.request
|
||||
|
||||
|
||||
class MockCurlCffiResponse(requests.Response):
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
rsp_cache: dict[str, str] = {}
|
||||
name = "curl-cffi"
|
||||
|
||||
def __init__(self, session, method, url, **kwargs) -> None:
|
||||
super().__init__()
|
||||
fn = self.check_funcs.get((method, url))
|
||||
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
|
||||
self.response = None
|
||||
if self.key not in self.rsp_cache:
|
||||
response = origin_request(session, method, url, **kwargs)
|
||||
self.rsp_cache[self.key] = response.content.decode()
|
||||
self.content = self.rsp_cache[self.key].encode()
|
||||
29
tests/mock/mock_httplib2.py
Normal file
29
tests/mock/mock_httplib2.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
from urllib.parse import parse_qsl, urlparse
|
||||
|
||||
import httplib2
|
||||
|
||||
origin_request = httplib2.Http.request
|
||||
|
||||
|
||||
class MockHttplib2Response(httplib2.Response):
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
rsp_cache: dict[str, str] = {}
|
||||
name = "httplib2"
|
||||
|
||||
def __init__(self, http, uri, method="GET", **kwargs) -> None:
|
||||
url = uri.split("?")[0]
|
||||
result = urlparse(uri)
|
||||
params = dict(parse_qsl(result.query))
|
||||
fn = self.check_funcs.get((method, uri))
|
||||
new_kwargs = {"params": params}
|
||||
key = f"{self.name}-{method}-{url}-{fn(new_kwargs) if fn else json.dumps(new_kwargs)}"
|
||||
if key not in self.rsp_cache:
|
||||
_, self.content = origin_request(http, uri, method, **kwargs)
|
||||
self.rsp_cache[key] = self.content.decode()
|
||||
self.content = self.rsp_cache[key]
|
||||
|
||||
def __iter__(self):
|
||||
yield self
|
||||
yield self.content.encode()
|
||||
Loading…
Add table
Add a link
Reference in a new issue