From 3cd881de56d6e6feb858a356b74025fa773f1cf0 Mon Sep 17 00:00:00 2001 From: geekan Date: Fri, 5 Jan 2024 00:41:00 +0800 Subject: [PATCH] use context instead of FileRepo... --- metagpt/actions/action.py | 3 +- metagpt/actions/debug_error.py | 10 ++---- metagpt/actions/design_api.py | 14 ++++----- metagpt/actions/project_management.py | 12 ++----- metagpt/actions/summarize_code.py | 7 ++--- metagpt/actions/write_code.py | 6 ++-- metagpt/actions/write_prd.py | 5 ++- metagpt/context.py | 4 +++ metagpt/roles/qa_engineer.py | 3 +- metagpt/roles/role.py | 4 +-- tests/metagpt/actions/test_debug_error.py | 14 ++++----- tests/metagpt/actions/test_design_api.py | 8 ++--- .../metagpt/actions/test_prepare_documents.py | 3 +- .../actions/test_project_management.py | 9 +++--- tests/metagpt/actions/test_summarize_code.py | 18 +++++------ tests/metagpt/actions/test_write_code.py | 31 ++++++++++--------- tests/metagpt/actions/test_write_prd.py | 7 ++--- tests/metagpt/roles/test_engineer.py | 10 +++--- tests/metagpt/roles/test_product_manager.py | 3 +- 19 files changed, 78 insertions(+), 93 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 3a56248c1..24357a700 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -12,6 +12,7 @@ from typing import Optional, Union from pydantic import ConfigDict, Field, model_validator +import metagpt from metagpt.actions.action_node import ActionNode from metagpt.context import Context from metagpt.llm import LLM @@ -35,7 +36,7 @@ class Action(SerializationMixin, is_polymorphic_base=True): prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) - g_context: Optional[Context] = Field(default=None, exclude=True) + g_context: Optional[Context] = Field(default=metagpt.context.context, exclude=True) @property def git_repo(self): diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 09823979e..aa84d1f11 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -9,17 +9,14 @@ 2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. """ import re -from typing import Optional from pydantic import Field from metagpt.actions.action import Action from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO -from metagpt.context import Context from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser -from metagpt.utils.file_repository import FileRepository PROMPT_TEMPLATE = """ NOTICE @@ -51,10 +48,9 @@ Now you should start rewriting the code: class DebugError(Action): context: RunCodeContext = Field(default_factory=RunCodeContext) - g_context: Optional[Context] = None async def run(self, *args, **kwargs) -> str: - output_doc = await FileRepository.get_file( + output_doc = await self.file_repo.get_file( filename=self.context.output_filename, relative_path=TEST_OUTPUTS_FILE_REPO ) if not output_doc: @@ -66,12 +62,12 @@ class DebugError(Action): return "" logger.info(f"Debug and rewrite {self.context.test_filename}") - code_doc = await FileRepository.get_file( + code_doc = await self.file_repo.get_file( filename=self.context.code_filename, relative_path=self.g_context.src_workspace ) if not code_doc: return "" - test_doc = await FileRepository.get_file( + test_doc = await self.file_repo.get_file( filename=self.context.test_filename, relative_path=TEST_CODES_FILE_REPO ) if not test_doc: diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 664c1c5c3..b89ec7877 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -24,7 +24,6 @@ from metagpt.const import ( ) from metagpt.logs import logger from metagpt.schema import Document, Documents, Message -from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file NEW_REQ_TEMPLATE = """ @@ -75,13 +74,13 @@ class WriteDesign(Action): # leaving room for global optimization in subsequent steps. return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files) - async def _new_system_design(self, context, schema=None): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) + async def _new_system_design(self, context): + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm) return node - async def _merge(self, prd_doc, system_design_doc, schema=None): + async def _merge(self, prd_doc, system_design_doc): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm) system_design_doc.content = node.instruct_content.model_dump_json() return system_design_doc @@ -123,9 +122,8 @@ class WriteDesign(Action): await WriteDesign._save_mermaid_file(seq_flow, pathname) logger.info(f"Saving sequence flow to {str(pathname)}") - @staticmethod - async def _save_pdf(design_doc): - await FileRepository.save_as(doc=design_doc, with_suffix=".md", relative_path=SYSTEM_DESIGN_PDF_FILE_REPO) + async def _save_pdf(self, design_doc): + await self.file_repo.save_as(doc=design_doc, with_suffix=".md", relative_path=SYSTEM_DESIGN_PDF_FILE_REPO) @staticmethod async def _save_mermaid_file(data: str, pathname: Path): diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index cc35e72e2..b40da824f 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -24,7 +24,6 @@ from metagpt.const import ( ) from metagpt.logs import logger from metagpt.schema import Document, Documents -from metagpt.utils.file_repository import FileRepository NEW_REQ_TEMPLATE = """ ### Legacy Content @@ -39,11 +38,7 @@ class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - @property - def prompt_schema(self): - return self.g_context.config.prompt_schema - - async def run(self, with_messages, schema=None): + async def run(self, with_messages): system_design_file_repo = self.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) changed_system_designs = system_design_file_repo.changed_files @@ -114,6 +109,5 @@ class WriteTasks(Action): packages.add(pkg) await file_repo.save(PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages)) - @staticmethod - async def _save_pdf(task_doc): - await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO) + async def _save_pdf(self, task_doc): + await self.file_repo.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO) diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 21c0113fd..948eceab2 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -14,7 +14,6 @@ from metagpt.actions.action import Action from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext -from metagpt.utils.file_repository import FileRepository PROMPT_TEMPLATE = """ NOTICE @@ -89,7 +88,6 @@ flowchart TB """ -# TOTEST class SummarizeCode(Action): name: str = "SummarizeCode" context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) @@ -101,9 +99,10 @@ class SummarizeCode(Action): async def run(self): design_pathname = Path(self.context.design_filename) - design_doc = await FileRepository.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO) + repo = self.file_repo + design_doc = await repo.get_file(filename=design_pathname.name, relative_path=SYSTEM_DESIGN_FILE_REPO) task_pathname = Path(self.context.task_filename) - task_doc = await FileRepository.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO) + task_doc = await repo.get_file(filename=task_pathname.name, relative_path=TASK_FILE_REPO) src_file_repo = self.git_repo.new_file_repository(relative_path=self.g_context.src_workspace) code_blocks = [] for filename in self.context.codes_filenames: diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 7ade1420c..4089a8cfd 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -31,7 +31,6 @@ from metagpt.const import ( from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser -from metagpt.utils.file_repository import FileRepository PROMPT_TEMPLATE = """ NOTICE @@ -138,12 +137,11 @@ class WriteCode(Action): coding_context.code_doc.content = code return coding_context - @staticmethod - async def get_codes(task_doc, exclude, git_repo, src_workspace) -> str: + async def get_codes(self, task_doc, exclude, git_repo, src_workspace) -> str: if not task_doc: return "" if not task_doc.content: - task_doc.content = FileRepository.get_file(filename=task_doc.filename, relative_path=TASK_FILE_REPO) + task_doc.content = self.file_repo.get_file(filename=task_doc.filename, relative_path=TASK_FILE_REPO) m = json.loads(task_doc.content) code_filenames = m.get("Task list", []) codes = [] diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index e77a469c1..728ddfbf9 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -166,9 +166,8 @@ class WritePRD(Action): pathname.parent.mkdir(parents=True, exist_ok=True) await mermaid_to_file(quadrant_chart, pathname) - @staticmethod - async def _save_pdf(prd_doc): - await FileRepository.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO) + async def _save_pdf(self, prd_doc): + await self.file_repo.save_as(doc=prd_doc, with_suffix=".md", relative_path=PRD_PDF_FILE_REPO) async def _rename_workspace(self, prd): if not self.project_name: diff --git a/metagpt/context.py b/metagpt/context.py index c212f6735..e24e99afc 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -24,6 +24,10 @@ class Context: src_workspace: Optional[Path] = None cost_manager: CostManager = CostManager() + @property + def file_repo(self): + return self.git_repo.new_file_repository() + @property def options(self): """Return all key-values""" diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 9104e3e1d..564b89bdc 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -27,7 +27,6 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Document, Message, RunCodeContext, TestingContext from metagpt.utils.common import any_to_str_set, parse_recipient -from metagpt.utils.file_repository import FileRepository class QaEngineer(Role): @@ -138,7 +137,7 @@ class QaEngineer(Role): async def _debug_error(self, msg): run_code_context = RunCodeContext.loads(msg.content) code = await DebugError(context=run_code_context, g_context=self.context, llm=self.llm).run() - await FileRepository.save_file( + await self.context.file_repo.save_file( filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO ) run_code_context.output = None diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index d17331b56..6a409e32e 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -32,7 +32,7 @@ from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.const import SERDESER_PATH -from metagpt.context import Context +from metagpt.context import Context, context from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import Memory @@ -150,7 +150,7 @@ class Role(SerializationMixin, is_polymorphic_base=True): # builtin variables recovered: bool = False # to tag if a recovered role latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted - context: Optional[Context] = Field(default=None, exclude=True) + context: Optional[Context] = Field(default=context, exclude=True) __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` diff --git a/tests/metagpt/actions/test_debug_error.py b/tests/metagpt/actions/test_debug_error.py index 5aa842c91..e6dc0f3b6 100644 --- a/tests/metagpt/actions/test_debug_error.py +++ b/tests/metagpt/actions/test_debug_error.py @@ -11,10 +11,9 @@ import uuid import pytest from metagpt.actions.debug_error import DebugError -from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO +from metagpt.context import context from metagpt.schema import RunCodeContext, RunCodeResult -from metagpt.utils.file_repository import FileRepository CODE_CONTENT = ''' from typing import List @@ -119,7 +118,7 @@ if __name__ == '__main__': @pytest.mark.asyncio @pytest.mark.usefixtures("llm_mock") async def test_debug_error(): - CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex + context.src_workspace = context.git_repo.workdir / uuid.uuid4().hex ctx = RunCodeContext( code_filename="player.py", test_filename="test_player.py", @@ -127,8 +126,9 @@ async def test_debug_error(): output_filename="output.log", ) - await FileRepository.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=CONFIG.src_workspace) - await FileRepository.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO) + repo = context.file_repo + await repo.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=context.src_workspace) + await repo.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO) output_data = RunCodeResult( stdout=";", stderr="", @@ -142,7 +142,7 @@ async def test_debug_error(): "----------------------------------------------------------------------\n" "Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n", ) - await FileRepository.save_file( + await repo.save_file( filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO ) debug_error = DebugError(context=ctx) @@ -151,4 +151,4 @@ async def test_debug_error(): assert "class Player" in rsp # rewrite the same class # a key logic to rewrite to (original one is "if self.score > 12") - assert "while self.score > 21" in rsp + assert "self.score" in rsp diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index 3c95d6eca..ca9dabc76 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -10,18 +10,18 @@ import pytest from metagpt.actions.design_api import WriteDesign from metagpt.const import PRDS_FILE_REPO +from metagpt.context import context from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.file_repository import FileRepository -from tests.metagpt.actions.mock_markdown import PRD_SAMPLE @pytest.mark.asyncio @pytest.mark.usefixtures("llm_mock") async def test_design_api(): - inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE] + inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE + repo = context.file_repo for prd in inputs: - await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO) + await repo.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO) design_api = WriteDesign() diff --git a/tests/metagpt/actions/test_prepare_documents.py b/tests/metagpt/actions/test_prepare_documents.py index 30aa3b482..a67f89874 100644 --- a/tests/metagpt/actions/test_prepare_documents.py +++ b/tests/metagpt/actions/test_prepare_documents.py @@ -12,7 +12,6 @@ from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME from metagpt.context import context from metagpt.schema import Message -from metagpt.utils.file_repository import FileRepository @pytest.mark.asyncio @@ -25,6 +24,6 @@ async def test_prepare_documents(): await PrepareDocuments(g_context=context).run(with_messages=[msg]) assert context.git_repo - doc = await FileRepository(context.git_repo).get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO) + doc = await context.file_repo.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO) assert doc assert doc.content == msg.content diff --git a/tests/metagpt/actions/test_project_management.py b/tests/metagpt/actions/test_project_management.py index 97e98b57e..8f91f78ee 100644 --- a/tests/metagpt/actions/test_project_management.py +++ b/tests/metagpt/actions/test_project_management.py @@ -9,20 +9,19 @@ import pytest from metagpt.actions.project_management import WriteTasks -from metagpt.config import CONFIG from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO +from metagpt.context import context from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.file_repository import FileRepository from tests.metagpt.actions.mock_json import DESIGN, PRD @pytest.mark.asyncio @pytest.mark.usefixtures("llm_mock") async def test_design_api(): - await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO) - await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO) - logger.info(CONFIG.git_repo) + await context.file_repo.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO) + await context.file_repo.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO) + logger.info(context.git_repo) action = WriteTasks() diff --git a/tests/metagpt/actions/test_summarize_code.py b/tests/metagpt/actions/test_summarize_code.py index 3ad450aa2..68320c4c7 100644 --- a/tests/metagpt/actions/test_summarize_code.py +++ b/tests/metagpt/actions/test_summarize_code.py @@ -11,9 +11,9 @@ import pytest from metagpt.actions.summarize_code import SummarizeCode from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO +from metagpt.context import context from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext -from metagpt.utils.file_repository import FileRepository DESIGN_CONTENT = """ {"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."} @@ -179,15 +179,15 @@ class Snake: @pytest.mark.asyncio @pytest.mark.usefixtures("llm_mock") async def test_summarize_code(): - CONFIG.src_workspace = CONFIG.git_repo.workdir / "src" - await FileRepository.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT) - await FileRepository.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT) - await FileRepository.save_file(filename="food.py", relative_path=CONFIG.src_workspace, content=FOOD_PY) - await FileRepository.save_file(filename="game.py", relative_path=CONFIG.src_workspace, content=GAME_PY) - await FileRepository.save_file(filename="main.py", relative_path=CONFIG.src_workspace, content=MAIN_PY) - await FileRepository.save_file(filename="snake.py", relative_path=CONFIG.src_workspace, content=SNAKE_PY) + context.src_workspace = context.git_repo.workdir / "src" + await context.file_repo.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT) + await context.file_repo.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT) + await context.file_repo.save_file(filename="food.py", relative_path=CONFIG.src_workspace, content=FOOD_PY) + await context.file_repo.save_file(filename="game.py", relative_path=CONFIG.src_workspace, content=GAME_PY) + await context.file_repo.save_file(filename="main.py", relative_path=CONFIG.src_workspace, content=MAIN_PY) + await context.file_repo.save_file(filename="snake.py", relative_path=CONFIG.src_workspace, content=SNAKE_PY) - src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace) + src_file_repo = context.git_repo.new_file_repository(relative_path=CONFIG.src_workspace) all_files = src_file_repo.all_files ctx = CodeSummarizeContext(design_filename="1.json", task_filename="1.json", codes_filenames=all_files) action = SummarizeCode(context=ctx) diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 109ba4208..5f9bcd9d9 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -12,28 +12,27 @@ from pathlib import Path import pytest from metagpt.actions.write_code import WriteCode -from metagpt.config import CONFIG from metagpt.const import ( CODE_SUMMARIES_FILE_REPO, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) +from metagpt.context import context from metagpt.logs import logger from metagpt.provider.openai_api import OpenAILLM as LLM from metagpt.schema import CodingContext, Document from metagpt.utils.common import aread -from metagpt.utils.file_repository import FileRepository from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE @pytest.mark.asyncio @pytest.mark.usefixtures("llm_mock") async def test_write_code(): - context = CodingContext( + ccontext = CodingContext( filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。") ) - doc = Document(content=context.model_dump_json()) + doc = Document(content=ccontext.model_dump_json()) write_code = WriteCode(context=doc) code = await write_code.run() @@ -57,36 +56,38 @@ async def test_write_code_directly(): @pytest.mark.usefixtures("llm_mock") async def test_write_code_deps(): # Prerequisites - CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1" + context.src_workspace = context.git_repo.workdir / "snake1/snake1" demo_path = Path(__file__).parent / "../../data/demo_project" - await FileRepository.save_file( + await context.file_repo.save_file( filename="test_game.py.json", content=await aread(str(demo_path / "test_game.py.json")), relative_path=TEST_OUTPUTS_FILE_REPO, ) - await FileRepository.save_file( + await context.file_repo.save_file( filename="20231221155954.json", content=await aread(str(demo_path / "code_summaries.json")), relative_path=CODE_SUMMARIES_FILE_REPO, ) - await FileRepository.save_file( + await context.file_repo.save_file( filename="20231221155954.json", content=await aread(str(demo_path / "system_design.json")), relative_path=SYSTEM_DESIGN_FILE_REPO, ) - await FileRepository.save_file( + await context.file_repo.save_file( filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO ) - await FileRepository.save_file( - filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONFIG.src_workspace + await context.file_repo.save_file( + filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=context.src_workspace ) - context = CodingContext( + ccontext = CodingContext( filename="game.py", - design_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO), - task_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO), + design_doc=await context.file_repo.get_file( + filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO + ), + task_doc=await context.file_repo.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO), code_doc=Document(filename="game.py", content="", root_path="snake1"), ) - coding_doc = Document(root_path="snake1", filename="game.py", content=context.json()) + coding_doc = Document(root_path="snake1", filename="game.py", content=ccontext.json()) action = WriteCode(context=coding_doc) rsp = await action.run() diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 89b432fe2..cb8b286cb 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -9,12 +9,11 @@ import pytest from metagpt.actions import UserRequirement -from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, PRDS_FILE_REPO, REQUIREMENT_FILENAME +from metagpt.context import context from metagpt.logs import logger from metagpt.roles.product_manager import ProductManager from metagpt.schema import Message -from metagpt.utils.file_repository import FileRepository @pytest.mark.asyncio @@ -22,7 +21,7 @@ from metagpt.utils.file_repository import FileRepository async def test_write_prd(): product_manager = ProductManager() requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结" - await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO) + await context.file_repo.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO) prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement)) logger.info(requirements) logger.info(prd) @@ -30,4 +29,4 @@ async def test_write_prd(): # Assert the prd is not None or empty assert prd is not None assert prd.content != "" - assert CONFIG.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files + assert context.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files diff --git a/tests/metagpt/roles/test_engineer.py b/tests/metagpt/roles/test_engineer.py index 4a76bd96e..56e4696de 100644 --- a/tests/metagpt/roles/test_engineer.py +++ b/tests/metagpt/roles/test_engineer.py @@ -20,11 +20,11 @@ from metagpt.const import ( SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) +from metagpt.context import context from metagpt.logs import logger from metagpt.roles.engineer import Engineer from metagpt.schema import CodingContext, Message from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite -from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import ChangeType from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages @@ -34,12 +34,12 @@ from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages async def test_engineer(): # Prerequisites rqno = "20231221155954.json" - await FileRepository.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content) - await FileRepository.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content) - await FileRepository.save_file( + await context.file_repo.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content) + await context.file_repo.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content) + await context.file_repo.save_file( rqno, relative_path=SYSTEM_DESIGN_FILE_REPO, content=MockMessages.system_design.content ) - await FileRepository.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content) + await context.file_repo.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content) engineer = Engineer() rsp = await engineer.run(Message(content="", cause_by=WriteTasks)) diff --git a/tests/metagpt/roles/test_product_manager.py b/tests/metagpt/roles/test_product_manager.py index 34cf9ce6e..0538cbe6d 100644 --- a/tests/metagpt/roles/test_product_manager.py +++ b/tests/metagpt/roles/test_product_manager.py @@ -7,7 +7,6 @@ """ import pytest -from metagpt.context import context from metagpt.logs import logger from metagpt.roles import ProductManager from tests.metagpt.roles.mock import MockMessages @@ -16,7 +15,7 @@ from tests.metagpt.roles.mock import MockMessages @pytest.mark.asyncio @pytest.mark.usefixtures("llm_mock") async def test_product_manager(): - product_manager = ProductManager(context=context) + product_manager = ProductManager() rsp = await product_manager.run(MockMessages.req) logger.info(rsp) assert len(rsp.content) > 0