Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev

This commit is contained in:
莘权 马 2023-12-25 16:15:04 +08:00
commit fe1d60f111
5 changed files with 97 additions and 24 deletions

View file

@ -135,11 +135,3 @@ class Memory(BaseModel):
continue
rsp += self.index[action]
return rsp
def get_by_tags(self, tags: list) -> list[Message]:
"""Return messages with specified tags"""
result = []
for m in self.storage:
if m.is_contain_tags(tags):
result.append(m)
return result

View file

@ -6,9 +6,11 @@
"""
from pathlib import Path
from typing import List
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
@ -22,20 +24,30 @@ class MemoryStorage(FaissStore):
The memory storage with Faiss as ANN search engine
"""
def __init__(self, mem_ttl: int = MEM_TTL):
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
self.role_id: str = None
self.role_mem_path: str = None
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or OpenAIEmbeddings()
self.store: FAISS = None # Faiss engine
@property
def is_initialized(self) -> bool:
return self._initialized
def recover_memory(self, role_id: str) -> List[Message]:
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
def recover_memory(self, role_id: str) -> list[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
@ -52,16 +64,16 @@ class MemoryStorage(FaissStore):
return messages
def _get_index_and_store_fname(self):
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
if not self.role_mem_path:
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
return index_fpath, storage_fpath
def persist(self):
super().persist()
self.store.save_local(self.role_mem_path, self.role_id)
logger.debug(f"Agent {self.role_id} persist memory into local")
def add(self, message: Message) -> bool:
@ -77,7 +89,7 @@ class MemoryStorage(FaissStore):
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
def search_dissimilar(self, message: Message, k=4) -> List[Message]:
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
"""search for dissimilar messages"""
if not self.store:
return []