From 780caf011d9b2147455983ce6f7a912016f9f979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Mon, 25 Dec 2023 12:42:23 +0800 Subject: [PATCH] =?UTF-8?q?fixbug:=20=E5=9F=BA=E4=BA=8E=E5=85=A8memory?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=AD=98=E5=82=A8=E7=9A=84=E6=B5=81=E7=A8=8B?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=E6=81=A2=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/const.py | 3 ++ metagpt/memory/memory.py | 6 ++++ metagpt/roles/role.py | 35 +++++++++++++++---- .../serialize_deserialize/test_role.py | 6 +++- .../test_serdeser_base.py | 1 + 5 files changed, 44 insertions(+), 7 deletions(-) diff --git a/metagpt/const.py b/metagpt/const.py index 7de360daf..012c84542 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -121,3 +121,6 @@ BASE64_FORMAT = "base64" # REDIS REDIS_KEY = "REDIS_KEY" LLM_API_TIMEOUT = 300 + +# Message id +IGNORED_MESSAGE_ID = "0" diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index d964cc1dc..8761af83c 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -12,6 +12,7 @@ from typing import Iterable, Set from pydantic import BaseModel, Field +from metagpt.const import IGNORED_MESSAGE_ID from metagpt.schema import Message from metagpt.utils.common import ( any_to_str, @@ -26,6 +27,7 @@ class Memory(BaseModel): storage: list[Message] = [] index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) + ignore_id: bool = False def __init__(self, **kwargs): index = kwargs.get("index", {}) @@ -54,6 +56,8 @@ class Memory(BaseModel): def add(self, message: Message): """Add a new message to storage, while updating the index""" + if self.ignore_id: + message.id = IGNORED_MESSAGE_ID if message in self.storage: return self.storage.append(message) @@ -84,6 +88,8 @@ class Memory(BaseModel): def delete(self, message: Message): """Delete the specified message from storage, while updating the index""" + if self.ignore_id: + message.id = IGNORED_MESSAGE_ID self.storage.remove(message) if message.cause_by and message in self.index[message.cause_by]: self.index[message.cause_by].remove(message) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 992ff83d2..23a7faaae 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -376,7 +376,7 @@ class Role(BaseModel): if self.recovered and self._rc.state >= 0: self._set_state(self._rc.state) # action to run from recovered state - self.recovered = False # avoid max_react_loop out of work + self.set_recovered(False) # avoid max_react_loop out of work return True prompt = self._get_prefix() @@ -433,17 +433,17 @@ class Role(BaseModel): async def _observe(self, ignore_memory=False) -> int: """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. - news = self._rc.msg_buffer.pop_all() + news = [] if self.recovered: news = [self.latest_observed_msg] if self.latest_observed_msg else [] - else: - self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg - + if not news: + news = self._rc.msg_buffer.pop_all() # Store the read messages in your own memory to prevent duplicate processing. old_messages = [] if ignore_memory else self._rc.memory.get() self._rc.memory.add_batch(news) # Filter out messages of interest. - self._rc.news = self._find_news(news, old_messages) + self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages] + self.latest_observed_msg = self._rc.news[-1] if self._rc.news else None # record the latest observed msg # Design Rules: # If you need to further categorize Message objects, you can do so using the Message.set_meta function. @@ -453,6 +453,29 @@ class Role(BaseModel): logger.debug(f"{self._setting} observed: {news_text}") return len(self._rc.news) + # async def _observe(self, ignore_memory=False) -> int: + # """Prepare new messages for processing from the message buffer and other sources.""" + # # Read unprocessed messages from the msg buffer. + # news = self._rc.msg_buffer.pop_all() + # if self.recovered: + # news = [self.latest_observed_msg] if self.latest_observed_msg else [] + # else: + # self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg + # + # # Store the read messages in your own memory to prevent duplicate processing. + # old_messages = [] if ignore_memory else self._rc.memory.get() + # self._rc.memory.add_batch(news) + # # Filter out messages of interest. + # self._rc.news = self._find_news(news, old_messages) + # + # # Design Rules: + # # If you need to further categorize Message objects, you can do so using the Message.set_meta function. + # # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer. + # news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news] + # if news_text: + # logger.debug(f"{self._setting} observed: {news_text}") + # return len(self._rc.news) + def publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" if not msg: diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 72da8a6fc..343f01ace 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -93,4 +93,8 @@ async def test_role_serdeser_interrupt(): assert new_role_a._rc.state == 1 with pytest.raises(Exception): - await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement)) + await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement)) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index a66813489..23c14e851 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -85,3 +85,4 @@ class RoleC(Role): self._init_actions([ActionOK, ActionRaise]) self._watch([UserRequirement]) self._rc.react_mode = RoleReactMode.BY_ORDER + self._rc.memory.ignore_id = True