refactor: rename is_recipient

This commit is contained in:
莘权 马 2023-11-08 13:42:08 +08:00
parent 93ebe8c103
commit af4c87e123
8 changed files with 19 additions and 19 deletions

View file

@ -60,7 +60,7 @@ class Trump(Role):
async def _observe(self) -> int:
await super()._observe()
# accept messages sent (from opponent) to self, disregard own messages from the last round
self._rc.news = [msg for msg in self._rc.news if msg.is_recipient({self.name})]
self._rc.news = [msg for msg in self._rc.news if msg.contain_any({self.name})]
return len(self._rc.news)
async def _act(self) -> Message:
@ -103,7 +103,7 @@ class Biden(Role):
# accept the very first human instruction (the debate topic) or messages sent (from opponent) to self,
# disregard own messages from the last round
message_filter = {BossRequirement, self.name}
self._rc.news = [msg for msg in self._rc.news if msg.is_recipient(message_filter)]
self._rc.news = [msg for msg in self._rc.news if msg.contain_any(message_filter)]
return len(self._rc.news)
async def _act(self) -> Message:

View file

@ -59,7 +59,7 @@ class WriteCode(Action):
return
message_filter = {WriteDesign}
design = [i for i in context if i.is_recipient(message_filter)][0]
design = [i for i in context if i.contain_any(message_filter)][0]
ws_name = CodeParser.parse_str(block="Python package name", text=design.content)
ws_path = WORKSPACE_ROOT / ws_name

View file

@ -63,7 +63,7 @@ class Environment(BaseModel):
found = False
# According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
for obj, subscribed_tags in self.consumers.items():
if message.is_recipient(subscribed_tags):
if message.contain_any(subscribed_tags):
obj.put_message(message)
found = True
if not found:

View file

@ -3,7 +3,7 @@
"""
@Desc : the implement of Long-term memory
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116:
1. Replace code related to message filtering with the `Message.is_recipient` function.
1. Replace code related to message filtering with the `Message.contain_any` function.
"""
from metagpt.logs import logger
@ -40,7 +40,7 @@ class LongTermMemory(Memory):
def add(self, message: Message):
super(LongTermMemory, self).add(message)
if message.is_recipient(self.rc.watch) and not self.msg_from_recover:
if message.contain_any(self.rc.watch) and not self.msg_from_recover:
# currently, only add role's watching messages to its memory_storage
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message)

View file

@ -233,7 +233,7 @@ class Engineer(Role):
# Parse task lists
message_filter = {WriteTasks}
for message in self._rc.news:
if not message.is_recipient(message_filter):
if not message.contain_any(message_filter):
continue
self.todos = self.parse_tasks(message)

View file

@ -154,7 +154,7 @@ class QaEngineer(Role):
async def _observe(self) -> int:
await super()._observe()
self._rc.news = [
msg for msg in self._rc.news if msg.is_recipient({self.profile})
msg for msg in self._rc.news if msg.contain_any({self.profile})
] # only relevant msgs count as observed news
return len(self._rc.news)
@ -174,13 +174,13 @@ class QaEngineer(Role):
for msg in self._rc.news:
# Decide what to do based on observed msg type, currently defined by human,
# might potentially be moved to _think, that is, let the agent decides for itself
if msg.is_recipient(code_filters):
if msg.contain_any(code_filters):
# engineer wrote a code, time to write a test for it
await self._write_test(msg)
elif msg.is_recipient(test_filters):
elif msg.contain_any(test_filters):
# I wrote or debugged my test code, time to run it
await self._run_code(msg)
elif msg.is_recipient(run_filters):
elif msg.contain_any(run_filters):
# I ran my test code, time to fix bugs, if any
await self._debug_error(msg)
self.test_round += 1

View file

@ -65,8 +65,8 @@ class Routes(BaseModel):
self.routes.append({})
return self.routes[0]
def is_recipient(self, tags: Set) -> bool:
"""Check if it is the message recipient."""
def contain_any(self, tags: Set) -> bool:
"""Check if this object contains these tags."""
route = self._get_route()
to_tags = route.get(MESSAGE_ROUTE_TO)
if not to_tags:
@ -206,9 +206,9 @@ class Message(BaseModel):
"""Add a subscription label for the recipients."""
self.route.add_to(tag)
def is_recipient(self, tags: Set):
def contain_any(self, tags: Set):
"""Return true if any input label exists in the message's subscription labels."""
return self.route.is_recipient(tags)
return self.route.contain_any(tags)
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])

View file

@ -48,7 +48,7 @@ def test_message():
m = Message("a", role="b", cause_by="c", x="d")
assert m.content == "a"
assert m.role == "b"
assert m.is_recipient({"c"})
assert m.contain_any({"c"})
assert m.cause_by == "c"
assert m.get_meta("x") == "d"
@ -73,9 +73,9 @@ def test_routes():
assert route.msg_to == {"b", "c"}
route.set_to({"e", "f"})
assert route.msg_to == {"e", "f"}
assert route.is_recipient({"e"})
assert route.is_recipient({"f"})
assert not route.is_recipient({"a"})
assert route.contain_any({"e"})
assert route.contain_any({"f"})
assert not route.contain_any({"a"})
if __name__ == "__main__":