From a27fcd7e524b96cd279df520807d4db86ed5498f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Sat, 30 Mar 2024 20:19:46 +0800 Subject: [PATCH] fixbug: di run --- metagpt/roles/di/data_interpreter.py | 10 +++++++++- tests/metagpt/roles/di/test_mgx.py | 10 ++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py index a8534b710..35c6d1297 100644 --- a/metagpt/roles/di/data_interpreter.py +++ b/metagpt/roles/di/data_interpreter.py @@ -14,7 +14,7 @@ from metagpt.roles import Role from metagpt.schema import Message, Task, TaskResult from metagpt.strategy.task_type import TaskType from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender -from metagpt.utils.common import CodeParser +from metagpt.utils.common import CodeParser, role_raise_decorator REACT_THINK_PROMPT = """ # User Requirement @@ -182,3 +182,11 @@ class DataInterpreter(Role): print(result) data_info = DATA_INFO.format(info=result) self.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData)) + + @role_raise_decorator + async def run(self, with_message=None) -> Message | None: + if not self.rc.todo: + self.set_actions([WriteAnalysisCode]) + self._set_state(0) + + return await super().run(with_message) diff --git a/tests/metagpt/roles/di/test_mgx.py b/tests/metagpt/roles/di/test_mgx.py index a835be414..2e67113b9 100644 --- a/tests/metagpt/roles/di/test_mgx.py +++ b/tests/metagpt/roles/di/test_mgx.py @@ -7,11 +7,17 @@ import pytest from metagpt.context import Context from metagpt.roles.di.mgx import MGX from metagpt.schema import Message -from tests.metagpt.actions.test_intent_detect import DEMO_CONTENT +from tests.metagpt.actions.test_intent_detect import DEMO1_CONTENT, DEMO_CONTENT @pytest.mark.asyncio -@pytest.mark.parametrize("user_messages", [[Message.model_validate(i) for i in DEMO_CONTENT if i["role"] == "user"]]) +@pytest.mark.parametrize( + "user_messages", + [ + [Message.model_validate(i) for i in DEMO_CONTENT if i["role"] == "user"], + [Message.model_validate(i) for i in DEMO1_CONTENT if i["role"] == "user"], + ], +) async def test_mgx(user_messages: List[Message]): ctx = Context() mgx = MGX(context=ctx, tools=[""])