refactor: subscription -> address

This commit is contained in:
莘权 马 2024-01-08 11:28:24 +08:00
parent fe07b37836
commit 3f2859b15d
4 changed files with 20 additions and 19 deletions

1
.gitignore vendored
View file

@ -175,3 +175,4 @@ htmlcov.*
*.pkl
*-structure.csv
*-structure.json
*.dot

View file

@ -21,7 +21,7 @@ from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
from metagpt.utils.common import is_send_to, read_json_file, write_json_file
class Environment(BaseModel):
@ -111,8 +111,8 @@ class Environment(BaseModel):
logger.debug(f"publish_message: {message.dump()}")
found = False
# According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
for role, subscription in self.members.items():
if is_subscribed(message, subscription):
for role, addrs in self.member_addrs.items():
if is_send_to(message, addrs):
role.put_message(message)
found = True
if not found:
@ -157,13 +157,13 @@ class Environment(BaseModel):
return False
return True
def get_subscription(self, obj):
"""Get the labels for messages to be consumed by the object."""
return self.members.get(obj, {})
def get_addresses(self, obj):
"""Get the addresses of the object."""
return self.member_addrs.get(obj, {})
def set_subscription(self, obj, tags):
"""Set the labels for message to be consumed by the object"""
self.members[obj] = tags
def set_addresses(self, obj, addresses):
"""Set the addresses of the object"""
self.member_addrs[obj] = addresses
def archive(self, auto_archive=True):
if auto_archive and self.context.git_repo:

View file

@ -145,7 +145,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
states: list[str] = []
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
rc: RoleContext = Field(default_factory=RoleContext)
subscription: set[str] = set()
addresses: set[str] = set()
# builtin variables
recovered: bool = False # to tag if a recovered role
@ -200,9 +200,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
return self.context.config.project_path
@model_validator(mode="after")
def check_subscription(self):
if not self.subscription:
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
def check_addresses(self):
if not self.addresses:
self.addresses = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
return self
def __init__(self, **data: Any):
@ -322,14 +322,14 @@ class Role(SerializationMixin, is_polymorphic_base=True):
def is_watch(self, caused_by: str):
return caused_by in self.rc.watch
def subscribe(self, tags: Set[str]):
def set_addresses(self, addresses: Set[str]):
"""Used to receive Messages with certain tags from the environment. Message will be put into personal message
buffer to be further processed in _observe. By default, a Role subscribes Messages with a tag of its own name
or profile.
"""
self.subscription = tags
self.addresses = addresses
if self.rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
self.rc.env.set_subscription(self, self.subscription)
self.rc.env.set_addresses(self, self.addresses)
def _set_state(self, state: int):
"""Update the current state."""
@ -342,7 +342,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
messages by observing."""
self.rc.env = env
if env:
env.set_subscription(self, self.subscription)
env.set_addresses(self, self.addresses)
self.refresh_system_message() # add env message to system message
@property

View file

@ -381,12 +381,12 @@ def any_to_str_set(val) -> set:
return res
def is_subscribed(message: "Message", tags: set):
def is_send_to(message: "Message", addresses: set):
"""Return whether it's consumer"""
if MESSAGE_ROUTE_TO_ALL in message.send_to:
return True
for i in tags:
for i in addresses:
if i in message.send_to:
return True
return False