integrate ltm with stm if LONG_TERM_MEMORY is True

This commit is contained in:
betterwang 2023-07-24 14:55:25 +08:00
parent f6b55c8b3b
commit cddb3aa072
4 changed files with 72 additions and 11 deletions

View file

@ -43,17 +43,24 @@ class LongTermMemory(Memory):
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message)
def remember(self, observed: list[Message], k=0) -> list[Message]:
"""remember the most similar k memories from observed Messages, return all when k=0"""
def remember(self, observed: list[Message], k=10) -> list[Message]:
"""
remember the most similar k memories from observed Messages, return all when k=0
1. remember the short-term memory(stm) news
2. integrate the stm news with ltm(long-term memory) news
"""
stm_news = super(LongTermMemory, self).remember(observed) # shot-term memory news
if not self.memory_storage.is_initialized:
# memory_storage hasn't initialized, use default `remember`
return super(LongTermMemory, self).remember(observed)
# memory_storage hasn't initialized, use default `remember` to get stm_news
return stm_news
news: list[Message] = []
for mem in observed:
ltm_news: list[Message] = []
for mem in stm_news:
# integrate stm & ltm
mem_searched = self.memory_storage.search(mem)
news.extend(mem_searched)
return news[-k:]
if len(mem_searched) > 0:
ltm_news.append(mem)
return ltm_news[-k:]
def delete(self, message: Message):
super(LongTermMemory, self).delete(message)

View file

@ -63,7 +63,7 @@ class Memory:
"""Return the most recent k memories, return all when k=0"""
return self.storage[-k:]
def remember(self, observed: list[Message], k=0) -> list[Message]:
def remember(self, observed: list[Message], k=10) -> list[Message]:
"""remember the most recent k memories from observed Messages, return all when k=0"""
already_observed = self.get(k)
news: list[Message] = []

View file

@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of `metagpt/memory/longterm_memory.py`
from metagpt.config import CONFIG
from metagpt.schema import Message
from metagpt.actions import BossRequirement
from metagpt.roles.role import RoleContext
from metagpt.memory import LongTermMemory
def test_ltm_search():
assert hasattr(CONFIG, "long_term_memory") is True
openai_api_key = CONFIG.openai_api_key
assert len(openai_api_key) > 20
role_id = 'UTUserLtm(Product Manager)'
rc = RoleContext(watch=[BossRequirement])
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)
idea = 'Write a cli snake game'
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
news = ltm.remember([message])
assert len(news) == 1
ltm.add(message)
sim_idea = 'Write a game of cli snake'
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
news = ltm.remember([sim_message])
assert len(news) == 0
ltm.add(sim_message)
new_idea = 'Write a 2048 web game'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
news = ltm.remember([new_message])
assert len(news) == 1
ltm.add(new_message)
# restore from local index
ltm_new = LongTermMemory()
ltm_new.recover_memory(role_id, rc)
news = ltm_new.remember([message])
assert len(news) == 0
ltm_new.recover_memory(role_id, rc)
news = ltm_new.remember([sim_message])
assert len(news) == 0
new_idea = 'Write a Battle City'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
news = ltm_new.remember([new_message])
assert len(news) == 1
ltm_new.clear()

View file

@ -2,8 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : the unittests of metagpt/memory/memory_storage.py
from typing import List, Tuple
import pytest
from typing import List
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message