From 6a65639cd7790f55dab143886f60aec8e0a032c1 Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 25 Dec 2023 14:38:20 +0800 Subject: [PATCH] update ltm unittest --- metagpt/memory/memory_storage.py | 28 ++++++++++++++------ tests/metagpt/memory/test_longterm_memory.py | 8 +++--- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index 3017c23ad..1850e0ea0 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -6,9 +6,11 @@ """ from pathlib import Path -from typing import List +from typing import Optional +from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores.faiss import FAISS +from langchain_core.embeddings import Embeddings from metagpt.const import DATA_PATH, MEM_TTL from metagpt.document_store.faiss_store import FaissStore @@ -22,20 +24,30 @@ class MemoryStorage(FaissStore): The memory storage with Faiss as ANN search engine """ - def __init__(self, mem_ttl: int = MEM_TTL): + def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None): self.role_id: str = None self.role_mem_path: str = None self.mem_ttl: int = mem_ttl # later use self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories self._initialized: bool = False + self.embedding = embedding or OpenAIEmbeddings() self.store: FAISS = None # Faiss engine @property def is_initialized(self) -> bool: return self._initialized - def recover_memory(self, role_id: str) -> List[Message]: + def _load(self) -> Optional["FaissStore"]: + index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss + + if not (index_file.exists() and store_file.exists()): + logger.info("Missing at least one of index_file/store_file, load failed and return None") + return None + + return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id) + + def recover_memory(self, role_id: str) -> list[Message]: self.role_id = role_id self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/") self.role_mem_path.mkdir(parents=True, exist_ok=True) @@ -52,16 +64,16 @@ class MemoryStorage(FaissStore): return messages - def _get_index_and_store_fname(self): + def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"): if not self.role_mem_path: logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory") return None, None - index_fpath = Path(self.role_mem_path / f"{self.role_id}.index") - storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl") + index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}") + storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}") return index_fpath, storage_fpath def persist(self): - super().persist() + self.store.save_local(self.role_mem_path, self.role_id) logger.debug(f"Agent {self.role_id} persist memory into local") def add(self, message: Message) -> bool: @@ -77,7 +89,7 @@ class MemoryStorage(FaissStore): self.persist() logger.info(f"Agent {self.role_id}'s memory_storage add a message") - def search_dissimilar(self, message: Message, k=4) -> List[Message]: + def search_dissimilar(self, message: Message, k=4) -> list[Message]: """search for dissimilar messages""" if not self.store: return [] diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index 1f07d74e3..ac33552b3 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -5,6 +5,8 @@ @Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ +import os + from metagpt.actions import UserRequirement from metagpt.config import CONFIG from metagpt.memory.longterm_memory import LongTermMemory @@ -14,11 +16,11 @@ from metagpt.schema import Message def test_ltm_search(): assert hasattr(CONFIG, "long_term_memory") is True - openai_api_key = CONFIG.openai_api_key - assert len(openai_api_key) > 20 + os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key) + assert len(CONFIG.openai_api_key) > 20 role_id = "UTUserLtm(Product Manager)" - rc = RoleContext(watch=[UserRequirement]) + rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"}) ltm = LongTermMemory() ltm.recover_memory(role_id, rc)