update ltm unittest

This commit is contained in:
better629 2023-12-25 14:38:20 +08:00
parent 94a0699ec4
commit 6a65639cd7
2 changed files with 25 additions and 11 deletions

View file

@ -6,9 +6,11 @@
"""
from pathlib import Path
from typing import List
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
@ -22,20 +24,30 @@ class MemoryStorage(FaissStore):
The memory storage with Faiss as ANN search engine
"""
def __init__(self, mem_ttl: int = MEM_TTL):
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
self.role_id: str = None
self.role_mem_path: str = None
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or OpenAIEmbeddings()
self.store: FAISS = None # Faiss engine
@property
def is_initialized(self) -> bool:
return self._initialized
def recover_memory(self, role_id: str) -> List[Message]:
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
def recover_memory(self, role_id: str) -> list[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
@ -52,16 +64,16 @@ class MemoryStorage(FaissStore):
return messages
def _get_index_and_store_fname(self):
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
if not self.role_mem_path:
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
return index_fpath, storage_fpath
def persist(self):
super().persist()
self.store.save_local(self.role_mem_path, self.role_id)
logger.debug(f"Agent {self.role_id} persist memory into local")
def add(self, message: Message) -> bool:
@ -77,7 +89,7 @@ class MemoryStorage(FaissStore):
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
def search_dissimilar(self, message: Message, k=4) -> List[Message]:
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
"""search for dissimilar messages"""
if not self.store:
return []

View file

@ -5,6 +5,8 @@
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
import os
from metagpt.actions import UserRequirement
from metagpt.config import CONFIG
from metagpt.memory.longterm_memory import LongTermMemory
@ -14,11 +16,11 @@ from metagpt.schema import Message
def test_ltm_search():
assert hasattr(CONFIG, "long_term_memory") is True
openai_api_key = CONFIG.openai_api_key
assert len(openai_api_key) > 20
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
assert len(CONFIG.openai_api_key) > 20
role_id = "UTUserLtm(Product Manager)"
rc = RoleContext(watch=[UserRequirement])
rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"})
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)