diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 6ba879990..7317bba76 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -8,12 +8,12 @@ """ import pytest -from metagpt.actions import UserRequirement -from metagpt.actions.prepare_documents import PrepareDocuments +from metagpt.actions import UserRequirement, WritePRD from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, PRDS_FILE_REPO, REQUIREMENT_FILENAME from metagpt.logs import logger from metagpt.roles.product_manager import ProductManager +from metagpt.roles.role import RoleReactMode from metagpt.schema import Message from metagpt.utils.common import any_to_str from metagpt.utils.file_repository import FileRepository @@ -24,9 +24,9 @@ async def test_write_prd(new_filename): product_manager = ProductManager() requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结" await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO) - prepare = await product_manager.run(Message(content=requirements, cause_by=UserRequirement)) - assert prepare.cause_by == any_to_str(PrepareDocuments) - prd = await product_manager.run(with_message=prepare) + product_manager.rc.react_mode = RoleReactMode.BY_ORDER + prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement)) + assert prd.cause_by == any_to_str(WritePRD) logger.info(requirements) logger.info(prd)