fixbug: di run

This commit is contained in:
莘权 马 2024-03-30 20:19:46 +08:00
parent eedba3038d
commit a27fcd7e52
2 changed files with 17 additions and 3 deletions

View file

@ -14,7 +14,7 @@ from metagpt.roles import Role
from metagpt.schema import Message, Task, TaskResult
from metagpt.strategy.task_type import TaskType
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.utils.common import CodeParser
from metagpt.utils.common import CodeParser, role_raise_decorator
REACT_THINK_PROMPT = """
# User Requirement
@ -182,3 +182,11 @@ class DataInterpreter(Role):
print(result)
data_info = DATA_INFO.format(info=result)
self.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData))
@role_raise_decorator
async def run(self, with_message=None) -> Message | None:
if not self.rc.todo:
self.set_actions([WriteAnalysisCode])
self._set_state(0)
return await super().run(with_message)

View file

@ -7,11 +7,17 @@ import pytest
from metagpt.context import Context
from metagpt.roles.di.mgx import MGX
from metagpt.schema import Message
from tests.metagpt.actions.test_intent_detect import DEMO_CONTENT
from tests.metagpt.actions.test_intent_detect import DEMO1_CONTENT, DEMO_CONTENT
@pytest.mark.asyncio
@pytest.mark.parametrize("user_messages", [[Message.model_validate(i) for i in DEMO_CONTENT if i["role"] == "user"]])
@pytest.mark.parametrize(
"user_messages",
[
[Message.model_validate(i) for i in DEMO_CONTENT if i["role"] == "user"],
[Message.model_validate(i) for i in DEMO1_CONTENT if i["role"] == "user"],
],
)
async def test_mgx(user_messages: List[Message]):
ctx = Context()
mgx = MGX(context=ctx, tools=["<all>"])