diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index efe3bcbd4..3a8721004 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -27,15 +27,15 @@ from typing import Iterable, Set, Type, Any from pydantic import BaseModel, Field + from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.const import SERDESER_PATH -from metagpt.llm import LLM +from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.human_provider import HumanProvider from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str, read_json_file, write_json_file, import_class, role_raise_decorator from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output @@ -293,8 +293,7 @@ class Role(BaseModel): """Watch Actions of interest. Role will select Messages caused by these Actions from its personal message buffer during _observe. """ - tags = {any_to_str(t) for t in actions} - self._rc.watch.update(tags) + self._rc.watch = {any_to_str(t) for t in actions} # check RoleContext after adding watch actions self._rc.check(self._role_id) @@ -509,6 +508,8 @@ class Role(BaseModel): msg = with_message elif isinstance(with_message, list): msg = Message(content="\n".join(with_message)) + if not msg.cause_by: + msg.cause_by = UserRequirement self.put_message(msg) if not await self._observe(): diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 8fac2503c..611d321fc 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -14,11 +14,11 @@ import uuid import pytest from pydantic import BaseModel -from metagpt.actions import Action, ActionOutput +from metagpt.actions import Action, ActionOutput, UserRequirement from metagpt.environment import Environment from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import get_class_name +from metagpt.utils.common import any_to_str, get_class_name class MockAction(Action): @@ -60,7 +60,7 @@ async def test_react(): name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc ) role.subscribe({seed.subscription}) - assert role._rc.watch == set({}) + assert role._rc.watch == {any_to_str(UserRequirement)} assert role.name == seed.name assert role.profile == seed.profile assert role._setting.goal == seed.goal