use context instead of FileRepo...

This commit is contained in:
geekan 2024-01-05 00:41:00 +08:00
parent 5c1f3a4b91
commit 3cd881de56
19 changed files with 78 additions and 93 deletions

View file

@ -12,6 +12,7 @@ from typing import Optional, Union
from pydantic import ConfigDict, Field, model_validator
import metagpt
from metagpt.actions.action_node import ActionNode
from metagpt.context import Context
from metagpt.llm import LLM
@ -35,7 +36,7 @@ class Action(SerializationMixin, is_polymorphic_base=True):
prefix: str = "" # aask*时会加上prefix作为system_message
desc: str = "" # for skill manager
node: ActionNode = Field(default=None, exclude=True)
g_context: Optional[Context] = Field(default=None, exclude=True)
g_context: Optional[Context] = Field(default=metagpt.context.context, exclude=True)
@property
def git_repo(self):

View file

@ -9,17 +9,14 @@
2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name.
"""
import re
from typing import Optional
from pydantic import Field
from metagpt.actions.action import Action
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
from metagpt.context import Context
from metagpt.logs import logger
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.common import CodeParser
from metagpt.utils.file_repository import FileRepository
PROMPT_TEMPLATE = """
NOTICE
@ -51,10 +48,9 @@ Now you should start rewriting the code:
class DebugError(Action):
context: RunCodeContext = Field(default_factory=RunCodeContext)
g_context: Optional[Context] = None
async def run(self, *args, **kwargs) -> str:
output_doc = await FileRepository.get_file(
output_doc = await self.file_repo.get_file(
filename=self.context.output_filename, relative_path=TEST_OUTPUTS_FILE_REPO
)
if not output_doc:
@ -66,12 +62,12 @@ class DebugError(Action):
return ""
logger.info(f"Debug and rewrite {self.context.test_filename}")
code_doc = await FileRepository.get_file(
code_doc = await self.file_repo.get_file(
filename=self.context.code_filename, relative_path=self.g_context.src_workspace
)
if not code_doc:
return ""
test_doc = await FileRepository.get_file(
test_doc = await self.file_repo.get_file(
filename=self.context.test_filename, relative_path=TEST_CODES_FILE_REPO
)
if not test_doc:

View file

@ -24,7 +24,6 @@ from metagpt.const import (
)
from metagpt.logs import logger
from metagpt.schema import Document, Documents, Message
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.mermaid import mermaid_to_file
NEW_REQ_TEMPLATE = """
@ -75,13 +74,13 @@ class WriteDesign(Action):
# leaving room for global optimization in subsequent steps.
return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files)
async def _new_system_design(self, context, schema=None):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
async def _new_system_design(self, context):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm)
return node
async def _merge(self, prd_doc, system_design_doc, schema=None):
async def _merge(self, prd_doc, system_design_doc):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm)
system_design_doc.content = node.instruct_content.model_dump_json()
return system_design_doc
@ -123,9 +122,8 @@ class WriteDesign(Action):
await WriteDesign._save_mermaid_file(seq_flow, pathname)
logger.info(f"Saving sequence flow to {str(pathname)}")
@staticmethod
async def _save_pdf(design_doc):
await FileRepository.save_as(doc=design_doc, with_suffix=".md", relative_path=SYSTEM_DESIGN_PDF_FILE_REPO)
async def _save_pdf(self, design_doc):
await self.file_repo.save_as(doc=design_doc, with_suffix=".md", relative_path=SYSTEM_DESIGN_PDF_FILE_REPO)
@staticmethod
async def _save_mermaid_file(data: str, pathname: Path):

View file

@ -24,7 +24,6 @@ from metagpt.const import (
)
from metagpt.logs import logger
from metagpt.schema import Document, Documents
from metagpt.utils.file_repository import FileRepository
NEW_REQ_TEMPLATE = """
### Legacy Content
@ -39,11 +38,7 @@ class WriteTasks(Action):
name: str = "CreateTasks"
context: Optional[str] = None
@property
def prompt_schema(self):
return self.g_context.config.prompt_schema
async def run(self, with_messages, schema=None):
async def run(self, with_messages):
system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files
@ -114,6 +109,5 @@ class WriteTasks(Action):
packages.add(pkg)
await file_repo.save(PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
@staticmethod
async def _save_pdf(task_doc):
await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
async def _save_pdf(self, task_doc):
await self.file_repo.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)

View file

@ -14,7 +14,6 @@ from metagpt.actions.action import Action
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
from metagpt.utils.file_repository import FileRepository
PROMPT_TEMPLATE = """
NOTICE
@ -89,7 +88,6 @@ flowchart TB
"""
# TOTEST
class SummarizeCode(Action):
name: str = "SummarizeCode"
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
@ -101,9 +99,10 @@ class SummarizeCode(Action):
async def run(self):
design_pathname = Path(self.context.design_filename)
design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
repo = self.file_repo
design_doc = await repo.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO)
task_pathname = Path(self.context.task_filename)
task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
task_doc = await repo.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO)
src_file_repo = self.git_repo.new_file_repository(relative_path=self.g_context.src_workspace)
code_blocks = []
for filename in self.context.codes_filenames:

View file

@ -31,7 +31,6 @@ from metagpt.const import (
from metagpt.logs import logger
from metagpt.schema import CodingContext, Document, RunCodeResult
from metagpt.utils.common import CodeParser
from metagpt.utils.file_repository import FileRepository
PROMPT_TEMPLATE = """
NOTICE
@ -138,12 +137,11 @@ class WriteCode(Action):
coding_context.code_doc.content = code
return coding_context
@staticmethod
async def get_codes(task_doc, exclude, git_repo, src_workspace) -> str:
async def get_codes(self, task_doc, exclude, git_repo, src_workspace) -> str:
if not task_doc:
return ""
if not task_doc.content:
task_doc.content = FileRepository.get_file(filename=task_doc.filename, relative_path=TASK_FILE_REPO)
task_doc.content = self.file_repo.get_file(filename=task_doc.filename, relative_path=TASK_FILE_REPO)
m = json.loads(task_doc.content)
code_filenames = m.get("Task list", [])
codes = []

View file

@ -166,9 +166,8 @@ class WritePRD(Action):
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(quadrant_chart, pathname)
@staticmethod
async def _save_pdf(prd_doc):
await FileRepository.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO)
async def _save_pdf(self, prd_doc):
await self.file_repo.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO)
async def _rename_workspace(self, prd):
if not self.project_name:

View file

@ -24,6 +24,10 @@ class Context:
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()
@property
def file_repo(self):
return self.git_repo.new_file_repository()
@property
def options(self):
"""Return all key-values"""

View file

@ -27,7 +27,6 @@ from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Document, Message, RunCodeContext, TestingContext
from metagpt.utils.common import any_to_str_set, parse_recipient
from metagpt.utils.file_repository import FileRepository
class QaEngineer(Role):
@ -138,7 +137,7 @@ class QaEngineer(Role):
async def _debug_error(self, msg):
run_code_context = RunCodeContext.loads(msg.content)
code = await DebugError(context=run_code_context, g_context=self.context, llm=self.llm).run()
await FileRepository.save_file(
await self.context.file_repo.save_file(
filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO
)
run_code_context.output = None

View file

@ -32,7 +32,7 @@ from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
from metagpt.context import Context
from metagpt.context import Context, context
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.memory import Memory
@ -150,7 +150,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
# builtin variables
recovered: bool = False # to tag if a recovered role
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
context: Optional[Context] = Field(default=None, exclude=True)
context: Optional[Context] = Field(default=context, exclude=True)
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`

View file

@ -11,10 +11,9 @@ import uuid
import pytest
from metagpt.actions.debug_error import DebugError
from metagpt.config import CONFIG
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
from metagpt.context import context
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.file_repository import FileRepository
CODE_CONTENT = '''
from typing import List
@ -119,7 +118,7 @@ if __name__ == '__main__':
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_debug_error():
CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex
context.src_workspace = context.git_repo.workdir / uuid.uuid4().hex
ctx = RunCodeContext(
code_filename="player.py",
test_filename="test_player.py",
@ -127,8 +126,9 @@ async def test_debug_error():
output_filename="output.log",
)
await FileRepository.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=CONFIG.src_workspace)
await FileRepository.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO)
repo = context.file_repo
await repo.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=context.src_workspace)
await repo.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO)
output_data = RunCodeResult(
stdout=";",
stderr="",
@ -142,7 +142,7 @@ async def test_debug_error():
"----------------------------------------------------------------------\n"
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
)
await FileRepository.save_file(
await repo.save_file(
filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO
)
debug_error = DebugError(context=ctx)
@ -151,4 +151,4 @@ async def test_debug_error():
assert "class Player" in rsp # rewrite the same class
# a key logic to rewrite to (original one is "if self.score > 12")
assert "while self.score > 21" in rsp
assert "self.score" in rsp

View file

@ -10,18 +10,18 @@ import pytest
from metagpt.actions.design_api import WriteDesign
from metagpt.const import PRDS_FILE_REPO
from metagpt.context import context
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock_markdown import PRD_SAMPLE
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api():
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE]
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE
repo = context.file_repo
for prd in inputs:
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
await repo.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
design_api = WriteDesign()

View file

@ -12,7 +12,6 @@ from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.context import context
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
@pytest.mark.asyncio
@ -25,6 +24,6 @@ async def test_prepare_documents():
await PrepareDocuments(g_context=context).run(with_messages=[msg])
assert context.git_repo
doc = await FileRepository(context.git_repo).get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
doc = await context.file_repo.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
assert doc
assert doc.content == msg.content

View file

@ -9,20 +9,19 @@
import pytest
from metagpt.actions.project_management import WriteTasks
from metagpt.config import CONFIG
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
from metagpt.context import context
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock_json import DESIGN, PRD
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_design_api():
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
logger.info(CONFIG.git_repo)
await context.file_repo.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
await context.file_repo.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
logger.info(context.git_repo)
action = WriteTasks()

View file

@ -11,9 +11,9 @@ import pytest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.context import context
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
from metagpt.utils.file_repository import FileRepository
DESIGN_CONTENT = """
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
@ -179,15 +179,15 @@ class Snake:
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_summarize_code():
CONFIG.src_workspace = CONFIG.git_repo.workdir / "src"
await FileRepository.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
await FileRepository.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT)
await FileRepository.save_file(filename="food.py", relative_path=CONFIG.src_workspace, content=FOOD_PY)
await FileRepository.save_file(filename="game.py", relative_path=CONFIG.src_workspace, content=GAME_PY)
await FileRepository.save_file(filename="main.py", relative_path=CONFIG.src_workspace, content=MAIN_PY)
await FileRepository.save_file(filename="snake.py", relative_path=CONFIG.src_workspace, content=SNAKE_PY)
context.src_workspace = context.git_repo.workdir / "src"
await context.file_repo.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
await context.file_repo.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT)
await context.file_repo.save_file(filename="food.py", relative_path=CONFIG.src_workspace, content=FOOD_PY)
await context.file_repo.save_file(filename="game.py", relative_path=CONFIG.src_workspace, content=GAME_PY)
await context.file_repo.save_file(filename="main.py", relative_path=CONFIG.src_workspace, content=MAIN_PY)
await context.file_repo.save_file(filename="snake.py", relative_path=CONFIG.src_workspace, content=SNAKE_PY)
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
src_file_repo = context.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
all_files = src_file_repo.all_files
ctx = CodeSummarizeContext(design_filename="1.json", task_filename="1.json", codes_filenames=all_files)
action = SummarizeCode(context=ctx)

