diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index d964cc1dc..e9891ed00 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -129,11 +129,3 @@ class Memory(BaseModel): continue rsp += self.index[action] return rsp - - def get_by_tags(self, tags: list) -> list[Message]: - """Return messages with specified tags""" - result = [] - for m in self.storage: - if m.is_contain_tags(tags): - result.append(m) - return result diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/memory/test_memory.py new file mode 100644 index 000000000..36d7ad488 --- /dev/null +++ b/tests/metagpt/memory/test_memory.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of Memory + +from metagpt.actions import UserRequirement +from metagpt.memory.memory import Memory +from metagpt.schema import Message + + +def test_memory(): + memory = Memory() + + message1 = Message(content="test message1", role="user1") + message2 = Message(content="test message2", role="user2") + message3 = Message(content="test message3", role="user1") + memory.add(message1) + assert memory.count() == 1 + + memory.delete_newest() + assert memory.count() == 0 + + memory.add_batch([message1, message2]) + assert memory.count() == 2 + assert len(memory.index.get(message1.cause_by)) == 2 + + messages = memory.get_by_role("user1") + assert messages[0].content == message1.content + + messages = memory.get_by_content("test message") + assert len(messages) == 2 + + messages = memory.get_by_action(UserRequirement) + assert len(messages) == 2 + + messages = memory.get_by_actions([UserRequirement]) + assert len(messages) == 2 + + messages = memory.try_remember("test message") + assert len(messages) == 2 + + messages = memory.get(k=1) + assert len(messages) == 1 + + messages = memory.get(k=5) + assert len(messages) == 2 + + messages = memory.find_news([message3]) + assert len(messages) == 1 + + memory.delete(message1) + assert memory.count() == 1 + messages = memory.get_by_role("user2") + assert messages[0].content == message2.content + + memory.clear() + assert memory.count() == 0 + assert len(memory.index) == 0