diff --git a/metagpt/schema.py b/metagpt/schema.py index bb8d8b42c..52020c468 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -24,6 +24,7 @@ from metagpt.const import ( MESSAGE_ROUTE_TO, ) from metagpt.logs import logger +from metagpt.utils.common import get_class_name, get_object_name class RawMessage(TypedDict): @@ -87,6 +88,14 @@ class Routes(BaseModel): route = self._get_route() return route.get(MESSAGE_ROUTE_TO) + def replace(self, old_val, new_val): + """Replace old value with new value""" + route = self._get_route() + tags = route.get(MESSAGE_ROUTE_TO, set()) + tags.discard(old_val) + tags.add(new_val) + route[MESSAGE_ROUTE_TO] = tags + class Message(BaseModel): """list[: ]""" @@ -147,6 +156,26 @@ class Message(BaseModel): """Labels for the consumer to filter its subscribed messages, also serving as meta info.""" return self.get_meta(MESSAGE_ROUTE_CAUSE_BY) + def __setattr__(self, key, val): + """Override `@property.setter`""" + if key == MESSAGE_ROUTE_CAUSE_BY: + self.set_cause_by(val) + return + super().__setattr__(key, val) + + def set_cause_by(self, val): + """Update the value of `cause_by` in the `meta_info` and `routes` attributes.""" + old_value = self.get_meta(MESSAGE_ROUTE_CAUSE_BY) + new_value = None + if isinstance(val, str): + new_value = val + elif not callable(val): + new_value = get_object_name(val) + else: + new_value = get_class_name(val) + self.set_meta(MESSAGE_ROUTE_CAUSE_BY, new_value) + self.route.replace(old_value, new_value) + @property def tx_from(self): """Message route info tells who sent this message.""" @@ -301,6 +330,10 @@ if __name__ == "__main__": m.set_role("v2") v = m.dump() m = Message.load(v) + m.cause_by = "Message" + m.cause_by = Routes + m.cause_by = Routes() + m.content = "b" test_content = "test_message" msgs = [ diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 21ba3fd14..e4aa0c0dd 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -11,6 +11,7 @@ import json import pytest +from metagpt.actions import Action from metagpt.schema import AIMessage, Message, Routes, SystemMessage, UserMessage @@ -50,6 +51,15 @@ def test_message(): assert m.cause_by == "c" assert m.get_meta("x") == "d" + m.cause_by = "Message" + assert m.cause_by == "Message" + m.cause_by = Action + assert m.cause_by == Action.get_class_name() + m.cause_by = Action() + assert m.cause_by == Action.get_class_name() + m.content = "b" + assert m.content == "b" + @pytest.mark.asyncio def test_routes():