View file

@ -12,28 +12,27 @@ from pathlib import Path
import pytest
from metagpt.actions.write_code import WriteCode
from metagpt.config import CONFIG
from metagpt.const import (
CODE_SUMMARIES_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.context import context
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAILLM as LLM
from metagpt.schema import CodingContext, Document
from metagpt.utils.common import aread
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_write_code():
context = CodingContext(
ccontext = CodingContext(
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
)
doc = Document(content=context.model_dump_json())
doc = Document(content=ccontext.model_dump_json())
write_code = WriteCode(context=doc)
code = await write_code.run()
@ -57,36 +56,38 @@ async def test_write_code_directly():
@pytest.mark.usefixtures("llm_mock")
async def test_write_code_deps():
# Prerequisites
CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1"
context.src_workspace = context.git_repo.workdir / "snake1/snake1"
demo_path = Path(__file__).parent / "../../data/demo_project"
await FileRepository.save_file(
await context.file_repo.save_file(
filename="test_game.py.json",
content=await aread(str(demo_path / "test_game.py.json")),
relative_path=TEST_OUTPUTS_FILE_REPO,
)
await FileRepository.save_file(
await context.file_repo.save_file(
filename="20231221155954.json",
content=await aread(str(demo_path / "code_summaries.json")),
relative_path=CODE_SUMMARIES_FILE_REPO,
)
await FileRepository.save_file(
await context.file_repo.save_file(
filename="20231221155954.json",
content=await aread(str(demo_path / "system_design.json")),
relative_path=SYSTEM_DESIGN_FILE_REPO,
)
await FileRepository.save_file(
await context.file_repo.save_file(
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO
)
await FileRepository.save_file(
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONFIG.src_workspace
await context.file_repo.save_file(
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=context.src_workspace
)
context = CodingContext(
ccontext = CodingContext(
filename="game.py",
design_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO),
task_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
design_doc=await context.file_repo.get_file(
filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO
),
task_doc=await context.file_repo.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
code_doc=Document(filename="game.py", content="", root_path="snake1"),
)
coding_doc = Document(root_path="snake1", filename="game.py", content=context.json())
coding_doc = Document(root_path="snake1", filename="game.py", content=ccontext.json())
action = WriteCode(context=coding_doc)
rsp = await action.run()

View file

@ -9,12 +9,11 @@
import pytest
from metagpt.actions import UserRequirement
from metagpt.config import CONFIG
from metagpt.const import DOCS_FILE_REPO, PRDS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.context import context
from metagpt.logs import logger
from metagpt.roles.product_manager import ProductManager
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
@pytest.mark.asyncio
@ -22,7 +21,7 @@ from metagpt.utils.file_repository import FileRepository
async def test_write_prd():
product_manager = ProductManager()
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
await context.file_repo.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
logger.info(requirements)
logger.info(prd)
@ -30,4 +29,4 @@ async def test_write_prd():
# Assert the prd is not None or empty
assert prd is not None
assert prd.content != ""
assert CONFIG.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files
assert context.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files

View file

@ -20,11 +20,11 @@ from metagpt.const import (
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
from metagpt.context import context
from metagpt.logs import logger
from metagpt.roles.engineer import Engineer
from metagpt.schema import CodingContext, Message
from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import ChangeType
from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
@ -34,12 +34,12 @@ from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
async def test_engineer():
# Prerequisites
rqno = "20231221155954.json"
await FileRepository.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content)
await FileRepository.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content)
await FileRepository.save_file(
await context.file_repo.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content)
await context.file_repo.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content)
await context.file_repo.save_file(
rqno, relative_path=SYSTEM_DESIGN_FILE_REPO, content=MockMessages.system_design.content
)
await FileRepository.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content)
await context.file_repo.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content)
engineer = Engineer()
rsp = await engineer.run(Message(content="", cause_by=WriteTasks))

View file

@ -7,7 +7,6 @@
"""
import pytest
from metagpt.context import context
from metagpt.logs import logger
from metagpt.roles import ProductManager
from tests.metagpt.roles.mock import MockMessages
@ -16,7 +15,7 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_mock")
async def test_product_manager():
product_manager = ProductManager(context=context)
product_manager = ProductManager()
rsp = await product_manager.run(MockMessages.req)
logger.info(rsp)
assert len(rsp.content) > 0