This commit is contained in:
betterwang 2024-03-07 23:02:06 +08:00
parent 716cb1a0c5
commit fb6b9e2928
5 changed files with 34 additions and 44 deletions

View file

@ -61,8 +61,8 @@ class LongTermMemory(Memory):
ltm_news: list[Message] = []
for mem in stm_news:
# filter out messages similar to those seen previously in ltm, only keep fresh news
mem_searched = await self.memory_storage.search_dissimilar(mem)
if len(mem_searched) > 0:
mem_searched = await self.memory_storage.search_similar(mem)
if len(mem_searched) == 0:
ltm_news.append(mem)
return ltm_news[-k:]

View file

@ -58,16 +58,14 @@ class MemoryStorage(object):
self.faiss_engine.add_objs([message])
logger.info(f"Role {self.role_id}'s memory_storage add a message")
async def search_dissimilar(self, message: Message, k=4) -> list[Message]:
"""search for dissimilar messages"""
async def search_similar(self, message: Message, k=4) -> list[Message]:
"""search for similar messages"""
# filter the result which score is smaller than the threshold
filtered_resp = []
resp = await self.faiss_engine.aretrieve(message.content)
for item in resp:
print(" item.score ", item.score, item)
if item.score < self.threshold:
continue
filtered_resp.append(item.metadata.get("obj"))
filtered_resp.append(item.metadata.get("obj"))
return filtered_resp
def clean(self):
@ -76,4 +74,4 @@ class MemoryStorage(object):
def persit(self):
if self.faiss_engine:
self.faiss_engine.index.storage_context.persist(self.cache_dir)
self.faiss_engine.retriever._index.storage_context.persist(self.cache_dir)

View file

@ -29,10 +29,8 @@ class RAGIndexFactory(ConfigFactory):
embed_model = self.extract_embed_model(config, **kwargs)
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir=config.persist_path, embed_mode=embed_model
)
index = load_index_from_storage(storage_context=storage_context)
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: