Merge pull request #724 from iorisa/feature/project_repo

feat: Implementation of ProjectRepo
This commit is contained in:
geekan 2024-01-12 17:54:32 +08:00 committed by GitHub
commit 1edff983f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
47 changed files with 690 additions and 533 deletions

View file

@ -21,7 +21,7 @@ from metagpt.schema import (
SerializationMixin,
TestingContext,
)
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.project_repo import ProjectRepo
class Action(SerializationMixin, ContextMixin, BaseModel):
@ -34,16 +34,8 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
node: ActionNode = Field(default=None, exclude=True)
@property
def git_repo(self):
return self.context.git_repo
@property
def file_repo(self):
return FileRepository(self.context.git_repo)
@property
def src_workspace(self):
return self.context.src_workspace
def project_repo(self):
return ProjectRepo(self.context.git_repo)
@property
def prompt_schema(self):

View file

@ -13,7 +13,6 @@ import re
from pydantic import Field
from metagpt.actions.action import Action
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.common import CodeParser
@ -50,9 +49,7 @@ class DebugError(Action):
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
async def run(self, *args, **kwargs) -> str:
output_doc = await self.file_repo.get_file(
filename=self.i_context.output_filename, relative_path=TEST_OUTPUTS_FILE_REPO
)
output_doc = await self.project_repo.test_outputs.get(filename=self.i_context.output_filename)
if not output_doc:
return ""
output_detail = RunCodeResult.loads(output_doc.content)
@ -62,14 +59,12 @@ class DebugError(Action):
return ""
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
code_doc = await self.file_repo.get_file(
filename=self.i_context.code_filename, relative_path=self.context.src_workspace
code_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
filename=self.i_context.code_filename
)
if not code_doc:
return ""
test_doc = await self.file_repo.get_file(
filename=self.i_context.test_filename, relative_path=TEST_CODES_FILE_REPO
)
test_doc = await self.project_repo.tests.get(filename=self.i_context.test_filename)
if not test_doc:
return ""
prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr)

View file

@ -15,13 +15,7 @@ from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.actions.design_api_an import DESIGN_API_NODE
from metagpt.const import (
DATA_API_DESIGN_FILE_REPO,
PRDS_FILE_REPO,
SEQ_FLOW_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
SYSTEM_DESIGN_PDF_FILE_REPO,
)
from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Document, Documents, Message
from metagpt.utils.mermaid import mermaid_to_file
@ -46,27 +40,21 @@ class WriteDesign(Action):
async def run(self, with_messages: Message, schema: str = None):
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
prds_file_repo = self.git_repo.new_file_repository(PRDS_FILE_REPO)
changed_prds = prds_file_repo.changed_files
changed_prds = self.project_repo.docs.prd.changed_files
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
# changes.
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files
changed_system_designs = self.project_repo.docs.system_design.changed_files
# For those PRDs and design documents that have undergone changes, regenerate the design content.
changed_files = Documents()
for filename in changed_prds.keys():
doc = await self._update_system_design(
filename=filename, prds_file_repo=prds_file_repo, system_design_file_repo=system_design_file_repo
)
doc = await self._update_system_design(filename=filename)
changed_files.docs[filename] = doc
for filename in changed_system_designs.keys():
if filename in changed_files.docs:
continue
doc = await self._update_system_design(
filename=filename, prds_file_repo=prds_file_repo, system_design_file_repo=system_design_file_repo
)
doc = await self._update_system_design(filename=filename)
changed_files.docs[filename] = doc
if not changed_files.docs:
logger.info("Nothing has changed.")
@ -84,24 +72,22 @@ class WriteDesign(Action):
system_design_doc.content = node.instruct_content.model_dump_json()
return system_design_doc
async def _update_system_design(self, filename, prds_file_repo, system_design_file_repo) -> Document:
prd = await prds_file_repo.get(filename)
old_system_design_doc = await system_design_file_repo.get(filename)
async def _update_system_design(self, filename) -> Document:
prd = await self.project_repo.docs.prd.get(filename)
old_system_design_doc = await self.project_repo.docs.system_design.get(filename)
if not old_system_design_doc:
system_design = await self._new_system_design(context=prd.content)
doc = Document(
root_path=SYSTEM_DESIGN_FILE_REPO,
doc = await self.project_repo.docs.system_design.save(
filename=filename,
content=system_design.instruct_content.model_dump_json(),
dependencies={prd.root_relative_path},
)
else:
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
await system_design_file_repo.save(
filename=filename, content=doc.content, dependencies={prd.root_relative_path}
)
await self.project_repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
await self._save_data_api_design(doc)
await self._save_seq_flow(doc)
await self._save_pdf(doc)
await self.project_repo.resources.system_design.save_pdf(doc=doc)
return doc
async def _save_data_api_design(self, design_doc):
@ -109,7 +95,7 @@ class WriteDesign(Action):
data_api_design = m.get("Data structures and interfaces")
if not data_api_design:
return
pathname = self.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
pathname = self.project_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(data_api_design, pathname)
logger.info(f"Save class view to {str(pathname)}")
@ -118,13 +104,10 @@ class WriteDesign(Action):
seq_flow = m.get("Program call flow")
if not seq_flow:
return
pathname = self.git_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
pathname = self.project_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(seq_flow, pathname)
logger.info(f"Saving sequence flow to {str(pathname)}")
async def _save_pdf(self, design_doc):
await self.file_repo.save_as(doc=design_doc, with_suffix=".md", relative_path=SYSTEM_DESIGN_PDF_FILE_REPO)
async def _save_mermaid_file(self, data: str, pathname: Path):
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(self.config.mermaid_engine, data, pathname)

View file

@ -12,8 +12,7 @@ from pathlib import Path
from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.schema import Document
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
@ -38,7 +37,6 @@ class PrepareDocuments(Action):
if path.exists() and not self.config.inc:
shutil.rmtree(path)
self.config.project_path = path
self.config.project_name = path.name
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
async def run(self, with_messages, **kwargs):
@ -46,9 +44,7 @@ class PrepareDocuments(Action):
self._init_repo()
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
doc = Document(root_path=DOCS_FILE_REPO, filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
await self.file_repo.save_file(filename=REQUIREMENT_FILENAME, content=doc.content, relative_path=DOCS_FILE_REPO)
doc = await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
# Send a Message notification to the WritePRD action, instructing it to process requirements using
# `docs/requirement.txt` and `docs/prds/`.
return ActionOutput(content=doc.content, instruct_content=doc)

View file

@ -16,12 +16,7 @@ from typing import Optional
from metagpt.actions import ActionOutput
from metagpt.actions.action import Action
from metagpt.actions.project_management_an import PM_NODE
from metagpt.const import (
PACKAGE_REQUIREMENTS_FILENAME,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
TASK_PDF_FILE_REPO,
)
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
from metagpt.logs import logger
from metagpt.schema import Document, Documents
@ -39,27 +34,20 @@ class WriteTasks(Action):
i_context: Optional[str] = None
async def run(self, with_messages):
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files
tasks_file_repo = self.git_repo.new_file_repository(TASK_FILE_REPO)
changed_tasks = tasks_file_repo.changed_files
changed_system_designs = self.project_repo.docs.system_design.changed_files
changed_tasks = self.project_repo.docs.task.changed_files
change_files = Documents()
# Rewrite the system designs that have undergone changes based on the git head diff under
# `docs/system_designs/`.
for filename in changed_system_designs:
task_doc = await self._update_tasks(
filename=filename, system_design_file_repo=system_design_file_repo, tasks_file_repo=tasks_file_repo
)
task_doc = await self._update_tasks(filename=filename)
change_files.docs[filename] = task_doc
# Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`.
for filename in changed_tasks:
if filename in change_files.docs:
continue
task_doc = await self._update_tasks(
filename=filename, system_design_file_repo=system_design_file_repo, tasks_file_repo=tasks_file_repo
)
task_doc = await self._update_tasks(filename=filename)
change_files.docs[filename] = task_doc
if not change_files.docs:
@ -68,21 +56,22 @@ class WriteTasks(Action):
# global optimization in subsequent steps.
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _update_tasks(self, filename, system_design_file_repo, tasks_file_repo):
system_design_doc = await system_design_file_repo.get(filename)
task_doc = await tasks_file_repo.get(filename)
async def _update_tasks(self, filename):
system_design_doc = await self.project_repo.docs.system_design.get(filename)
task_doc = await self.project_repo.docs.task.get(filename)
if task_doc:
task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc)
await self.project_repo.docs.task.save_doc(
doc=task_doc, dependencies={system_design_doc.root_relative_path}
)
else:
rsp = await self._run_new_tasks(context=system_design_doc.content)
task_doc = Document(
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.model_dump_json()
task_doc = await self.project_repo.docs.task.save(
filename=filename,
content=rsp.instruct_content.model_dump_json(),
dependencies={system_design_doc.root_relative_path},
)
await tasks_file_repo.save(
filename=filename, content=task_doc.content, dependencies={system_design_doc.root_relative_path}
)
await self._update_requirements(task_doc)
await self._save_pdf(task_doc=task_doc)
return task_doc
async def _run_new_tasks(self, context):
@ -98,8 +87,7 @@ class WriteTasks(Action):
async def _update_requirements(self, doc):
m = json.loads(doc.content)
packages = set(m.get("Required Python third-party packages", set()))
file_repo = self.git_repo.new_file_repository()
requirement_doc = await file_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
requirement_doc = await self.project_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
if not requirement_doc:
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
lines = requirement_doc.content.splitlines()
@ -107,7 +95,4 @@ class WriteTasks(Action):
if pkg == "":
continue
packages.add(pkg)
await file_repo.save(PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
async def _save_pdf(self, task_doc):
await self.file_repo.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
await self.project_repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))

View file

@ -29,9 +29,7 @@ class ArgumentsParingAction(Action):
@property
def prompt(self):
prompt = "You are a function parser. You can convert spoken words into function parameters.\n"
prompt += "\n---\n"
prompt += f"{self.skill.name} function parameters description:\n"
prompt = f"{self.skill.name} function parameters description:\n"
for k, v in self.skill.arguments.items():
prompt += f"parameter `{k}`: {v}\n"
prompt += "\n---\n"
@ -49,7 +47,10 @@ class ArgumentsParingAction(Action):
async def run(self, with_message=None, **kwargs) -> Message:
prompt = self.prompt
rsp = await self.llm.aask(msg=prompt, system_msgs=[])
rsp = await self.llm.aask(
msg=prompt,
system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."],
)
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)
self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self)

View file

@ -11,7 +11,6 @@ from pydantic import Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action import Action
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
@ -99,11 +98,10 @@ class SummarizeCode(Action):
async def run(self):
design_pathname = Path(self.i_context.design_filename)
repo = self.file_repo
design_doc = await repo.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
design_doc = await self.project_repo.docs.system_design.get(filename=design_pathname.name)
task_pathname = Path(self.i_context.task_filename)
task_doc = await repo.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
src_file_repo = self.git_repo.new_file_repository(relative_path=self.context.src_workspace)
task_doc = await self.project_repo.docs.task.get(filename=task_pathname.name)
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
code_blocks = []
for filename in self.i_context.codes_filenames:
code_doc = await src_file_repo.get(filename)

View file

@ -21,13 +21,7 @@ from pydantic import Field
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action import Action
from metagpt.const import (
BUGFIX_FILENAME,
CODE_SUMMARIES_FILE_REPO,
DOCS_FILE_REPO,
TASK_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.const import BUGFIX_FILENAME
from metagpt.logs import logger
from metagpt.schema import CodingContext, Document, RunCodeResult
from metagpt.utils.common import CodeParser
@ -94,16 +88,12 @@ class WriteCode(Action):
return code
async def run(self, *args, **kwargs) -> CodingContext:
bug_feedback = await self.file_repo.get_file(filename=BUGFIX_FILENAME, relative_path=DOCS_FILE_REPO)
bug_feedback = await self.project_repo.docs.get(filename=BUGFIX_FILENAME)
coding_context = CodingContext.loads(self.i_context.content)
test_doc = await self.file_repo.get_file(
filename="test_" + coding_context.filename + ".json", relative_path=TEST_OUTPUTS_FILE_REPO
)
test_doc = await self.project_repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
summary_doc = None
if coding_context.design_doc and coding_context.design_doc.filename:
summary_doc = await self.file_repo.get_file(
filename=coding_context.design_doc.filename, relative_path=CODE_SUMMARIES_FILE_REPO
)
summary_doc = await self.project_repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
logs = ""
if test_doc:
test_detail = RunCodeResult.loads(test_doc.content)
@ -115,8 +105,7 @@ class WriteCode(Action):
code_context = await self.get_codes(
coding_context.task_doc,
exclude=self.i_context.filename,
git_repo=self.git_repo,
src_workspace=self.context.src_workspace,
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
)
prompt = PROMPT_TEMPLATE.format(
@ -138,16 +127,15 @@ class WriteCode(Action):
return coding_context
@staticmethod
async def get_codes(task_doc, exclude, git_repo, src_workspace) -> str:
async def get_codes(task_doc, exclude, project_repo) -> str:
if not task_doc:
return ""
if not task_doc.content:
file_repo = git_repo.new_file_repository()
task_doc.content = file_repo.get_file(filename=task_doc.filename, relative_path=TASK_FILE_REPO)
task_doc = project_repo.docs.task.get(filename=task_doc.filename)
m = json.loads(task_doc.content)
code_filenames = m.get("Task list", [])
codes = []
src_file_repo = git_repo.new_file_repository(relative_path=src_workspace)
src_file_repo = project_repo.srcs
for filename in code_filenames:
if filename == exclude:
continue

View file

@ -143,8 +143,7 @@ class WriteCodeReview(Action):
code_context = await WriteCode.get_codes(
self.i_context.task_doc,
exclude=self.i_context.filename,
git_repo=self.context.git_repo,
src_workspace=self.src_workspace,
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
)
context = "\n".join(
[

View file

@ -29,9 +29,6 @@ from metagpt.actions.write_prd_an import (
from metagpt.const import (
BUGFIX_FILENAME,
COMPETITIVE_ANALYSIS_FILE_REPO,
DOCS_FILE_REPO,
PRD_PDF_FILE_REPO,
PRDS_FILE_REPO,
REQUIREMENT_FILENAME,
)
from metagpt.logs import logger
@ -67,11 +64,10 @@ class WritePRD(Action):
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
docs_file_repo = self.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME)
requirement_doc = await self.project_repo.docs.get(filename=REQUIREMENT_FILENAME)
if requirement_doc and await self._is_bugfix(requirement_doc.content):
await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="")
await self.project_repo.docs.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
@ -82,24 +78,19 @@ class WritePRD(Action):
send_to="Alex", # the name of Engineer
)
else:
await docs_file_repo.delete(filename=BUGFIX_FILENAME)
await self.project_repo.docs.delete(filename=BUGFIX_FILENAME)
prds_file_repo = self.git_repo.new_file_repository(PRDS_FILE_REPO)
prd_docs = await prds_file_repo.get_all()
prd_docs = await self.project_repo.docs.prd.get_all()
change_files = Documents()
for prd_doc in prd_docs:
prd_doc = await self._update_prd(
requirement_doc=requirement_doc, prd_doc=prd_doc, prds_file_repo=prds_file_repo, *args, **kwargs
)
prd_doc = await self._update_prd(requirement_doc=requirement_doc, prd_doc=prd_doc, *args, **kwargs)
if not prd_doc:
continue
change_files.docs[prd_doc.filename] = prd_doc
logger.info(f"rewrite prd: {prd_doc.filename}")
# If there is no existing PRD, generate one using 'docs/requirement.txt'.
if not change_files.docs:
prd_doc = await self._update_prd(
requirement_doc=requirement_doc, prd_doc=None, prds_file_repo=prds_file_repo, *args, **kwargs
)
prd_doc = await self._update_prd(requirement_doc=requirement_doc, *args, **kwargs)
if prd_doc:
change_files.docs[prd_doc.filename] = prd_doc
logger.debug(f"new prd: {prd_doc.filename}")
@ -109,13 +100,6 @@ class WritePRD(Action):
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _run_new_requirement(self, requirements) -> ActionOutput:
# sas = SearchAndSummarize()
# # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
# rsp = ""
# info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}"
# if sas.result:
# logger.info(sas.result)
# logger.info(rsp)
project_name = self.project_name
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
@ -137,23 +121,21 @@ class WritePRD(Action):
await self._rename_workspace(node)
return prd_doc
async def _update_prd(self, requirement_doc, prd_doc, prds_file_repo, *args, **kwargs) -> Document | None:
async def _update_prd(self, requirement_doc, prd_doc=None, *args, **kwargs) -> Document | None:
if not prd_doc:
prd = await self._run_new_requirement(
requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs
)
new_prd_doc = Document(
root_path=PRDS_FILE_REPO,
filename=FileRepository.new_filename() + ".json",
content=prd.instruct_content.model_dump_json(),
new_prd_doc = await self.project_repo.docs.prd.save(
filename=FileRepository.new_filename() + ".json", content=prd.instruct_content.model_dump_json()
)
elif await self._is_relative(requirement_doc, prd_doc):
new_prd_doc = await self._merge(requirement_doc, prd_doc)
self.project_repo.docs.prd.save_doc(doc=new_prd_doc)
else:
return None
await prds_file_repo.save(filename=new_prd_doc.filename, content=new_prd_doc.content)
await self._save_competitive_analysis(new_prd_doc)
await self._save_pdf(new_prd_doc)
await self.project_repo.resources.prd.save_pdf(doc=new_prd_doc)
return new_prd_doc
async def _save_competitive_analysis(self, prd_doc):
@ -161,14 +143,13 @@ class WritePRD(Action):
quadrant_chart = m.get("Competitive Quadrant Chart")
if not quadrant_chart:
return
pathname = self.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
pathname = (
self.project_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
)
if not pathname.parent.exists():
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(self.config.mermaid_engine, quadrant_chart, pathname)
async def _save_pdf(self, prd_doc):
await self.file_repo.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO)
async def _rename_workspace(self, prd):
if not self.project_name:
if isinstance(prd, (ActionOutput, ActionNode)):
@ -177,11 +158,14 @@ class WritePRD(Action):
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
if ws_name:
self.project_name = ws_name
self.git_repo.rename_root(self.project_name)
self.project_repo.git_repo.rename_root(self.project_name)
async def _is_bugfix(self, context) -> bool:
src_workspace_path = self.git_repo.workdir / self.git_repo.workdir.name
code_files = self.git_repo.get_files(relative_path=src_workspace_path)
git_workdir = self.project_repo.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
if not src_workdir.exists():
return False
code_files = self.project_repo.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)

View file

@ -8,7 +8,7 @@
from typing import Optional
from metagpt.actions import Action
from metagpt.context import CONTEXT
from metagpt.context import Context
from metagpt.logs import logger
@ -24,7 +24,7 @@ class WriteTeachingPlanPart(Action):
statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, [])
statements = []
for p in statement_patterns:
s = self.format_value(p)
s = self.format_value(p, context=self.context)
statements.append(s)
formatter = (
TeachingPlanBlock.PROMPT_TITLE_TEMPLATE
@ -68,21 +68,23 @@ class WriteTeachingPlanPart(Action):
return self.topic
@staticmethod
def format_value(value):
def format_value(value, context: Context):
"""Fill parameters inside `value` with `options`."""
if not isinstance(value, str):
return value
if "{" not in value:
return value
# FIXME: 从Context中获取参数而非从options
merged_opts = CONTEXT.options or {}
options = context.config.model_dump()
for k, v in context.kwargs:
options[k] = v # None value is allowed to override and disable the value from config.
opts = {k: v for k, v in options.items() if v is not None}
try:
return value.format(**merged_opts)
return value.format(**opts)
except KeyError as e:
logger.warning(f"Parameter is missing:{e}")
for k, v in merged_opts.items():
for k, v in opts.items():
value = value.replace("{" + f"{k}" + "}", str(v))
return value

View file

@ -89,23 +89,23 @@ BUGFIX_FILENAME = "bugfix.txt"
PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt"
DOCS_FILE_REPO = "docs"
PRDS_FILE_REPO = "docs/prds"
PRDS_FILE_REPO = "docs/prd"
SYSTEM_DESIGN_FILE_REPO = "docs/system_design"
TASK_FILE_REPO = "docs/tasks"
TASK_FILE_REPO = "docs/task"
COMPETITIVE_ANALYSIS_FILE_REPO = "resources/competitive_analysis"
DATA_API_DESIGN_FILE_REPO = "resources/data_api_design"
SEQ_FLOW_FILE_REPO = "resources/seq_flow"
SYSTEM_DESIGN_PDF_FILE_REPO = "resources/system_design"
PRD_PDF_FILE_REPO = "resources/prd"
TASK_PDF_FILE_REPO = "resources/api_spec_and_tasks"
TASK_PDF_FILE_REPO = "resources/api_spec_and_task"
TEST_CODES_FILE_REPO = "tests"
TEST_OUTPUTS_FILE_REPO = "test_outputs"
CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
CODE_SUMMARIES_FILE_REPO = "docs/code_summary"
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
RESOURCES_FILE_REPO = "resources"
SD_OUTPUT_FILE_REPO = "resources/SD_Output"
SD_OUTPUT_FILE_REPO = "resources/sd_output"
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
CLASS_VIEW_FILE_REPO = "docs/class_views"
CLASS_VIEW_FILE_REPO = "docs/class_view"
YAPI_URL = "http://yapi.deepwisdomai.com/"

View file

@ -7,13 +7,12 @@
"""
import os
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from metagpt.config2 import Config
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import OPTIONS
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.utils.cost_manager import CostManager
@ -41,6 +40,16 @@ class AttrDict(BaseModel):
else:
raise AttributeError(f"No such attribute: {key}")
def set(self, key, val: Any):
self.__dict__[key] = val
def get(self, key, default: Any = None):
return self.__dict__.get(key, default)
def remove(self, key):
if key in self.__dict__:
self.__delattr__(key)
class Context(BaseModel):
"""Env context for MetaGPT"""
@ -55,15 +64,6 @@ class Context(BaseModel):
_llm: Optional[BaseLLM] = None
@property
def file_repo(self):
return self.git_repo.new_file_repository()
@property
def options(self):
"""Return all key-values"""
return OPTIONS.get()
def new_environ(self):
"""Return a new os.environ object"""
env = os.environ.copy()

View file

@ -6,16 +6,19 @@
@File : text_to_embedding.py
@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality.
"""
import metagpt.config2
from metagpt.config2 import Config
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs):
async def text_to_embedding(text, model="text-embedding-ada-002", config: Config = metagpt.config2.config):
"""Text to embedding
:param text: The text used for embedding.
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key)
openai_api_key = config.get_openai_llm().api_key
proxy = config.get_openai_llm().proxy
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key, proxy=proxy)

View file

@ -8,6 +8,7 @@
"""
import base64
import metagpt.config2
from metagpt.config2 import Config
from metagpt.const import BASE64_FORMAT
from metagpt.llm import LLM
@ -16,27 +17,26 @@ from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
from metagpt.utils.s3 import S3
async def text_to_image(text, size_type: str = "512x512", model_url="", config: Config = None):
async def text_to_image(text, size_type: str = "512x512", config: Config = metagpt.config2.config):
"""Text to image
:param text: The text used for image conversion.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768'].
:param model_url: MetaGPT model url
:param config: Config
:return: The image data is returned in Base64 encoding.
"""
image_declaration = "data:image/png;base64,"
model_url = config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
if model_url:
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
elif oai_llm := config.get_openai_llm():
binary_data = await oas3_openai_text_to_image(text, size_type, LLM(oai_llm))
elif config.get_openai_llm():
llm = LLM(llm_config=config.get_openai_llm())
binary_data = await oas3_openai_text_to_image(text, size_type, llm=llm)
else:
raise ValueError("Missing necessary parameters.")
base64_data = base64.b64encode(binary_data).decode("utf-8")
assert config.s3, "S3 config is required."
s3 = S3(config.s3)
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT)
if url:

View file

@ -6,8 +6,8 @@
@File : text_to_speech.py
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
"""
from metagpt.config2 import config
import metagpt.config2
from metagpt.config2 import Config
from metagpt.const import BASE64_FORMAT
from metagpt.tools.azure_tts import oas3_azsure_tts
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
@ -20,12 +20,7 @@ async def text_to_speech(
voice="zh-CN-XiaomoNeural",
style="affectionate",
role="Girl",
subscription_key="",
region="",
iflytek_app_id="",
iflytek_api_key="",
iflytek_api_secret="",
**kwargs,
config: Config = metagpt.config2.config,
):
"""Text to speech
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
@ -44,6 +39,8 @@ async def text_to_speech(
"""
subscription_key = config.AZURE_TTS_SUBSCRIPTION_KEY
region = config.AZURE_TTS_REGION
if subscription_key and region:
audio_declaration = "data:audio/wav;base64,"
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
@ -52,6 +49,10 @@ async def text_to_speech(
if url:
return f"[{text}]({url})"
return audio_declaration + base64_data if base64_data else base64_data
iflytek_app_id = config.IFLYTEK_APP_ID
iflytek_api_key = config.IFLYTEK_API_KEY
iflytek_api_secret = config.IFLYTEK_API_SECRET
if iflytek_app_id and iflytek_api_key and iflytek_api_secret:
audio_declaration = "data:audio/mp3;base64,"
base64_data = await oas3_iflytek_tts(

View file

@ -65,7 +65,7 @@ class Assistant(Role):
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
rsp = await self.llm.aask(prompt, [])
rsp = await self.llm.aask(prompt, ["You are an action classifier"])
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
return await self._plan(rsp, last_talk=last_talk)
@ -98,9 +98,7 @@ class Assistant(Role):
history = self.memory.history_text
text = kwargs.get("last_talk") or text
self.set_todo(
TalkAction(
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs
)
TalkAction(i_context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm)
)
return True
@ -110,7 +108,7 @@ class Assistant(Role):
if not skill:
logger.info(f"skill not found: {text}")
return await self.talk_handler(text=last_talk, **kwargs)
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk, **kwargs)
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk)
await action.run(**kwargs)
if action.args is None:
return await self.talk_handler(text=last_talk, **kwargs)

View file

@ -27,12 +27,7 @@ from typing import Set
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
from metagpt.actions.fix_bug import FixBug
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import (
CODE_SUMMARIES_FILE_REPO,
CODE_SUMMARIES_PDF_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import (
@ -97,7 +92,6 @@ class Engineer(Role):
async def _act_sp_with_cr(self, review=False) -> Set[str]:
changed_files = set()
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
for todo in self.code_todos:
"""
# Select essential information from the historical data to reduce the length of the prompt (summarized from human experience):
@ -112,8 +106,8 @@ class Engineer(Role):
action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm)
self._init_action(action)
coding_context = await action.run()
await src_file_repo.save(
coding_context.filename,
await self.project_repo.srcs.save(
filename=coding_context.filename,
dependencies={coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path},
content=coding_context.code_doc.content,
)
@ -153,31 +147,29 @@ class Engineer(Role):
)
async def _act_summarize(self):
code_summaries_file_repo = self.git_repo.new_file_repository(CODE_SUMMARIES_FILE_REPO)
code_summaries_pdf_file_repo = self.git_repo.new_file_repository(CODE_SUMMARIES_PDF_FILE_REPO)
tasks = []
src_relative_path = self.src_workspace.relative_to(self.git_repo.workdir)
for todo in self.summarize_todos:
summary = await todo.run()
summary_filename = Path(todo.i_context.design_filename).with_suffix(".md").name
dependencies = {todo.i_context.design_filename, todo.i_context.task_filename}
for filename in todo.i_context.codes_filenames:
rpath = src_relative_path / filename
rpath = self.project_repo.src_relative_path / filename
dependencies.add(str(rpath))
await code_summaries_pdf_file_repo.save(
await self.project_repo.resources.code_summary.save(
filename=summary_filename, content=summary, dependencies=dependencies
)
is_pass, reason = await self._is_pass(summary)
if not is_pass:
todo.i_context.reason = reason
tasks.append(todo.i_context.dict())
await code_summaries_file_repo.save(
await self.project_repo.docs.code_summary.save(
filename=Path(todo.i_context.design_filename).name,
content=todo.i_context.model_dump_json(),
dependencies=dependencies,
)
else:
await code_summaries_file_repo.delete(filename=Path(todo.i_context.design_filename).name)
await self.project_repo.docs.code_summary.delete(filename=Path(todo.i_context.design_filename).name)
logger.info(f"--max-auto-summarize-code={self.config.max_auto_summarize_code}")
if not tasks or self.config.max_auto_summarize_code == 0:
@ -220,60 +212,54 @@ class Engineer(Role):
return self.rc.todo
return None
@staticmethod
async def _new_coding_context(
filename, src_file_repo, task_file_repo, design_file_repo, dependency
) -> CodingContext:
old_code_doc = await src_file_repo.get(filename)
async def _new_coding_context(self, filename, dependency) -> CodingContext:
old_code_doc = await self.project_repo.srcs.get(filename)
if not old_code_doc:
old_code_doc = Document(root_path=str(src_file_repo.root_path), filename=filename, content="")
old_code_doc = Document(root_path=str(self.project_repo.src_relative_path), filename=filename, content="")
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
task_doc = None
design_doc = None
for i in dependencies:
if str(i.parent) == TASK_FILE_REPO:
task_doc = await task_file_repo.get(i.name)
task_doc = await self.project_repo.docs.task.get(i.name)
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
design_doc = await design_file_repo.get(i.name)
design_doc = await self.project_repo.docs.system_design.get(i.name)
if not task_doc or not design_doc:
logger.error(f'Detected source code "{filename}" from an unknown origin.')
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
context = CodingContext(filename=filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc)
return context
@staticmethod
async def _new_coding_doc(filename, src_file_repo, task_file_repo, design_file_repo, dependency):
context = await Engineer._new_coding_context(
filename, src_file_repo, task_file_repo, design_file_repo, dependency
)
async def _new_coding_doc(self, filename, dependency):
context = await self._new_coding_context(filename, dependency)
coding_doc = Document(
root_path=str(src_file_repo.root_path), filename=filename, content=context.model_dump_json()
root_path=str(self.project_repo.src_relative_path), filename=filename, content=context.model_dump_json()
)
return coding_doc
async def _new_code_actions(self, bug_fix=False):
# Prepare file repos
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
changed_src_files = src_file_repo.all_files if bug_fix else src_file_repo.changed_files
task_file_repo = self.git_repo.new_file_repository(TASK_FILE_REPO)
changed_task_files = task_file_repo.changed_files
design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_src_files = self.project_repo.srcs.all_files if bug_fix else self.project_repo.srcs.changed_files
changed_task_files = self.project_repo.docs.task.changed_files
changed_files = Documents()
# Recode caused by upstream changes.
for filename in changed_task_files:
design_doc = await design_file_repo.get(filename)
task_doc = await task_file_repo.get(filename)
design_doc = await self.project_repo.docs.system_design.get(filename)
task_doc = await self.project_repo.docs.task.get(filename)
task_list = self._parse_tasks(task_doc)
for task_filename in task_list:
old_code_doc = await src_file_repo.get(task_filename)
old_code_doc = await self.project_repo.srcs.get(task_filename)
if not old_code_doc:
old_code_doc = Document(root_path=str(src_file_repo.root_path), filename=task_filename, content="")
old_code_doc = Document(
root_path=str(self.project_repo.src_relative_path), filename=task_filename, content=""
)
context = CodingContext(
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
)
coding_doc = Document(
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.model_dump_json()
root_path=str(self.project_repo.src_relative_path),
filename=task_filename,
content=context.model_dump_json(),
)
if task_filename in changed_files.docs:
logger.warning(
@ -289,13 +275,7 @@ class Engineer(Role):
for filename in changed_src_files:
if filename in changed_files.docs:
continue
coding_doc = await self._new_coding_doc(
filename=filename,
src_file_repo=src_file_repo,
task_file_repo=task_file_repo,
design_file_repo=design_file_repo,
dependency=dependency,
)
coding_doc = await self._new_coding_doc(filename=filename, dependency=dependency)
changed_files.docs[filename] = coding_doc
self.code_todos.append(WriteCode(i_context=coding_doc, context=self.context, llm=self.llm))
@ -303,13 +283,12 @@ class Engineer(Role):
self.set_todo(self.code_todos[0])
async def _new_summarize_actions(self):
src_file_repo = self.git_repo.new_file_repository(self.src_workspace)
src_files = src_file_repo.all_files
src_files = self.project_repo.srcs.all_files
# Generate a SummarizeCode action for each pair of (system_design_doc, task_doc).
summarizations = defaultdict(list)
for filename in src_files:
dependencies = await src_file_repo.get_dependency(filename=filename)
ctx = CodeSummarizeContext.loads(filenames=dependencies)
dependencies = await self.project_repo.srcs.get_dependency(filename=filename)
ctx = CodeSummarizeContext.loads(filenames=list(dependencies))
summarizations[ctx].append(filename)
for ctx, filenames in summarizations.items():
ctx.codes_filenames = filenames

View file

@ -17,11 +17,7 @@
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import (
MESSAGE_ROUTE_TO_NONE,
TEST_CODES_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.const import MESSAGE_ROUTE_TO_NONE
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Document, Message, RunCodeContext, TestingContext
@ -48,29 +44,26 @@ class QaEngineer(Role):
self.test_round = 0
async def _write_test(self, message: Message) -> None:
src_file_repo = self.context.git_repo.new_file_repository(self.context.src_workspace)
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
changed_files = set(src_file_repo.changed_files.keys())
# Unit tests only.
if self.config.reqa_file and self.config.reqa_file not in changed_files:
changed_files.add(self.config.reqa_file)
tests_file_repo = self.context.git_repo.new_file_repository(TEST_CODES_FILE_REPO)
for filename in changed_files:
# write tests
if not filename or "test" in filename:
continue
code_doc = await src_file_repo.get(filename)
test_doc = await tests_file_repo.get("test_" + code_doc.filename)
test_doc = await self.project_repo.tests.get("test_" + code_doc.filename)
if not test_doc:
test_doc = Document(
root_path=str(tests_file_repo.root_path), filename="test_" + code_doc.filename, content=""
root_path=str(self.project_repo.tests.root_path), filename="test_" + code_doc.filename, content=""
)
logger.info(f"Writing {test_doc.filename}..")
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
context = await WriteTest(i_context=context, context=self.context, llm=self.llm).run()
await tests_file_repo.save(
filename=context.test_doc.filename,
content=context.test_doc.content,
dependencies={context.code_doc.root_relative_path},
await self.project_repo.tests.save_doc(
doc=context.test_doc, dependencies={context.code_doc.root_relative_path}
)
# prepare context for run tests in next round
@ -78,7 +71,7 @@ class QaEngineer(Role):
command=["python", context.test_doc.root_relative_path],
code_filename=context.code_doc.filename,
test_filename=context.test_doc.filename,
working_directory=str(self.context.git_repo.workdir),
working_directory=str(self.project_repo.workdir),
additional_python_paths=[str(self.context.src_workspace)],
)
self.publish_message(
@ -91,25 +84,23 @@ class QaEngineer(Role):
)
)
logger.info(f"Done {str(tests_file_repo.workdir)} generating.")
logger.info(f"Done {str(self.project_repo.tests.workdir)} generating.")
async def _run_code(self, msg):
run_code_context = RunCodeContext.loads(msg.content)
src_doc = await self.context.git_repo.new_file_repository(self.context.src_workspace).get(
src_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
run_code_context.code_filename
)
if not src_doc:
return
test_doc = await self.context.git_repo.new_file_repository(TEST_CODES_FILE_REPO).get(
run_code_context.test_filename
)
test_doc = await self.project_repo.tests.get(run_code_context.test_filename)
if not test_doc:
return
run_code_context.code = src_doc.content
run_code_context.test_code = test_doc.content
result = await RunCode(i_context=run_code_context, context=self.context, llm=self.llm).run()
run_code_context.output_filename = run_code_context.test_filename + ".json"
await self.context.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
await self.project_repo.test_outputs.save(
filename=run_code_context.output_filename,
content=result.model_dump_json(),
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},
@ -132,9 +123,7 @@ class QaEngineer(Role):
async def _debug_error(self, msg):
run_code_context = RunCodeContext.loads(msg.content)
code = await DebugError(i_context=run_code_context, context=self.context, llm=self.llm).run()
await self.context.file_repo.save_file(
filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO
)
await self.project_repo.tests.save(filename=run_code_context.test_filename, content=code)
run_code_context.output = None
self.publish_message(
Message(

View file

@ -36,6 +36,7 @@ from metagpt.memory import Memory
from metagpt.provider import HumanProvider
from metagpt.schema import Message, MessageQueue, SerializationMixin
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.project_repo import ProjectRepo
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
@ -199,6 +200,11 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
def src_workspace(self, value):
self.context.src_workspace = value
@property
def project_repo(self) -> ProjectRepo:
project_repo = ProjectRepo(self.context.git_repo)
return project_repo.with_src_path(self.context.src_workspace) if self.context.src_workspace else project_repo
@property
def prompt_schema(self):
"""Prompt schema: json/markdown"""
@ -449,7 +455,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
break
# act
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
rsp = await self._act() # 这个rsp是否需要publish_message
rsp = await self._act()
actions_taken += 1
return rsp # return output from the last action

View file

@ -31,11 +31,11 @@ class Teacher(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = WriteTeachingPlanPart.format_value(self.name)
self.profile = WriteTeachingPlanPart.format_value(self.profile)
self.goal = WriteTeachingPlanPart.format_value(self.goal)
self.constraints = WriteTeachingPlanPart.format_value(self.constraints)
self.desc = WriteTeachingPlanPart.format_value(self.desc)
self.name = WriteTeachingPlanPart.format_value(self.name, self.context)
self.profile = WriteTeachingPlanPart.format_value(self.profile, self.context)
self.goal = WriteTeachingPlanPart.format_value(self.goal, self.context)
self.constraints = WriteTeachingPlanPart.format_value(self.constraints, self.context)
self.desc = WriteTeachingPlanPart.format_value(self.desc, self.context)
async def _think(self) -> bool:
"""Everything will be done part by part."""

View file

@ -13,7 +13,6 @@ import aiohttp
import requests
from pydantic import BaseModel, Field
from metagpt.config2 import config
from metagpt.logs import logger
@ -43,12 +42,12 @@ class ResultEmbedding(BaseModel):
class OpenAIText2Embedding:
def __init__(self, openai_api_key):
def __init__(self, api_key: str, proxy: str):
"""
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
"""
self.openai_llm = config.get_openai_llm()
self.openai_api_key = openai_api_key or self.openai_llm.api_key
self.api_key = api_key
self.proxy = proxy
async def text_2_embedding(self, text, model="text-embedding-ada-002"):
"""Text to embedding
@ -58,8 +57,8 @@ class OpenAIText2Embedding:
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
proxies = {"proxy": self.openai_llm.proxy} if self.openai_llm.proxy else {}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"}
proxies = {"proxy": self.proxy} if self.proxy else {}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
data = {"input": text, "model": model}
url = "https://api.openai.com/v1/embeddings"
try:
@ -73,16 +72,14 @@ class OpenAIText2Embedding:
# Export
async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", openai_api_key=""):
async def oas3_openai_text_to_embedding(text, openai_api_key: str, model="text-embedding-ada-002", proxy: str = ""):
"""Text to embedding
:param text: The text used for embedding.
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
if not text:
return ""
if not openai_api_key:
openai_api_key = config.get_openai_llm().api_key
return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model)
return await OpenAIText2Embedding(api_key=openai_api_key, proxy=proxy).text_2_embedding(text, model=model)

View file

@ -45,7 +45,7 @@ class FileRepository:
# Initializing
self.workdir.mkdir(parents=True, exist_ok=True)
async def save(self, filename: Path | str, content, dependencies: List[str] = None):
async def save(self, filename: Path | str, content, dependencies: List[str] = None) -> Document:
"""Save content to a file and update its dependencies.
:param filename: The filename or path within the repository.
@ -63,6 +63,8 @@ class FileRepository:
await dependency_file.update(pathname, set(dependencies))
logger.info(f"update dependency: {str(pathname)}:{dependencies}")
return Document(root_path=str(self._relative_path), filename=filename, content=content)
async def get_dependency(self, filename: Path | str) -> Set[str]:
"""Get the dependencies of a file.
@ -181,10 +183,20 @@ class FileRepository:
"""
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
return current_time
# guid_suffix = str(uuid.uuid4())[:8]
# return f"{current_time}x{guid_suffix}"
async def save_doc(self, doc: Document, with_suffix: str = None, dependencies: List[str] = None):
async def save_doc(self, doc: Document, dependencies: List[str] = None):
"""Save content to a file and update its dependencies.
:param doc: The Document instance to be saved.
:type doc: Document
:param dependencies: A list of dependencies for the saved file.
:type dependencies: List[str], optional
"""
await self.save(filename=doc.filename, content=doc.content, dependencies=dependencies)
logger.debug(f"File Saved: {str(doc.filename)}")
async def save_pdf(self, doc: Document, with_suffix: str = ".md", dependencies: List[str] = None):
"""Save a Document instance as a PDF file.
This method converts the content of the Document instance to Markdown,
@ -202,68 +214,6 @@ class FileRepository:
await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
logger.debug(f"File Saved: {str(filename)}")
async def get_file(self, filename: Path | str, relative_path: Path | str = ".") -> Document | None:
"""Retrieve a specific file from the file repository.
:param filename: The name or path of the file to retrieve.
:type filename: Path or str
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: The document representing the file, or None if not found.
:rtype: Document or None
"""
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get(filename=filename)
async def get_all_files(self, relative_path: Path | str = ".") -> List[Document]:
"""Retrieve all files from the file repository.
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: A list of documents representing all files in the repository.
:rtype: List[Document]
"""
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.get_all()
async def save_file(
self, filename: Path | str, content, dependencies: List[str] = None, relative_path: Path | str = "."
):
"""Save a file to the file repository.
:param filename: The name or path of the file to save.
:type filename: Path or str
:param content: The content of the file.
:param dependencies: A list of dependencies for the file.
:type dependencies: List[str], optional
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
"""
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save(filename=filename, content=content, dependencies=dependencies)
async def save_as(
self, doc: Document, with_suffix: str = None, dependencies: List[str] = None, relative_path: Path | str = "."
):
"""Save a Document instance with optional modifications.
This static method creates a new FileRepository, saves the Document instance
with optional modifications (such as a suffix), and logs the saved file.
:param doc: The Document instance to be saved.
:type doc: Document
:param with_suffix: An optional suffix to append to the saved file's name.
:type with_suffix: str, optional
:param dependencies: A list of dependencies for the saved file.
:type dependencies: List[str], optional
:param relative_path: The relative path within the file repository.
:type relative_path: Path or str, optional
:return: A boolean indicating whether the save operation was successful.
:rtype: bool
"""
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
return await file_repo.save_doc(doc=doc, with_suffix=with_suffix, dependencies=dependencies)
async def delete(self, filename: Path | str):
"""Delete a file from the file repository.
@ -280,7 +230,3 @@ class FileRepository:
dependency_file = await self._git_repo.get_dependency()
await dependency_file.update(filename=pathname, dependencies=None)
logger.info(f"remove dependency key: {str(pathname)}")
async def delete_file(self, filename: Path | str, relative_path: Path | str = "."):
file_repo = self._git_repo.new_file_repository(relative_path=relative_path)
await file_repo.delete(filename=filename)

View file

@ -107,7 +107,10 @@ class GitRepository:
def delete_repository(self):
"""Delete the entire repository directory."""
if self.is_valid:
shutil.rmtree(self._repository.working_dir)
try:
shutil.rmtree(self._repository.working_dir)
except Exception as e:
logger.exception(f"Failed delete git repo:{self.workdir}, error:{e}")
@property
def changed_files(self) -> Dict[str, str]:
@ -199,10 +202,17 @@ class GitRepository:
if new_path.exists():
logger.info(f"Delete directory {str(new_path)}")
shutil.rmtree(new_path)
if new_path.exists(): # Recheck for windows os
logger.warning(f"Failed to delete directory {str(new_path)}")
return
try:
shutil.move(src=str(self.workdir), dst=str(new_path))
except Exception as e:
logger.warning(f"Move {str(self.workdir)} to {str(new_path)} error: {e}")
finally:
if not new_path.exists(): # Recheck for windows os
logger.warning(f"Failed to move {str(self.workdir)} to {str(new_path)}")
return
logger.info(f"Rename directory {str(self.workdir)} to {str(new_path)}")
self._repository = Repo(new_path)
self._gitignore_rules = parse_gitignore(full_path=str(new_path / ".gitignore"))

View file

@ -0,0 +1,119 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/8
@Author : mashenquan
@File : project_repo.py
@Desc : Wrapper for GitRepository and FileRepository of project.
Implementation of Chapter 4.6 of https://deepwisdom.feishu.cn/wiki/CUK4wImd7id9WlkQBNscIe9cnqh
"""
from __future__ import annotations
from pathlib import Path
from metagpt.const import (
CLASS_VIEW_FILE_REPO,
CODE_SUMMARIES_FILE_REPO,
CODE_SUMMARIES_PDF_FILE_REPO,
COMPETITIVE_ANALYSIS_FILE_REPO,
DATA_API_DESIGN_FILE_REPO,
DOCS_FILE_REPO,
GRAPH_REPO_FILE_REPO,
PRD_PDF_FILE_REPO,
PRDS_FILE_REPO,
RESOURCES_FILE_REPO,
SD_OUTPUT_FILE_REPO,
SEQ_FLOW_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
SYSTEM_DESIGN_PDF_FILE_REPO,
TASK_FILE_REPO,
TASK_PDF_FILE_REPO,
TEST_CODES_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
class DocFileRepositories(FileRepository):
prd: FileRepository
system_design: FileRepository
task: FileRepository
code_summary: FileRepository
graph_repo: FileRepository
class_view: FileRepository
def __init__(self, git_repo):
super().__init__(git_repo=git_repo, relative_path=DOCS_FILE_REPO)
self.prd = git_repo.new_file_repository(relative_path=PRDS_FILE_REPO)
self.system_design = git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_FILE_REPO)
self.task = git_repo.new_file_repository(relative_path=TASK_FILE_REPO)
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_FILE_REPO)
self.graph_repo = git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
self.class_view = git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
class ResourceFileRepositories(FileRepository):
competitive_analysis: FileRepository
data_api_design: FileRepository
seq_flow: FileRepository
system_design: FileRepository
prd: FileRepository
api_spec_and_task: FileRepository
code_summary: FileRepository
sd_output: FileRepository
def __init__(self, git_repo):
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
self.competitive_analysis = git_repo.new_file_repository(relative_path=COMPETITIVE_ANALYSIS_FILE_REPO)
self.data_api_design = git_repo.new_file_repository(relative_path=DATA_API_DESIGN_FILE_REPO)
self.seq_flow = git_repo.new_file_repository(relative_path=SEQ_FLOW_FILE_REPO)
self.system_design = git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_PDF_FILE_REPO)
self.prd = git_repo.new_file_repository(relative_path=PRD_PDF_FILE_REPO)
self.api_spec_and_task = git_repo.new_file_repository(relative_path=TASK_PDF_FILE_REPO)
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
class ProjectRepo(FileRepository):
def __init__(self, root: str | Path | GitRepository):
if isinstance(root, str) or isinstance(root, Path):
git_repo_ = GitRepository(local_path=Path(root))
elif isinstance(root, GitRepository):
git_repo_ = root
else:
raise ValueError("Invalid root")
super().__init__(git_repo=git_repo_, relative_path=Path("."))
self._git_repo = git_repo_
self.docs = DocFileRepositories(self._git_repo)
self.resources = ResourceFileRepositories(self._git_repo)
self.tests = self._git_repo.new_file_repository(relative_path=TEST_CODES_FILE_REPO)
self.test_outputs = self._git_repo.new_file_repository(relative_path=TEST_OUTPUTS_FILE_REPO)
self._srcs_path = None
@property
def git_repo(self) -> GitRepository:
return self._git_repo
@property
def workdir(self) -> Path:
return Path(self.git_repo.workdir)
@property
def srcs(self) -> FileRepository:
if not self._srcs_path:
raise ValueError("Call with_srcs first.")
return self._git_repo.new_file_repository(self._srcs_path)
def with_src_path(self, path: str | Path) -> ProjectRepo:
try:
self._srcs_path = Path(path).relative_to(self.workdir)
except ValueError:
self._srcs_path = Path(path)
return self
@property
def src_relative_path(self) -> Path | None:
return self._srcs_path

View file

@ -1 +1 @@
{"docs/system_design/20231221155954.json": ["docs/prds/20231221155954.json"], "docs/tasks/20231221155954.json": ["docs/system_design/20231221155954.json"], "game_2048/game.py": ["docs/tasks/20231221155954.json", "docs/system_design/20231221155954.json"], "game_2048/main.py": ["docs/tasks/20231221155954.json", "docs/system_design/20231221155954.json"], "resources/code_summaries/20231221155954.md": ["docs/tasks/20231221155954.json", "game_2048/game.py", "docs/system_design/20231221155954.json", "game_2048/main.py"], "docs/code_summaries/20231221155954.json": ["docs/tasks/20231221155954.json", "game_2048/game.py", "docs/system_design/20231221155954.json", "game_2048/main.py"], "tests/test_main.py": ["game_2048/main.py"], "tests/test_game.py": ["game_2048/game.py"], "test_outputs/test_main.py.json": ["game_2048/main.py", "tests/test_main.py"], "test_outputs/test_game.py.json": ["game_2048/game.py", "tests/test_game.py"]}
{"docs/system_design/20231221155954.json": ["docs/prd/20231221155954.json"], "docs/task/20231221155954.json": ["docs/system_design/20231221155954.json"], "game_2048/game.py": ["docs/task/20231221155954.json", "docs/system_design/20231221155954.json"], "game_2048/main.py": ["docs/task/20231221155954.json", "docs/system_design/20231221155954.json"], "resources/code_summary/20231221155954.md": ["docs/task/20231221155954.json", "game_2048/game.py", "docs/system_design/20231221155954.json", "game_2048/main.py"], "docs/code_summary/20231221155954.json": ["docs/task/20231221155954.json", "game_2048/game.py", "docs/system_design/20231221155954.json", "game_2048/main.py"], "tests/test_main.py": ["game_2048/main.py"], "tests/test_game.py": ["game_2048/game.py"], "test_outputs/test_main.py.json": ["game_2048/main.py", "tests/test_main.py"], "test_outputs/test_game.py.json": ["game_2048/game.py", "tests/test_game.py"]}

File diff suppressed because one or more lines are too long

View file

@ -11,9 +11,9 @@ import uuid
import pytest
from metagpt.actions.debug_error import DebugError
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
from metagpt.context import CONTEXT
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.project_repo import ProjectRepo
CODE_CONTENT = '''
from typing import List
@ -118,6 +118,7 @@ if __name__ == '__main__':
@pytest.mark.asyncio
async def test_debug_error():
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / uuid.uuid4().hex
project_repo = ProjectRepo(CONTEXT.git_repo)
ctx = RunCodeContext(
code_filename="player.py",
test_filename="test_player.py",
@ -125,9 +126,8 @@ async def test_debug_error():
output_filename="output.log",
)
repo = CONTEXT.file_repo
await repo.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=CONTEXT.src_workspace)
await repo.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO)
await project_repo.with_src_path(CONTEXT.src_workspace).srcs.save(filename=ctx.code_filename, content=CODE_CONTENT)
await project_repo.tests.save(filename=ctx.test_filename, content=TEST_CONTENT)
output_data = RunCodeResult(
stdout=";",
stderr="",
@ -141,9 +141,7 @@ async def test_debug_error():
"----------------------------------------------------------------------\n"
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
)
await repo.save_file(
filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO
)
await project_repo.test_outputs.save(filename=ctx.output_filename, content=output_data.model_dump_json())
debug_error = DebugError(i_context=ctx)
rsp = await debug_error.run()

View file

@ -9,18 +9,18 @@
import pytest
from metagpt.actions.design_api import WriteDesign
from metagpt.const import PRDS_FILE_REPO
from metagpt.context import CONTEXT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.project_repo import ProjectRepo
@pytest.mark.asyncio
async def test_design_api():
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE
repo = CONTEXT.file_repo
project_repo = ProjectRepo(CONTEXT.git_repo)
for prd in inputs:
await repo.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
await project_repo.docs.prd.save(filename="new_prd.txt", content=prd)
design_api = WriteDesign()

View file

@ -9,9 +9,10 @@
import pytest
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.context import CONTEXT
from metagpt.schema import Message
from metagpt.utils.project_repo import ProjectRepo
@pytest.mark.asyncio
@ -24,6 +25,6 @@ async def test_prepare_documents():
await PrepareDocuments(context=CONTEXT).run(with_messages=[msg])
assert CONTEXT.git_repo
doc = await CONTEXT.file_repo.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
doc = await ProjectRepo(CONTEXT.git_repo).docs.get(filename=REQUIREMENT_FILENAME)
assert doc
assert doc.content == msg.content

View file

@ -9,17 +9,18 @@
import pytest
from metagpt.actions.project_management import WriteTasks
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
from metagpt.context import CONTEXT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.project_repo import ProjectRepo
from tests.metagpt.actions.mock_json import DESIGN, PRD
@pytest.mark.asyncio
async def test_design_api():
await CONTEXT.file_repo.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
await CONTEXT.file_repo.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
project_repo = ProjectRepo(CONTEXT.git_repo)
await project_repo.docs.prd.save("1.txt", content=str(PRD))
await project_repo.docs.system_design.save("1.txt", content=str(DESIGN))
logger.info(CONTEXT.git_repo)
action = WriteTasks()

View file

@ -15,6 +15,7 @@ from metagpt.context import CONTEXT
from metagpt.llm import LLM
from metagpt.utils.common import aread
from metagpt.utils.git_repository import ChangeType
from metagpt.utils.project_repo import ProjectRepo
@pytest.mark.asyncio
@ -22,12 +23,8 @@ async def test_rebuild():
# Mock
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
graph_db_filename = Path(CONTEXT.git_repo.workdir.name).with_suffix(".json")
repo = CONTEXT.file_repo
await repo.save_file(
filename=str(graph_db_filename),
relative_path=GRAPH_REPO_FILE_REPO,
content=data,
)
project_repo = ProjectRepo(CONTEXT.git_repo)
await project_repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
CONTEXT.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
CONTEXT.git_repo.commit("commit1")
@ -35,8 +32,7 @@ async def test_rebuild():
name="RedBean", i_context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
)
await action.run()
graph_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
assert graph_file_repo.changed_files
assert project_repo.docs.graph_repo.changed_files
@pytest.mark.parametrize(

View file

@ -9,10 +9,10 @@
import pytest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.context import CONTEXT
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
from metagpt.utils.project_repo import ProjectRepo
DESIGN_CONTENT = """
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
@ -178,17 +178,22 @@ class Snake:
@pytest.mark.asyncio
async def test_summarize_code():
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / "src"
await CONTEXT.file_repo.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
await CONTEXT.file_repo.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT)
await CONTEXT.file_repo.save_file(filename="food.py", relative_path=CONTEXT.src_workspace, content=FOOD_PY)
await CONTEXT.file_repo.save_file(filename="game.py", relative_path=CONTEXT.src_workspace, content=GAME_PY)
await CONTEXT.file_repo.save_file(filename="main.py", relative_path=CONTEXT.src_workspace, content=MAIN_PY)
await CONTEXT.file_repo.save_file(filename="snake.py", relative_path=CONTEXT.src_workspace, content=SNAKE_PY)
project_repo = ProjectRepo(CONTEXT.git_repo)
await project_repo.docs.system_design.save(filename="1.json", content=DESIGN_CONTENT)
await project_repo.docs.task.save(filename="1.json", content=TASK_CONTENT)
await project_repo.with_src_path(CONTEXT.src_workspace).srcs.save(filename="food.py", content=FOOD_PY)
assert project_repo.srcs.workdir == CONTEXT.src_workspace
await project_repo.srcs.save(filename="game.py", content=GAME_PY)
await project_repo.srcs.save(filename="main.py", content=MAIN_PY)
await project_repo.srcs.save(filename="snake.py", content=SNAKE_PY)
src_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=CONTEXT.src_workspace)
all_files = src_file_repo.all_files
all_files = project_repo.srcs.all_files
ctx = CodeSummarizeContext(design_filename="1.json", task_filename="1.json", codes_filenames=all_files)
action = SummarizeCode(i_context=ctx)
rsp = await action.run()
assert rsp
logger.info(rsp)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -12,26 +12,24 @@ from pathlib import Path
import pytest
from metagpt.actions.write_code import WriteCode
from metagpt.const import (
CODE_SUMMARIES_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.context import CONTEXT
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.schema import CodingContext, Document
from metagpt.utils.common import aread
from metagpt.utils.project_repo import ProjectRepo
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
@pytest.mark.asyncio
async def test_write_code():
ccontext = CodingContext(
# Prerequisites
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / "writecode"
coding_ctx = CodingContext(
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
)
doc = Document(content=ccontext.model_dump_json())
doc = Document(content=coding_ctx.model_dump_json())
write_code = WriteCode(i_context=doc)
code = await write_code.run()
@ -55,33 +53,28 @@ async def test_write_code_deps():
# Prerequisites
CONTEXT.src_workspace = CONTEXT.git_repo.workdir / "snake1/snake1"
demo_path = Path(__file__).parent / "../../data/demo_project"
await CONTEXT.file_repo.save_file(
filename="test_game.py.json",
content=await aread(str(demo_path / "test_game.py.json")),
relative_path=TEST_OUTPUTS_FILE_REPO,
project_repo = ProjectRepo(CONTEXT.git_repo)
await project_repo.test_outputs.save(
filename="test_game.py.json", content=await aread(str(demo_path / "test_game.py.json"))
)
await CONTEXT.file_repo.save_file(
await project_repo.docs.code_summary.save(
filename="20231221155954.json",
content=await aread(str(demo_path / "code_summaries.json")),
relative_path=CODE_SUMMARIES_FILE_REPO,
)
await CONTEXT.file_repo.save_file(
await project_repo.docs.system_design.save(
filename="20231221155954.json",
content=await aread(str(demo_path / "system_design.json")),
relative_path=SYSTEM_DESIGN_FILE_REPO,
)
await CONTEXT.file_repo.save_file(
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO
await project_repo.docs.task.save(
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json"))
)
await CONTEXT.file_repo.save_file(
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONTEXT.src_workspace
await project_repo.with_src_path(CONTEXT.src_workspace).srcs.save(
filename="main.py", content='if __name__ == "__main__":\nmain()'
)
ccontext = CodingContext(
filename="game.py",
design_doc=await CONTEXT.file_repo.get_file(
filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO
),
task_doc=await CONTEXT.file_repo.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
design_doc=await project_repo.docs.system_design.get(filename="20231221155954.json"),
task_doc=await project_repo.docs.task.get(filename="20231221155954.json"),
code_doc=Document(filename="game.py", content="", root_path="snake1"),
)
coding_doc = Document(root_path="snake1", filename="game.py", content=ccontext.json())

View file

@ -9,21 +9,22 @@
import pytest
from metagpt.actions import UserRequirement, WritePRD
from metagpt.const import DOCS_FILE_REPO, PRDS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.context import CONTEXT
from metagpt.logs import logger
from metagpt.roles.product_manager import ProductManager
from metagpt.roles.role import RoleReactMode
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
from metagpt.utils.project_repo import ProjectRepo
@pytest.mark.asyncio
async def test_write_prd(new_filename):
product_manager = ProductManager()
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
repo = CONTEXT.file_repo
await repo.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
project_repo = ProjectRepo(CONTEXT.git_repo)
await project_repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
product_manager.rc.react_mode = RoleReactMode.BY_ORDER
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
assert prd.cause_by == any_to_str(WritePRD)
@ -33,7 +34,7 @@ async def test_write_prd(new_filename):
# Assert the prd is not None or empty
assert prd is not None
assert prd.content != ""
assert CONTEXT.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files
assert ProjectRepo(product_manager.context.git_repo).docs.prd.changed_files
if __name__ == "__main__":

View file

@ -6,19 +6,32 @@
@File : test_text_to_embedding.py
@Desc : Unit tests.
"""
import json
from pathlib import Path
import pytest
from metagpt.config2 import config
from metagpt.learn.text_to_embedding import text_to_embedding
from metagpt.utils.common import aread
@pytest.mark.asyncio
async def test_text_to_embedding():
# Prerequisites
assert config.get_openai_llm()
async def test_text_to_embedding(mocker):
# mock
mock_post = mocker.patch("aiohttp.ClientSession.post")
mock_response = mocker.AsyncMock()
mock_response.status = 200
data = await aread(Path(__file__).parent / "../../data/openai/embedding.json")
mock_response.json.return_value = json.loads(data)
mock_post.return_value.__aenter__.return_value = mock_response
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
v = await text_to_embedding(text="Panda emoji")
# Prerequisites
assert config.get_openai_llm().api_key
assert config.get_openai_llm().proxy
v = await text_to_embedding(text="Panda emoji", config=config)
assert len(v.data) > 0

View file

@ -6,9 +6,11 @@
@File : test_text_to_image.py
@Desc : Unit tests.
"""
import base64
import openai
import pytest
from pydantic import BaseModel
from metagpt.config2 import Config
from metagpt.learn.text_to_image import text_to_image
@ -27,15 +29,30 @@ async def test_text_to_image(mocker):
config = Config.default()
assert config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
data = await text_to_image(
"Panda emoji", size_type="512x512", model_url=config.METAGPT_TEXT_TO_IMAGE_MODEL_URL, config=config
)
data = await text_to_image("Panda emoji", size_type="512x512", config=config)
assert "base64" in data or "http" in data
@pytest.mark.asyncio
async def test_openai_text_to_image():
async def test_openai_text_to_image(mocker):
# mocker
mock_url = mocker.Mock()
mock_url.url.return_value = "http://mock.com/0.png"
class _MockData(BaseModel):
data: list
mock_data = _MockData(data=[mock_url])
mocker.patch.object(openai.resources.images.AsyncImages, "generate", return_value=mock_data)
mock_post = mocker.patch("aiohttp.ClientSession.get")
mock_response = mocker.AsyncMock()
mock_response.status = 200
mock_response.read.return_value = base64.b64encode(b"success")
mock_post.return_value.__aenter__.return_value = mock_response
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png")
config = Config.default()
config.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None
assert config.get_openai_llm()
data = await text_to_image("Panda emoji", size_type="512x512", config=config)

View file

@ -8,43 +8,64 @@
"""
import pytest
from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer
from metagpt.config2 import config
from metagpt.config2 import Config
from metagpt.learn.text_to_speech import text_to_speech
from metagpt.tools.iflytek_tts import IFlyTekTTS
from metagpt.utils.s3 import S3
@pytest.mark.asyncio
async def test_text_to_speech():
async def test_azure_text_to_speech(mocker):
# mock
config = Config.default()
config.IFLYTEK_API_KEY = None
config.IFLYTEK_API_SECRET = None
config.IFLYTEK_APP_ID = None
mock_result = mocker.Mock()
mock_result.audio_data = b"mock audio data"
mock_result.reason = ResultReason.SynthesizingAudioCompleted
mock_data = mocker.Mock()
mock_data.get.return_value = mock_result
mocker.patch.object(SpeechSynthesizer, "speak_ssml_async", return_value=mock_data)
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/1.wav")
# Prerequisites
assert not config.IFLYTEK_APP_ID
assert not config.IFLYTEK_API_KEY
assert not config.IFLYTEK_API_SECRET
assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
assert config.AZURE_TTS_REGION
config.copy()
# test azure
data = await text_to_speech("panda emoji", config=config)
assert "base64" in data or "http" in data
@pytest.mark.asyncio
async def test_iflytek_text_to_speech(mocker):
# mock
config = Config.default()
config.AZURE_TTS_SUBSCRIPTION_KEY = None
config.AZURE_TTS_REGION = None
mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None)
mock_data = mocker.AsyncMock()
mock_data.read.return_value = b"mock iflytek"
mock_reader = mocker.patch("aiofiles.open")
mock_reader.return_value.__aenter__.return_value = mock_data
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/1.mp3")
# Prerequisites
assert config.IFLYTEK_APP_ID
assert config.IFLYTEK_API_KEY
assert config.IFLYTEK_API_SECRET
assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
assert config.AZURE_TTS_REGION
assert not config.AZURE_TTS_SUBSCRIPTION_KEY or config.AZURE_TTS_SUBSCRIPTION_KEY == "YOUR_API_KEY"
assert not config.AZURE_TTS_REGION
i = config.copy()
# test azure
data = await text_to_speech(
"panda emoji",
subscription_key=i.AZURE_TTS_SUBSCRIPTION_KEY,
region=i.AZURE_TTS_REGION,
iflytek_api_key=i.IFLYTEK_API_KEY,
iflytek_api_secret=i.IFLYTEK_API_SECRET,
iflytek_app_id=i.IFLYTEK_APP_ID,
)
assert "base64" in data or "http" in data
# test iflytek
## Mock session env
i.AZURE_TTS_SUBSCRIPTION_KEY = ""
data = await text_to_speech(
"panda emoji",
subscription_key=i.AZURE_TTS_SUBSCRIPTION_KEY,
region=i.AZURE_TTS_REGION,
iflytek_api_key=i.IFLYTEK_API_KEY,
iflytek_api_secret=i.IFLYTEK_API_SECRET,
iflytek_app_id=i.IFLYTEK_APP_ID,
)
data = await text_to_speech("panda emoji", config=config)
assert "base64" in data or "http" in data

View file

@ -20,7 +20,10 @@ from metagpt.utils.common import any_to_str
@pytest.mark.asyncio
async def test_run():
async def test_run(mocker):
# mock
mocker.patch("metagpt.learn.text_to_image", return_value="http://mock.com/1.png")
CONTEXT.kwargs.language = "Chinese"
class Input(BaseModel):
@ -65,7 +68,7 @@ async def test_run():
"cause_by": any_to_str(SkillAction),
},
]
CONTEXT.kwargs.agent_skills = [
agent_skills = [
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
@ -77,9 +80,11 @@ async def test_run():
for i in inputs:
seed = Input(**i)
CONTEXT.kwargs.language = seed.language
CONTEXT.kwargs.agent_description = seed.agent_description
role = Assistant(language="Chinese")
role.context.kwargs.language = seed.language
role.context.kwargs.agent_description = seed.agent_description
role.context.kwargs.agent_skills = agent_skills
role.memory = seed.memory # Restore historical conversation content.
while True:
has_action = await role.think()
@ -112,6 +117,7 @@ async def test_run():
@pytest.mark.asyncio
async def test_memory(memory):
role = Assistant()
role.context.kwargs.agent_skills = []
role.load_memory(memory)
val = role.get_memory()

View file

@ -8,23 +8,25 @@
distribution feature for message handling.
"""
import json
import uuid
from pathlib import Path
import pytest
from metagpt.actions import WriteCode, WriteTasks
from metagpt.const import (
PRDS_FILE_REPO,
DEFAULT_WORKSPACE_ROOT,
REQUIREMENT_FILENAME,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
from metagpt.context import CONTEXT
from metagpt.context import CONTEXT, Context
from metagpt.logs import logger
from metagpt.roles.engineer import Engineer
from metagpt.schema import CodingContext, Message
from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite
from metagpt.utils.git_repository import ChangeType
from metagpt.utils.git_repository import ChangeType, GitRepository
from metagpt.utils.project_repo import ProjectRepo
from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
@ -32,20 +34,18 @@ from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
async def test_engineer():
# Prerequisites
rqno = "20231221155954.json"
await CONTEXT.file_repo.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content)
await CONTEXT.file_repo.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content)
await CONTEXT.file_repo.save_file(
rqno, relative_path=SYSTEM_DESIGN_FILE_REPO, content=MockMessages.system_design.content
)
await CONTEXT.file_repo.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content)
project_repo = ProjectRepo(CONTEXT.git_repo)
await project_repo.save(REQUIREMENT_FILENAME, content=MockMessages.req.content)
await project_repo.docs.prd.save(rqno, content=MockMessages.prd.content)
await project_repo.docs.system_design.save(rqno, content=MockMessages.system_design.content)
await project_repo.docs.task.save(rqno, content=MockMessages.json_tasks.content)
engineer = Engineer()
rsp = await engineer.run(Message(content="", cause_by=WriteTasks))
logger.info(rsp)
assert rsp.cause_by == any_to_str(WriteCode)
src_file_repo = CONTEXT.git_repo.new_file_repository(CONTEXT.src_workspace)
assert src_file_repo.changed_files
assert project_repo.with_src_path(CONTEXT.src_workspace).srcs.changed_files
def test_parse_str():
@ -114,48 +114,50 @@ def test_todo():
@pytest.mark.asyncio
async def test_new_coding_context():
# Prerequisites
context = Context()
context.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
demo_path = Path(__file__).parent / "../../data/demo_project"
deps = json.loads(await aread(demo_path / "dependencies.json"))
dependency = await CONTEXT.git_repo.get_dependency()
dependency = await context.git_repo.get_dependency()
for k, v in deps.items():
await dependency.update(k, set(v))
data = await aread(demo_path / "system_design.json")
rqno = "20231221155954.json"
await awrite(CONTEXT.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data)
await awrite(context.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data)
data = await aread(demo_path / "tasks.json")
await awrite(CONTEXT.git_repo.workdir / TASK_FILE_REPO / rqno, data)
await awrite(context.git_repo.workdir / TASK_FILE_REPO / rqno, data)
CONTEXT.src_workspace = Path(CONTEXT.git_repo.workdir) / "game_2048"
src_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=CONTEXT.src_workspace)
task_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=TASK_FILE_REPO)
design_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_FILE_REPO)
context.src_workspace = Path(context.git_repo.workdir) / "game_2048"
filename = "game.py"
ctx_doc = await Engineer._new_coding_doc(
filename=filename,
src_file_repo=src_file_repo,
task_file_repo=task_file_repo,
design_file_repo=design_file_repo,
dependency=dependency,
)
assert ctx_doc
assert ctx_doc.filename == filename
assert ctx_doc.content
ctx = CodingContext.model_validate_json(ctx_doc.content)
assert ctx.filename == filename
assert ctx.design_doc
assert ctx.design_doc.content
assert ctx.task_doc
assert ctx.task_doc.content
assert ctx.code_doc
try:
filename = "game.py"
engineer = Engineer(context=context)
ctx_doc = await engineer._new_coding_doc(
filename=filename,
dependency=dependency,
)
assert ctx_doc
assert ctx_doc.filename == filename
assert ctx_doc.content
ctx = CodingContext.model_validate_json(ctx_doc.content)
assert ctx.filename == filename
assert ctx.design_doc
assert ctx.design_doc.content
assert ctx.task_doc
assert ctx.task_doc.content
assert ctx.code_doc
CONTEXT.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED})
CONTEXT.git_repo.commit("mock env")
await src_file_repo.save(filename=filename, content="content")
role = Engineer()
assert not role.code_todos
await role._new_code_actions()
assert role.code_todos
context.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED})
context.git_repo.commit("mock env")
await ProjectRepo(context.git_repo).with_src_path(context.src_workspace).srcs.save(
filename=filename, content="content"
)
role = Engineer(context=context)
assert not role.code_todos
await role._new_code_actions()
assert role.code_todos
finally:
context.git_repo.delete_repository()
if __name__ == "__main__":

View file

@ -8,15 +8,14 @@
from typing import Dict, Optional
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from metagpt.context import CONTEXT
from metagpt.context import Context
from metagpt.roles.teacher import Teacher
from metagpt.schema import Message
@pytest.mark.asyncio
@pytest.mark.skip
async def test_init():
class Inputs(BaseModel):
name: str
@ -30,6 +29,7 @@ async def test_init():
expect_goal: str
expect_constraints: str
expect_desc: str
exclude: list = Field(default_factory=list)
inputs = [
{
@ -44,6 +44,7 @@ async def test_init():
"kwargs": {},
"desc": "aaa{language}",
"expect_desc": "aaa{language}",
"exclude": ["language", "key1", "something_big", "teaching_language"],
},
{
"name": "Lily{language}",
@ -57,13 +58,21 @@ async def test_init():
"kwargs": {"language": "CN", "key1": "HaHa", "something_big": "sleep", "teaching_language": "EN"},
"desc": "aaa{language}",
"expect_desc": "aaaCN",
"language": "CN",
"teaching_language": "EN",
},
]
for i in inputs:
seed = Inputs(**i)
context = Context()
for k in seed.exclude:
context.kwargs.set(k, None)
for k, v in seed.kwargs.items():
context.kwargs.set(k, v)
teacher = Teacher(
context=context,
name=seed.name,
profile=seed.profile,
goal=seed.goal,
@ -97,8 +106,6 @@ async def test_new_file_name():
@pytest.mark.asyncio
async def test_run():
CONTEXT.kwargs.language = "Chinese"
CONTEXT.kwargs.teaching_language = "English"
lesson = """
UNIT 1 Making New Friends
TOPIC 1 Welcome to China!
@ -142,7 +149,10 @@ async def test_run():
3c Match the big letters with the small ones. Then write them on the lines.
"""
teacher = Teacher()
context = Context()
context.kwargs.language = "Chinese"
context.kwargs.teaching_language = "English"
teacher = Teacher(context=context)
rsp = await teacher.run(Message(content=lesson))
assert rsp

View file

@ -7,21 +7,31 @@
@Modified By: mashenquan, 2023-8-9, add more text formatting options
@Modified By: mashenquan, 2023-8-17, move to `tools` folder.
"""
from pathlib import Path
import pytest
from azure.cognitiveservices.speech import ResultReason
from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer
from metagpt.config2 import config
from metagpt.tools.azure_tts import AzureTTS
@pytest.mark.asyncio
async def test_azure_tts():
async def test_azure_tts(mocker):
# mock
mock_result = mocker.Mock()
mock_result.audio_data = b"mock audio data"
mock_result.reason = ResultReason.SynthesizingAudioCompleted
mock_data = mocker.Mock()
mock_data.get.return_value = mock_result
mocker.patch.object(SpeechSynthesizer, "speak_ssml_async", return_value=mock_data)
mocker.patch.object(Path, "exists", return_value=True)
# Prerequisites
assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
assert config.AZURE_TTS_REGION
azure_tts = AzureTTS(subscription_key="", region="")
azure_tts = AzureTTS(subscription_key=config.AZURE_TTS_SUBSCRIPTION_KEY, region=config.AZURE_TTS_REGION)
text = """
女儿看见父亲走了进来问道
<mstts:express-as role="YoungAdultFemale" style="calm">

View file

@ -7,12 +7,22 @@
"""
import pytest
from metagpt.config2 import config
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
from metagpt.config2 import Config
from metagpt.tools.iflytek_tts import IFlyTekTTS, oas3_iflytek_tts
@pytest.mark.asyncio
async def test_tts():
async def test_iflytek_tts(mocker):
# mock
config = Config.default()
config.AZURE_TTS_SUBSCRIPTION_KEY = None
config.AZURE_TTS_REGION = None
mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None)
mock_data = mocker.AsyncMock()
mock_data.read.return_value = b"mock iflytek"
mock_reader = mocker.patch("aiofiles.open")
mock_reader.return_value.__aenter__.return_value = mock_data
# Prerequisites
assert config.IFLYTEK_APP_ID
assert config.IFLYTEK_API_KEY

View file

@ -5,19 +5,35 @@
@Author : mashenquan
@File : test_openai_text_to_embedding.py
"""
import json
from pathlib import Path
import pytest
from metagpt.config2 import config
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
from metagpt.utils.common import aread
@pytest.mark.asyncio
async def test_embedding():
# Prerequisites
assert config.get_openai_llm()
async def test_embedding(mocker):
# mock
mock_post = mocker.patch("aiohttp.ClientSession.post")
mock_response = mocker.AsyncMock()
mock_response.status = 200
data = await aread(Path(__file__).parent / "../../data/openai/embedding.json")
mock_response.json.return_value = json.loads(data)
mock_post.return_value.__aenter__.return_value = mock_response
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
result = await oas3_openai_text_to_embedding("Panda emoji")
# Prerequisites
llm_config = config.get_openai_llm()
assert llm_config
assert llm_config.proxy
result = await oas3_openai_text_to_embedding(
"Panda emoji", openai_api_key=llm_config.api_key, proxy=llm_config.proxy
)
assert result
assert result.model
assert len(result.data) > 0

View file

@ -5,22 +5,44 @@
@Author : mashenquan
@File : test_openai_text_to_image.py
"""
import base64
import openai
import pytest
from pydantic import BaseModel
from metagpt.config2 import config
from metagpt.llm import LLM
from metagpt.tools.openai_text_to_image import (
OpenAIText2Image,
oas3_openai_text_to_image,
)
from metagpt.utils.s3 import S3
@pytest.mark.asyncio
async def test_draw():
# Prerequisites
assert config.get_openai_llm()
async def test_draw(mocker):
# mock
mock_url = mocker.Mock()
mock_url.url.return_value = "http://mock.com/0.png"
binary_data = await oas3_openai_text_to_image("Panda emoji")
class _MockData(BaseModel):
data: list
mock_data = _MockData(data=[mock_url])
mocker.patch.object(openai.resources.images.AsyncImages, "generate", return_value=mock_data)
mock_post = mocker.patch("aiohttp.ClientSession.get")
mock_response = mocker.AsyncMock()
mock_response.status = 200
mock_response.read.return_value = base64.b64encode(b"success")
mock_post.return_value.__aenter__.return_value = mock_response
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png")
# Prerequisites
llm_config = config.get_openai_llm()
assert llm_config
binary_data = await oas3_openai_text_to_image("Panda emoji", llm=LLM(llm_config=llm_config))
assert binary_data

View file

@ -0,0 +1,64 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/8
@Author : mashenquan
"""
import uuid
from pathlib import Path
import pytest
from metagpt.const import (
BUGFIX_FILENAME,
PACKAGE_REQUIREMENTS_FILENAME,
PRDS_FILE_REPO,
REQUIREMENT_FILENAME,
)
from metagpt.utils.project_repo import ProjectRepo
async def test_project_repo():
root = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}"
root = root.resolve()
pr = ProjectRepo(root=str(root))
assert pr.git_repo.workdir == root
assert pr.workdir == pr.git_repo.workdir
await pr.save(filename=REQUIREMENT_FILENAME, content=REQUIREMENT_FILENAME)
doc = await pr.get(filename=REQUIREMENT_FILENAME)
assert doc.content == REQUIREMENT_FILENAME
await pr.save(filename=BUGFIX_FILENAME, content=BUGFIX_FILENAME)
doc = await pr.get(filename=BUGFIX_FILENAME)
assert doc.content == BUGFIX_FILENAME
await pr.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content=PACKAGE_REQUIREMENTS_FILENAME)
doc = await pr.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
assert doc.content == PACKAGE_REQUIREMENTS_FILENAME
await pr.docs.prd.save(filename="1.prd", content="1.prd", dependencies=[REQUIREMENT_FILENAME])
doc = await pr.docs.prd.get(filename="1.prd")
assert doc.content == "1.prd"
await pr.resources.prd.save(
filename="1.prd",
content="1.prd",
dependencies=[REQUIREMENT_FILENAME, f"{PRDS_FILE_REPO}/1.prd"],
)
doc = await pr.resources.prd.get(filename="1.prd")
assert doc.content == "1.prd"
dependencies = await pr.resources.prd.get_dependency(filename="1.prd")
assert len(dependencies) == 2
assert pr.changed_files
assert pr.docs.prd.changed_files
assert not pr.tests.changed_files
with pytest.raises(ValueError):
pr.srcs
assert pr.with_src_path("test_src").srcs.root_path == Path("test_src")
assert pr.src_relative_path == Path("test_src")
pr.git_repo.delete_repository()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -8,7 +8,6 @@
from unittest.mock import AsyncMock
import pytest
from pytest_mock import mocker
from metagpt.config2 import Config
from metagpt.utils.redis import Redis
@ -22,7 +21,7 @@ async def async_mock_from_url(*args, **kwargs):
@pytest.mark.asyncio
async def test_redis(i):
async def test_redis(mocker):
redis = Config.default().redis
mocker.patch("aioredis.from_url", return_value=async_mock_from_url())