mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
refactor: Simplify the Message class.
This commit is contained in:
parent
af4c87e123
commit
c18bc7c876
4 changed files with 63 additions and 215 deletions
|
|
@ -44,8 +44,8 @@ SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills"
|
|||
|
||||
MEM_TTL = 24 * 30 * 3600
|
||||
|
||||
MESSAGE_ROUTE_FROM = "msg_from"
|
||||
MESSAGE_ROUTE_TO = "msg_to"
|
||||
MESSAGE_ROUTE_FROM = "sent_from"
|
||||
MESSAGE_ROUTE_TO = "send_to"
|
||||
MESSAGE_ROUTE_CAUSE_BY = "cause_by"
|
||||
MESSAGE_META_ROLE = "role"
|
||||
MESSAGE_ROUTE_TO_ALL = "<all>"
|
||||
|
|
|
|||
|
|
@ -13,19 +13,18 @@ import asyncio
|
|||
import json
|
||||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from typing import Dict, List, Set, TypedDict
|
||||
from typing import List, Set, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.const import (
|
||||
MESSAGE_META_ROLE,
|
||||
MESSAGE_ROUTE_CAUSE_BY,
|
||||
MESSAGE_ROUTE_FROM,
|
||||
MESSAGE_ROUTE_TO,
|
||||
MESSAGE_ROUTE_TO_ALL,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
|
||||
|
||||
class RawMessage(TypedDict):
|
||||
|
|
@ -33,182 +32,56 @@ class RawMessage(TypedDict):
|
|||
role: str
|
||||
|
||||
|
||||
class Routes(BaseModel):
|
||||
"""Responsible for managing routing information for the Message class."""
|
||||
|
||||
routes: List[Dict] = Field(default_factory=list)
|
||||
|
||||
def set_from(self, value):
|
||||
"""Set the label of the message sender."""
|
||||
route = self._get_route()
|
||||
route[MESSAGE_ROUTE_FROM] = value
|
||||
|
||||
def set_to(self, tags: Set):
|
||||
"""Set the labels of the message recipient."""
|
||||
route = self._get_route()
|
||||
if tags:
|
||||
route[MESSAGE_ROUTE_TO] = tags
|
||||
return
|
||||
|
||||
if MESSAGE_ROUTE_TO in route:
|
||||
del route[MESSAGE_ROUTE_TO]
|
||||
|
||||
def add_to(self, tag: str):
|
||||
"""Add a label of the message recipient."""
|
||||
route = self._get_route()
|
||||
tags = route.get(MESSAGE_ROUTE_TO, set())
|
||||
tags.add(tag)
|
||||
route[MESSAGE_ROUTE_TO] = tags
|
||||
|
||||
def _get_route(self) -> Dict:
|
||||
if not self.routes:
|
||||
self.routes.append({})
|
||||
return self.routes[0]
|
||||
|
||||
def contain_any(self, tags: Set) -> bool:
|
||||
"""Check if this object contains these tags."""
|
||||
route = self._get_route()
|
||||
to_tags = route.get(MESSAGE_ROUTE_TO)
|
||||
if not to_tags:
|
||||
return True
|
||||
|
||||
if MESSAGE_ROUTE_TO_ALL in to_tags:
|
||||
return True
|
||||
for k in tags:
|
||||
if k in to_tags:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def msg_from(self):
|
||||
"""Message route info tells who sent this message."""
|
||||
route = self._get_route()
|
||||
return route.get(MESSAGE_ROUTE_FROM)
|
||||
|
||||
@property
|
||||
def msg_to(self):
|
||||
"""Labels for the consumer to filter its subscribed messages."""
|
||||
route = self._get_route()
|
||||
return route.get(MESSAGE_ROUTE_TO)
|
||||
|
||||
def replace(self, old_val, new_val):
|
||||
"""Replace old value with new value"""
|
||||
route = self._get_route()
|
||||
tags = route.get(MESSAGE_ROUTE_TO, set())
|
||||
tags.discard(old_val)
|
||||
tags.add(new_val)
|
||||
route[MESSAGE_ROUTE_TO] = tags
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""list[<role>: <content>]"""
|
||||
|
||||
content: str
|
||||
instruct_content: BaseModel = None
|
||||
meta_info: Dict = Field(default_factory=dict)
|
||||
route: Routes = Field(default_factory=Routes)
|
||||
instruct_content: BaseModel = Field(default=None)
|
||||
role: str = "user" # system / user / assistant
|
||||
cause_by: str = ""
|
||||
sent_from: str = ""
|
||||
send_to: Set = Field(default_factory=set)
|
||||
|
||||
def __init__(self, content, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
content,
|
||||
instruct_content=None,
|
||||
role="user",
|
||||
cause_by="",
|
||||
sent_from="",
|
||||
send_to=MESSAGE_ROUTE_TO_ALL,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters not listed below will be stored as meta info, including custom parameters.
|
||||
:param content: Message content.
|
||||
:param instruct_content: Message content struct.
|
||||
:param meta_info: Message meta info.
|
||||
:param route: Message route configuration.
|
||||
:param msg_from: Message route info tells who sent this message.
|
||||
:param msg_to: Labels for the consumer to filter its subscribed messages.
|
||||
:param cause_by: Labels for the consumer to filter its subscribed messages, also serving as meta info.
|
||||
:param cause_by: Message producer
|
||||
:param sent_from: Message route info tells who sent this message.
|
||||
:param send_to: Labels for the consumer to filter its subscribed messages.
|
||||
:param role: Message meta info tells who sent this message.
|
||||
"""
|
||||
super(Message, self).__init__(
|
||||
content=content or kwargs.get("content"),
|
||||
instruct_content=kwargs.get("instruct_content"),
|
||||
meta_info=kwargs.get("meta_info", {}),
|
||||
route=kwargs.get("route", Routes()),
|
||||
super().__init__(
|
||||
content=content,
|
||||
instruct_content=instruct_content,
|
||||
role=role,
|
||||
cause_by=any_to_str(cause_by),
|
||||
sent_from=any_to_str(sent_from),
|
||||
send_to=any_to_str_set(send_to),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attribute_names = Message.__annotations__.keys()
|
||||
for k, v in kwargs.items():
|
||||
if k in attribute_names:
|
||||
continue
|
||||
if k == MESSAGE_ROUTE_FROM:
|
||||
self.set_from(any_to_str(v))
|
||||
continue
|
||||
if k == MESSAGE_ROUTE_CAUSE_BY:
|
||||
self.set_cause_by(v)
|
||||
continue
|
||||
if k == MESSAGE_ROUTE_TO:
|
||||
if isinstance(v, tuple) or isinstance(v, list) or isinstance(v, set):
|
||||
for i in v:
|
||||
self.add_to(any_to_str(i))
|
||||
else:
|
||||
self.add_to(any_to_str(v))
|
||||
continue
|
||||
self.meta_info[k] = v
|
||||
|
||||
def get_meta(self, key):
|
||||
"""Get meta info"""
|
||||
return self.meta_info.get(key)
|
||||
|
||||
def set_meta(self, key, value):
|
||||
"""Set meta info"""
|
||||
self.meta_info[key] = value
|
||||
|
||||
@property
|
||||
def role(self):
|
||||
"""Message meta info tells who sent this message."""
|
||||
return self.get_meta(MESSAGE_META_ROLE)
|
||||
|
||||
@property
|
||||
def cause_by(self):
|
||||
"""Labels for the consumer to filter its subscribed messages, also serving as meta info."""
|
||||
return self.get_meta(MESSAGE_ROUTE_CAUSE_BY)
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
"""Override `@property.setter`"""
|
||||
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
||||
if key == MESSAGE_ROUTE_CAUSE_BY:
|
||||
self.set_cause_by(val)
|
||||
return
|
||||
if key == MESSAGE_ROUTE_FROM:
|
||||
self.set_from(any_to_str(val))
|
||||
super().__setattr__(key, val)
|
||||
|
||||
def set_cause_by(self, val):
|
||||
"""Update the value of `cause_by` in the `meta_info` and `routes` attributes."""
|
||||
old_value = self.get_meta(MESSAGE_ROUTE_CAUSE_BY)
|
||||
new_value = any_to_str(val)
|
||||
self.set_meta(MESSAGE_ROUTE_CAUSE_BY, new_value)
|
||||
self.route.replace(old_value, new_value)
|
||||
|
||||
@property
|
||||
def msg_from(self):
|
||||
"""Message route info tells who sent this message."""
|
||||
return self.route.msg_from
|
||||
|
||||
@property
|
||||
def msg_to(self):
|
||||
"""Labels for the consumer to filter its subscribed messages."""
|
||||
return self.route.msg_to
|
||||
|
||||
def set_role(self, v):
|
||||
"""Set the message's meta info indicating the sender."""
|
||||
self.set_meta(MESSAGE_META_ROLE, v)
|
||||
|
||||
def set_from(self, v):
|
||||
"""Set the message's meta info indicating the sender."""
|
||||
self.route.set_from(v)
|
||||
|
||||
def set_to(self, tags: Set):
|
||||
"""Set the message's meta info indicating the sender."""
|
||||
self.route.set_to(tags)
|
||||
|
||||
def add_to(self, tag: str):
|
||||
"""Add a subscription label for the recipients."""
|
||||
self.route.add_to(tag)
|
||||
|
||||
def contain_any(self, tags: Set):
|
||||
"""Return true if any input label exists in the message's subscription labels."""
|
||||
return self.route.contain_any(tags)
|
||||
new_val = any_to_str(val)
|
||||
elif key == MESSAGE_ROUTE_FROM:
|
||||
new_val = any_to_str(val)
|
||||
elif key == MESSAGE_ROUTE_TO:
|
||||
new_val = any_to_str_set(val)
|
||||
else:
|
||||
new_val = val
|
||||
super().__setattr__(key, new_val)
|
||||
|
||||
def __str__(self):
|
||||
# prefix = '-'.join([self.role, str(self.cause_by)])
|
||||
|
|
@ -226,13 +99,13 @@ class Message(BaseModel):
|
|||
return self.json(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def load(v):
|
||||
def load(val):
|
||||
"""Convert the json string to object."""
|
||||
try:
|
||||
d = json.loads(v)
|
||||
d = json.loads(val)
|
||||
return Message(**d)
|
||||
except JSONDecodeError as err:
|
||||
logger.error(f"parse json failed: {v}, error:{err}")
|
||||
logger.error(f"parse json failed: {val}, error:{err}")
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -327,31 +200,3 @@ class MessageQueue:
|
|||
logger.warning(f"JSON load failed: {v}, error:{e}")
|
||||
|
||||
return q
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
m = Message("a", role="v1")
|
||||
m.set_role("v2")
|
||||
v = m.dump()
|
||||
m = Message.load(v)
|
||||
m.cause_by = "Message"
|
||||
m.cause_by = Routes
|
||||
m.cause_by = Routes()
|
||||
m.content = "b"
|
||||
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role="QA"),
|
||||
]
|
||||
logger.info(msgs)
|
||||
|
||||
jsons = [
|
||||
UserMessage(test_content).dump(),
|
||||
SystemMessage(test_content).dump(),
|
||||
AIMessage(test_content).dump(),
|
||||
Message(test_content, role="QA").dump(),
|
||||
]
|
||||
logger.info(jsons)
|
||||
|
|
|
|||
|
|
@ -325,3 +325,14 @@ def any_to_str(val) -> str:
|
|||
return get_object_name(val)
|
||||
|
||||
return get_class_name(val)
|
||||
|
||||
|
||||
def any_to_str_set(val) -> set:
|
||||
"""Convert any type to string set."""
|
||||
res = set()
|
||||
if isinstance(val, dict) or isinstance(val, list) or isinstance(val, set) or isinstance(val, tuple):
|
||||
for i in val:
|
||||
res.add(any_to_str(i))
|
||||
else:
|
||||
res.add(any_to_str(val))
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import json
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.schema import AIMessage, Message, Routes, SystemMessage, UserMessage
|
||||
from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
|
||||
from metagpt.utils.common import get_class_name
|
||||
|
||||
|
||||
|
|
@ -37,20 +37,19 @@ def test_message():
|
|||
d = json.loads(v)
|
||||
assert d
|
||||
assert d.get("content") == "a"
|
||||
assert d.get("meta_info") == {"role": "v1"}
|
||||
m.set_role("v2")
|
||||
assert d.get("role") == "v1"
|
||||
m.role = "v2"
|
||||
v = m.dump()
|
||||
assert v
|
||||
m = Message.load(v)
|
||||
assert m.content == "a"
|
||||
assert m.role == "v2"
|
||||
|
||||
m = Message("a", role="b", cause_by="c", x="d")
|
||||
m = Message("a", role="b", cause_by="c", x="d", send_to="c")
|
||||
assert m.content == "a"
|
||||
assert m.role == "b"
|
||||
assert m.contain_any({"c"})
|
||||
assert m.send_to == {"c"}
|
||||
assert m.cause_by == "c"
|
||||
assert m.get_meta("x") == "d"
|
||||
|
||||
m.cause_by = "Message"
|
||||
assert m.cause_by == "Message"
|
||||
|
|
@ -64,18 +63,11 @@ def test_message():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
def test_routes():
|
||||
route = Routes()
|
||||
route.set_from("a")
|
||||
assert route.msg_from == "a"
|
||||
route.add_to("b")
|
||||
assert route.msg_to == {"b"}
|
||||
route.add_to("c")
|
||||
assert route.msg_to == {"b", "c"}
|
||||
route.set_to({"e", "f"})
|
||||
assert route.msg_to == {"e", "f"}
|
||||
assert route.contain_any({"e"})
|
||||
assert route.contain_any({"f"})
|
||||
assert not route.contain_any({"a"})
|
||||
m = Message("a", role="b", cause_by="c", x="d", send_to="c")
|
||||
m.send_to = "b"
|
||||
assert m.send_to == {"b"}
|
||||
m.send_to = {"e", Action}
|
||||
assert m.send_to == {"e", get_class_name(Action)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue