integrate ltm with stm if LONG_TERM_MEMORY is True

This commit is contained in:
betterwang 2023-07-24 14:55:25 +08:00
parent f6b55c8b3b
commit cddb3aa072
4 changed files with 72 additions and 11 deletions

View file

@ -43,17 +43,24 @@ class LongTermMemory(Memory):
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message)
def remember(self, observed: list[Message], k=0) -> list[Message]:
"""remember the most similar k memories from observed Messages, return all when k=0"""
def remember(self, observed: list[Message], k=10) -> list[Message]:
"""
remember the most similar k memories from observed Messages, return all when k=0
1. remember the short-term memory(stm) news
2. integrate the stm news with ltm(long-term memory) news
"""
stm_news = super(LongTermMemory, self).remember(observed) # shot-term memory news
if not self.memory_storage.is_initialized:
# memory_storage hasn't initialized, use default `remember`
return super(LongTermMemory, self).remember(observed)
# memory_storage hasn't initialized, use default `remember` to get stm_news
return stm_news
news: list[Message] = []
for mem in observed:
ltm_news: list[Message] = []
for mem in stm_news:
# integrate stm & ltm
mem_searched = self.memory_storage.search(mem)
news.extend(mem_searched)
return news[-k:]
if len(mem_searched) > 0:
ltm_news.append(mem)
return ltm_news[-k:]
def delete(self, message: Message):
super(LongTermMemory, self).delete(message)

View file

@ -63,7 +63,7 @@ class Memory:
"""Return the most recent k memories, return all when k=0"""
return self.storage[-k:]
def remember(self, observed: list[Message], k=0) -> list[Message]:
def remember(self, observed: list[Message], k=10) -> list[Message]:
"""remember the most recent k memories from observed Messages, return all when k=0"""
already_observed = self.get(k)
news: list[Message] = []