diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 926c845cb..8521c046b 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -43,17 +43,24 @@ class LongTermMemory(Memory): # and ignore adding messages from recover repeatedly self.memory_storage.add(message) - def remember(self, observed: list[Message], k=0) -> list[Message]: - """remember the most similar k memories from observed Messages, return all when k=0""" + def remember(self, observed: list[Message], k=10) -> list[Message]: + """ + remember the most similar k memories from observed Messages, return all when k=0 + 1. remember the short-term memory(stm) news + 2. integrate the stm news with ltm(long-term memory) news + """ + stm_news = super(LongTermMemory, self).remember(observed) # shot-term memory news if not self.memory_storage.is_initialized: - # memory_storage hasn't initialized, use default `remember` - return super(LongTermMemory, self).remember(observed) + # memory_storage hasn't initialized, use default `remember` to get stm_news + return stm_news - news: list[Message] = [] - for mem in observed: + ltm_news: list[Message] = [] + for mem in stm_news: + # integrate stm & ltm mem_searched = self.memory_storage.search(mem) - news.extend(mem_searched) - return news[-k:] + if len(mem_searched) > 0: + ltm_news.append(mem) + return ltm_news[-k:] def delete(self, message: Message): super(LongTermMemory, self).delete(message) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index a96aaf1be..5d3b736a3 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -63,7 +63,7 @@ class Memory: """Return the most recent k memories, return all when k=0""" return self.storage[-k:] - def remember(self, observed: list[Message], k=0) -> list[Message]: + def remember(self, observed: list[Message], k=10) -> list[Message]: """remember the most recent k memories from observed Messages, return all when k=0""" already_observed = self.get(k) news: list[Message] = [] diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py new file mode 100644 index 000000000..62a3a2361 --- /dev/null +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of `metagpt/memory/longterm_memory.py` + +from metagpt.config import CONFIG +from metagpt.schema import Message +from metagpt.actions import BossRequirement +from metagpt.roles.role import RoleContext +from metagpt.memory import LongTermMemory + + +def test_ltm_search(): + assert hasattr(CONFIG, "long_term_memory") is True + openai_api_key = CONFIG.openai_api_key + assert len(openai_api_key) > 20 + + role_id = 'UTUserLtm(Product Manager)' + rc = RoleContext(watch=[BossRequirement]) + ltm = LongTermMemory() + ltm.recover_memory(role_id, rc) + + idea = 'Write a cli snake game' + message = Message(role='BOSS', content=idea, cause_by=BossRequirement) + news = ltm.remember([message]) + assert len(news) == 1 + ltm.add(message) + + sim_idea = 'Write a game of cli snake' + sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement) + news = ltm.remember([sim_message]) + assert len(news) == 0 + ltm.add(sim_message) + + new_idea = 'Write a 2048 web game' + new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement) + news = ltm.remember([new_message]) + assert len(news) == 1 + ltm.add(new_message) + + # restore from local index + ltm_new = LongTermMemory() + ltm_new.recover_memory(role_id, rc) + news = ltm_new.remember([message]) + assert len(news) == 0 + + ltm_new.recover_memory(role_id, rc) + news = ltm_new.remember([sim_message]) + assert len(news) == 0 + + new_idea = 'Write a Battle City' + new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement) + news = ltm_new.remember([new_message]) + assert len(news) == 1 + + ltm_new.clear() diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index 4e59fb003..6bb3e8f1d 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -2,8 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : the unittests of metagpt/memory/memory_storage.py -from typing import List, Tuple -import pytest +from typing import List from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message