diff --git a/metagpt/const.py b/metagpt/const.py index 2ba875543..fa0ccc536 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -44,8 +44,8 @@ SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" MEM_TTL = 24 * 30 * 3600 -MESSAGE_ROUTE_FROM = "msg_from" -MESSAGE_ROUTE_TO = "msg_to" +MESSAGE_ROUTE_FROM = "sent_from" +MESSAGE_ROUTE_TO = "send_to" MESSAGE_ROUTE_CAUSE_BY = "cause_by" MESSAGE_META_ROLE = "role" MESSAGE_ROUTE_TO_ALL = "" diff --git a/metagpt/schema.py b/metagpt/schema.py index 1b00843a6..7fdcef2ed 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -13,19 +13,18 @@ import asyncio import json from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError -from typing import Dict, List, Set, TypedDict +from typing import List, Set, TypedDict from pydantic import BaseModel, Field from metagpt.const import ( - MESSAGE_META_ROLE, MESSAGE_ROUTE_CAUSE_BY, MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, ) from metagpt.logs import logger -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_str, any_to_str_set class RawMessage(TypedDict): @@ -33,182 +32,56 @@ class RawMessage(TypedDict): role: str -class Routes(BaseModel): - """Responsible for managing routing information for the Message class.""" - - routes: List[Dict] = Field(default_factory=list) - - def set_from(self, value): - """Set the label of the message sender.""" - route = self._get_route() - route[MESSAGE_ROUTE_FROM] = value - - def set_to(self, tags: Set): - """Set the labels of the message recipient.""" - route = self._get_route() - if tags: - route[MESSAGE_ROUTE_TO] = tags - return - - if MESSAGE_ROUTE_TO in route: - del route[MESSAGE_ROUTE_TO] - - def add_to(self, tag: str): - """Add a label of the message recipient.""" - route = self._get_route() - tags = route.get(MESSAGE_ROUTE_TO, set()) - tags.add(tag) - route[MESSAGE_ROUTE_TO] = tags - - def _get_route(self) -> Dict: - if not self.routes: - self.routes.append({}) - return self.routes[0] - - def contain_any(self, tags: Set) -> bool: - """Check if this object contains these tags.""" - route = self._get_route() - to_tags = route.get(MESSAGE_ROUTE_TO) - if not to_tags: - return True - - if MESSAGE_ROUTE_TO_ALL in to_tags: - return True - for k in tags: - if k in to_tags: - return True - return False - - @property - def msg_from(self): - """Message route info tells who sent this message.""" - route = self._get_route() - return route.get(MESSAGE_ROUTE_FROM) - - @property - def msg_to(self): - """Labels for the consumer to filter its subscribed messages.""" - route = self._get_route() - return route.get(MESSAGE_ROUTE_TO) - - def replace(self, old_val, new_val): - """Replace old value with new value""" - route = self._get_route() - tags = route.get(MESSAGE_ROUTE_TO, set()) - tags.discard(old_val) - tags.add(new_val) - route[MESSAGE_ROUTE_TO] = tags - - class Message(BaseModel): """list[: ]""" content: str - instruct_content: BaseModel = None - meta_info: Dict = Field(default_factory=dict) - route: Routes = Field(default_factory=Routes) + instruct_content: BaseModel = Field(default=None) + role: str = "user" # system / user / assistant + cause_by: str = "" + sent_from: str = "" + send_to: Set = Field(default_factory=set) - def __init__(self, content, **kwargs): + def __init__( + self, + content, + instruct_content=None, + role="user", + cause_by="", + sent_from="", + send_to=MESSAGE_ROUTE_TO_ALL, + **kwargs, + ): """ Parameters not listed below will be stored as meta info, including custom parameters. :param content: Message content. :param instruct_content: Message content struct. - :param meta_info: Message meta info. - :param route: Message route configuration. - :param msg_from: Message route info tells who sent this message. - :param msg_to: Labels for the consumer to filter its subscribed messages. - :param cause_by: Labels for the consumer to filter its subscribed messages, also serving as meta info. + :param cause_by: Message producer + :param sent_from: Message route info tells who sent this message. + :param send_to: Labels for the consumer to filter its subscribed messages. :param role: Message meta info tells who sent this message. """ - super(Message, self).__init__( - content=content or kwargs.get("content"), - instruct_content=kwargs.get("instruct_content"), - meta_info=kwargs.get("meta_info", {}), - route=kwargs.get("route", Routes()), + super().__init__( + content=content, + instruct_content=instruct_content, + role=role, + cause_by=any_to_str(cause_by), + sent_from=any_to_str(sent_from), + send_to=any_to_str_set(send_to), + **kwargs, ) - attribute_names = Message.__annotations__.keys() - for k, v in kwargs.items(): - if k in attribute_names: - continue - if k == MESSAGE_ROUTE_FROM: - self.set_from(any_to_str(v)) - continue - if k == MESSAGE_ROUTE_CAUSE_BY: - self.set_cause_by(v) - continue - if k == MESSAGE_ROUTE_TO: - if isinstance(v, tuple) or isinstance(v, list) or isinstance(v, set): - for i in v: - self.add_to(any_to_str(i)) - else: - self.add_to(any_to_str(v)) - continue - self.meta_info[k] = v - - def get_meta(self, key): - """Get meta info""" - return self.meta_info.get(key) - - def set_meta(self, key, value): - """Set meta info""" - self.meta_info[key] = value - - @property - def role(self): - """Message meta info tells who sent this message.""" - return self.get_meta(MESSAGE_META_ROLE) - - @property - def cause_by(self): - """Labels for the consumer to filter its subscribed messages, also serving as meta info.""" - return self.get_meta(MESSAGE_ROUTE_CAUSE_BY) - def __setattr__(self, key, val): - """Override `@property.setter`""" + """Override `@property.setter`, convert non-string parameters into string parameters.""" if key == MESSAGE_ROUTE_CAUSE_BY: - self.set_cause_by(val) - return - if key == MESSAGE_ROUTE_FROM: - self.set_from(any_to_str(val)) - super().__setattr__(key, val) - - def set_cause_by(self, val): - """Update the value of `cause_by` in the `meta_info` and `routes` attributes.""" - old_value = self.get_meta(MESSAGE_ROUTE_CAUSE_BY) - new_value = any_to_str(val) - self.set_meta(MESSAGE_ROUTE_CAUSE_BY, new_value) - self.route.replace(old_value, new_value) - - @property - def msg_from(self): - """Message route info tells who sent this message.""" - return self.route.msg_from - - @property - def msg_to(self): - """Labels for the consumer to filter its subscribed messages.""" - return self.route.msg_to - - def set_role(self, v): - """Set the message's meta info indicating the sender.""" - self.set_meta(MESSAGE_META_ROLE, v) - - def set_from(self, v): - """Set the message's meta info indicating the sender.""" - self.route.set_from(v) - - def set_to(self, tags: Set): - """Set the message's meta info indicating the sender.""" - self.route.set_to(tags) - - def add_to(self, tag: str): - """Add a subscription label for the recipients.""" - self.route.add_to(tag) - - def contain_any(self, tags: Set): - """Return true if any input label exists in the message's subscription labels.""" - return self.route.contain_any(tags) + new_val = any_to_str(val) + elif key == MESSAGE_ROUTE_FROM: + new_val = any_to_str(val) + elif key == MESSAGE_ROUTE_TO: + new_val = any_to_str_set(val) + else: + new_val = val + super().__setattr__(key, new_val) def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) @@ -226,13 +99,13 @@ class Message(BaseModel): return self.json(exclude_none=True) @staticmethod - def load(v): + def load(val): """Convert the json string to object.""" try: - d = json.loads(v) + d = json.loads(val) return Message(**d) except JSONDecodeError as err: - logger.error(f"parse json failed: {v}, error:{err}") + logger.error(f"parse json failed: {val}, error:{err}") return None @@ -327,31 +200,3 @@ class MessageQueue: logger.warning(f"JSON load failed: {v}, error:{e}") return q - - -if __name__ == "__main__": - m = Message("a", role="v1") - m.set_role("v2") - v = m.dump() - m = Message.load(v) - m.cause_by = "Message" - m.cause_by = Routes - m.cause_by = Routes() - m.content = "b" - - test_content = "test_message" - msgs = [ - UserMessage(test_content), - SystemMessage(test_content), - AIMessage(test_content), - Message(test_content, role="QA"), - ] - logger.info(msgs) - - jsons = [ - UserMessage(test_content).dump(), - SystemMessage(test_content).dump(), - AIMessage(test_content).dump(), - Message(test_content, role="QA").dump(), - ] - logger.info(jsons) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index b372f0d8d..cd42b1412 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -325,3 +325,14 @@ def any_to_str(val) -> str: return get_object_name(val) return get_class_name(val) + + +def any_to_str_set(val) -> set: + """Convert any type to string set.""" + res = set() + if isinstance(val, dict) or isinstance(val, list) or isinstance(val, set) or isinstance(val, tuple): + for i in val: + res.add(any_to_str(i)) + else: + res.add(any_to_str(val)) + return res diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 05127362b..51ebd5baa 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -12,7 +12,7 @@ import json import pytest from metagpt.actions import Action -from metagpt.schema import AIMessage, Message, Routes, SystemMessage, UserMessage +from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.utils.common import get_class_name @@ -37,20 +37,19 @@ def test_message(): d = json.loads(v) assert d assert d.get("content") == "a" - assert d.get("meta_info") == {"role": "v1"} - m.set_role("v2") + assert d.get("role") == "v1" + m.role = "v2" v = m.dump() assert v m = Message.load(v) assert m.content == "a" assert m.role == "v2" - m = Message("a", role="b", cause_by="c", x="d") + m = Message("a", role="b", cause_by="c", x="d", send_to="c") assert m.content == "a" assert m.role == "b" - assert m.contain_any({"c"}) + assert m.send_to == {"c"} assert m.cause_by == "c" - assert m.get_meta("x") == "d" m.cause_by = "Message" assert m.cause_by == "Message" @@ -64,18 +63,11 @@ def test_message(): @pytest.mark.asyncio def test_routes(): - route = Routes() - route.set_from("a") - assert route.msg_from == "a" - route.add_to("b") - assert route.msg_to == {"b"} - route.add_to("c") - assert route.msg_to == {"b", "c"} - route.set_to({"e", "f"}) - assert route.msg_to == {"e", "f"} - assert route.contain_any({"e"}) - assert route.contain_any({"f"}) - assert not route.contain_any({"a"}) + m = Message("a", role="b", cause_by="c", x="d", send_to="c") + m.send_to = "b" + assert m.send_to == {"b"} + m.send_to = {"e", Action} + assert m.send_to == {"e", get_class_name(Action)} if __name__ == "__main__":