From 10467379105a0f526ca75471942bd0c48ea2e512 Mon Sep 17 00:00:00 2001 From: betterwang Date: Thu, 7 Mar 2024 17:26:14 +0800 Subject: [PATCH] update memory_storage --- metagpt/memory/memory2.py | 25 ---------------- metagpt/memory/memory_network.py | 20 ------------- metagpt/memory/memory_storage.py | 49 +++++++++++++------------------- 3 files changed, 20 insertions(+), 74 deletions(-) delete mode 100644 metagpt/memory/memory2.py delete mode 100644 metagpt/memory/memory_network.py diff --git a/metagpt/memory/memory2.py b/metagpt/memory/memory2.py deleted file mode 100644 index 74f848278..000000000 --- a/metagpt/memory/memory2.py +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : memory mechanism including store/retrieval/rank - -from typing import Optional - -from pydantic import BaseModel, Field - -from metagpt.memory.memory_network import MemoryNetwork -from metagpt.memory.schema import MemoryNode -from metagpt.schema import Message - - -class Memory(BaseModel): - mem_network: Optional[MemoryNetwork] = Field( - default_factory=MemoryNetwork, description="the network to store memory" - ) - - def add_msg(self, message: Message): - mem_node = MemoryNode.create_mem_node_from_message(message) - self.mem_network.add_mem(mem_node) - - def add_msgs(self, messages: list[Message]): - for msg in messages: - self.add_msg(msg) diff --git a/metagpt/memory/memory_network.py b/metagpt/memory/memory_network.py deleted file mode 100644 index f8f2244ed..000000000 --- a/metagpt/memory/memory_network.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : the memory network to store memory segment - -from pydantic import BaseModel, Field - -from metagpt.memory.schema import MemoryNode, MemorySegment - - -class MemoryNetwork(BaseModel): - mem_seg: MemorySegment = Field( - default_factory=MemorySegment, description="the memory segment to store memory nodes" - ) - - def add_mem(self, mem_node: MemoryNode): - self.mem_seg.add_mem_node(mem_node) - - def add_mems(self, mem_nodes: list[MemoryNode]): - for mem_node in mem_nodes: - self.add_mem(mem_node) diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index f096cec72..756508f05 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -3,16 +3,17 @@ """ @Desc : the implement of memory storage """ - +import shutil from pathlib import Path -from llama_index.embeddings import BaseEmbedding +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.schema import QueryBundle, TextNode from metagpt.const import DATA_PATH, MEM_TTL from metagpt.document_store.faiss_store import FaissStore from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.serialize import deserialize_message, serialize_message +from metagpt.utils.embedding import get_embedding class MemoryStorage(FaissStore): @@ -26,6 +27,7 @@ class MemoryStorage(FaissStore): self.mem_ttl: int = mem_ttl # later use self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories self._initialized: bool = False + self.embedding = embedding or get_embedding() @property def is_initialized(self) -> bool: @@ -35,6 +37,7 @@ class MemoryStorage(FaissStore): self.role_id = role_id self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/") self.role_mem_path.mkdir(parents=True, exist_ok=True) + self.cache_dir = self.role_mem_path self.store = self._load() messages = [] @@ -43,34 +46,22 @@ class MemoryStorage(FaissStore): pass else: for _id, document in self.store.docstore._dict.items(): - messages.append(deserialize_message(document.metadata.get("message_ser"))) + messages.append(Message(**document.metadata.get("obj_dict"))) self._initialized = True return messages - def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"): - if not self.role_mem_path: - logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory") - return None, None - index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}") - storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}") - self.cache_dir = Path(self.role_mem_path).joinpath(self.role_id) - return index_fpath, storage_fpath - - def persist(self): - self.store.save_local(self.role_mem_path, self.role_id) - logger.debug(f"Agent {self.role_id} persist memory into local") - def add(self, message: Message) -> bool: """add message into memory storage""" docs = [message.content] - metadatas = [{"message_ser": serialize_message(message)}] + metadatas = [{"obj_dict": message.model_dump()}] if not self.store: # init Faiss self.store = self._write(docs, metadatas) self._initialized = True else: - self.store.add_texts(texts=docs, metadatas=metadatas) + text_node = TextNode(text=message.content, metadata=metadatas[0]) + self.store.insert_nodes([text_node]) self.persist() logger.info(f"Agent {self.role_id}'s memory_storage add a message") @@ -79,25 +70,25 @@ class MemoryStorage(FaissStore): if not self.store: return [] - resp = self.store.similarity_search_with_score(query=message.content, k=k) + retriever = self.store.as_retriever(similarity_top_k=k) + resp = retriever.retrieve( + QueryBundle(query_str=message.content, embedding=self.embedding.get_text_embedding(message.content)) + ) # filter the result which score is smaller than the threshold filtered_resp = [] - for item, score in resp: + for item in resp: # the smaller score means more similar relation - if score < self.threshold: + + if item.score < self.threshold: continue # convert search result into Memory - metadata = item.metadata - new_mem = deserialize_message(metadata.get("message_ser")) + metadata = item.node.metadata + new_mem = Message(**metadata.get("obj_dict", {})) filtered_resp.append(new_mem) return filtered_resp def clean(self): - index_fpath, storage_fpath = self._get_index_and_store_fname() - if index_fpath and index_fpath.exists(): - index_fpath.unlink(missing_ok=True) - if storage_fpath and storage_fpath.exists(): - storage_fpath.unlink(missing_ok=True) + shutil.rmtree(self.cache_dir, ignore_errors=True) self.store = None self._initialized = False