From 57a1fac357db6d099bded9afa5098e35e3bce2c8 Mon Sep 17 00:00:00 2001 From: betterwang Date: Thu, 7 Mar 2024 19:05:46 +0800 Subject: [PATCH] memory_storage use rag_engine --- metagpt/document_store/base_store.py | 6 +- metagpt/document_store/faiss_store.py | 8 +-- metagpt/memory/longterm_memory.py | 12 ++-- metagpt/memory/memory_storage.py | 61 +++++++------------- metagpt/schema.py | 8 +++ tests/metagpt/memory/test_longterm_memory.py | 15 ++--- tests/metagpt/memory/test_memory_storage.py | 22 +++---- 7 files changed, 59 insertions(+), 73 deletions(-) diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py index 129da4f4f..6aafc57bb 100644 --- a/metagpt/document_store/base_store.py +++ b/metagpt/document_store/base_store.py @@ -38,9 +38,9 @@ class LocalStore(BaseStore, ABC): if not self.store: self.store = self.write() - def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"): - index_file = self.cache_dir / "default__vector_store.json" - store_file = self.cache_dir / "docstore.json" + def _get_index_and_store_fname(self, index_ext=".json", docstore_ext=".json"): + index_file = self.cache_dir / "default__vector_store" / index_ext + store_file = self.cache_dir / "docstore" / docstore_ext return index_file, store_file @abstractmethod diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index 25d1211b3..b196bef27 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -33,7 +33,7 @@ class FaissStore(LocalStore): super().__init__(raw_data, cache_dir) def _load(self) -> Optional["VectorStoreIndex"]: - index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # FAISS using .faiss + index_file, store_file = self._get_index_and_store_fname() if not (index_file.exists() and store_file.exists()): logger.info("Missing at least one of index_file/store_file, load failed and return None") @@ -46,12 +46,8 @@ class FaissStore(LocalStore): def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex: assert len(docs) == len(metadatas) - texts_embeds = self.embedding.get_text_embedding_batch(docs) documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)] - [TextNode(embedding=embed, metadata=metadatas[idx]) for idx, embed in enumerate(texts_embeds)] - # doc_store = SimpleDocumentStore() - # doc_store.add_documents(nodes) vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536)) storage_context = StorageContext.from_defaults(vector_store=vector_store) index = VectorStoreIndex.from_documents( @@ -90,7 +86,7 @@ class FaissStore(LocalStore): def add(self, texts: list[str], *args, **kwargs) -> list[str]: """FIXME: Currently, the store is not updated after adding.""" texts_embeds = self.embedding.get_text_embedding_batch(texts) - nodes = [TextNode(embedding=embed) for embed in texts_embeds] + nodes = [TextNode(text=texts[idx], embedding=embed) for idx, embed in enumerate(texts_embeds)] self.store.insert_nodes(nodes) return [] diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 5a139a93b..e90413085 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -29,16 +29,14 @@ class LongTermMemory(Memory): msg_from_recover: bool = False def recover_memory(self, role_id: str, rc: RoleContext): - messages = self.memory_storage.recover_memory(role_id) + self.memory_storage.recover_memory(role_id) self.rc = rc if not self.memory_storage.is_initialized: logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty") else: - logger.warning( - f"Agent {role_id} has existing memory storage with {len(messages)} messages " f"and has recovered them." - ) + logger.warning(f"Role {role_id} has existing memory storage and has recovered them.") self.msg_from_recover = True - self.add_batch(messages) + # self.add_batch(messages) # TODO no need self.msg_from_recover = False def add(self, message: Message): @@ -49,7 +47,7 @@ class LongTermMemory(Memory): # and ignore adding messages from recover repeatedly self.memory_storage.add(message) - def find_news(self, observed: list[Message], k=0) -> list[Message]: + async def find_news(self, observed: list[Message], k=0) -> list[Message]: """ find news (previously unseen messages) from the the most recent k memories, from all memories when k=0 1. find the short-term memory(stm) news @@ -63,7 +61,7 @@ class LongTermMemory(Memory): ltm_news: list[Message] = [] for mem in stm_news: # filter out messages similar to those seen previously in ltm, only keep fresh news - mem_searched = self.memory_storage.search_dissimilar(mem) + mem_searched = await self.memory_storage.search_dissimilar(mem) if len(mem_searched) > 0: ltm_news.append(mem) return ltm_news[-k:] diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index 756508f05..b7d49e1c3 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -7,16 +7,16 @@ import shutil from pathlib import Path from llama_index.core.embeddings import BaseEmbedding -from llama_index.core.schema import QueryBundle, TextNode from metagpt.const import DATA_PATH, MEM_TTL -from metagpt.document_store.faiss_store import FaissStore from metagpt.logs import logger +from metagpt.rag.engines.simple import SimpleEngine +from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig from metagpt.schema import Message from metagpt.utils.embedding import get_embedding -class MemoryStorage(FaissStore): +class MemoryStorage(object): """ The memory storage with Faiss as ANN search engine """ @@ -29,6 +29,8 @@ class MemoryStorage(FaissStore): self._initialized: bool = False self.embedding = embedding or get_embedding() + self.faiss_engine = None + @property def is_initialized(self) -> bool: return self._initialized @@ -39,56 +41,35 @@ class MemoryStorage(FaissStore): self.role_mem_path.mkdir(parents=True, exist_ok=True) self.cache_dir = self.role_mem_path - self.store = self._load() - messages = [] - if not self.store: - # TODO init `self.store` under here with raw faiss api instead under `add` - pass + if self.role_mem_path.joinpath("default__vector_store.json").exists(): + self.faiss_engine = SimpleEngine.from_index( + index_config=[FAISSIndexConfig(persist_path=self.cache_dir)], + retriever_configs=[FAISSRetrieverConfig()], + embed_model=self.embedding, + ) else: - for _id, document in self.store.docstore._dict.items(): - messages.append(Message(**document.metadata.get("obj_dict"))) - self._initialized = True - - return messages + self.faiss_engine = SimpleEngine.from_objs( + objs=[], retriever_configs=[FAISSRetrieverConfig()], embed_model=self.embedding + ) + self._initialized = True def add(self, message: Message) -> bool: """add message into memory storage""" - docs = [message.content] - metadatas = [{"obj_dict": message.model_dump()}] - if not self.store: - # init Faiss - self.store = self._write(docs, metadatas) - self._initialized = True - else: - text_node = TextNode(text=message.content, metadata=metadatas[0]) - self.store.insert_nodes([text_node]) - self.persist() - logger.info(f"Agent {self.role_id}'s memory_storage add a message") + self.faiss_engine.add_objs([message]) + logger.info(f"Role {self.role_id}'s memory_storage add a message") - def search_dissimilar(self, message: Message, k=4) -> list[Message]: + async def search_dissimilar(self, message: Message, k=4) -> list[Message]: """search for dissimilar messages""" - if not self.store: - return [] - - retriever = self.store.as_retriever(similarity_top_k=k) - resp = retriever.retrieve( - QueryBundle(query_str=message.content, embedding=self.embedding.get_text_embedding(message.content)) - ) # filter the result which score is smaller than the threshold filtered_resp = [] + resp = await self.faiss_engine.aretrieve(message.content) for item in resp: - # the smaller score means more similar relation - + print(" item.score ", item.score, item) if item.score < self.threshold: continue - # convert search result into Memory - metadata = item.node.metadata - new_mem = Message(**metadata.get("obj_dict", {})) - filtered_resp.append(new_mem) + filtered_resp.append(item.metadata.get("obj")) return filtered_resp def clean(self): shutil.rmtree(self.cache_dir, ignore_errors=True) - - self.store = None self._initialized = False diff --git a/metagpt/schema.py b/metagpt/schema.py index 7906febe0..45c7480f9 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -233,6 +233,10 @@ class Message(BaseModel): def check_send_to(cls, send_to: Any) -> set: return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL}) + @field_serializer("send_to", mode="plain") + def ser_send_to(self, send_to: set) -> list: + return list(send_to) + @field_serializer("instruct_content", mode="plain") def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]: ic_dict = None @@ -276,6 +280,10 @@ class Message(BaseModel): def __repr__(self): return self.__str__() + def rag_key(self) -> str: + """For search""" + return self.content + def to_dict(self) -> dict: """Return a dict containing `role` and `content` for the LLM call.l""" return {"role": self.role, "content": self.content} diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index f7e652758..d9eb5e67f 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -17,7 +17,8 @@ from tests.metagpt.memory.mock_text_embed import ( ) -def test_ltm_search(mocker): +@pytest.mark.asyncio +async def test_ltm_search(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) role_id = "UTUserLtm(Product Manager)" @@ -31,36 +32,36 @@ def test_ltm_search(mocker): idea = text_embed_arr[0].get("text", "Write a cli snake game") message = Message(role="User", content=idea, cause_by=UserRequirement) - news = ltm.find_news([message]) + news = await ltm.find_news([message]) assert len(news) == 1 ltm.add(message) sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) - news = ltm.find_news([sim_message]) + news = await ltm.find_news([sim_message]) assert len(news) == 0 ltm.add(sim_message) new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) - news = ltm.find_news([new_message]) + news = await ltm.find_news([new_message]) assert len(news) == 1 ltm.add(new_message) # restore from local index ltm_new = LongTermMemory() ltm_new.recover_memory(role_id, rc) - news = ltm_new.find_news([message]) + news = await ltm_new.find_news([message]) assert len(news) == 0 ltm_new.recover_memory(role_id, rc) - news = ltm_new.find_news([sim_message]) + news = await ltm_new.find_news([sim_message]) assert len(news) == 0 new_idea = text_embed_arr[3].get("text", "Write a Battle City") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) - news = ltm_new.find_news([new_message]) + news = await ltm_new.find_news([new_message]) assert len(news) == 1 ltm_new.clear() diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index 28a73276b..35f2309c5 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -8,6 +8,8 @@ import shutil from pathlib import Path from typing import List +import pytest + from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.action_node import ActionNode from metagpt.const import DATA_PATH @@ -19,7 +21,8 @@ from tests.metagpt.memory.mock_text_embed import ( ) -def test_idea_message(mocker): +@pytest.mark.asyncio +async def test_idea_message(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) idea = text_embed_arr[0].get("text", "Write a cli snake game") @@ -29,27 +32,27 @@ def test_idea_message(mocker): shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True) memory_storage: MemoryStorage = MemoryStorage() - messages = memory_storage.recover_memory(role_id) - assert len(messages) == 0 + memory_storage.recover_memory(role_id) memory_storage.add(message) assert memory_storage.is_initialized is True sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) - new_messages = memory_storage.search_dissimilar(sim_message) + new_messages = await memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) - new_messages = memory_storage.search_dissimilar(new_message) + new_messages = await memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content memory_storage.clean() assert memory_storage.is_initialized is False -def test_actionout_message(mocker): +@pytest.mark.asyncio +async def test_actionout_message(mocker): mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} @@ -67,22 +70,21 @@ def test_actionout_message(mocker): shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True) memory_storage: MemoryStorage = MemoryStorage() - messages = memory_storage.recover_memory(role_id) - assert len(messages) == 0 + memory_storage.recover_memory(role_id) memory_storage.add(message) assert memory_storage.is_initialized is True sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game") sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) - new_messages = memory_storage.search_dissimilar(sim_message) + new_messages = await memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] new_conent = text_embed_arr[6].get( "text", "Incorporate basic features of a snake game such as scoring and increasing difficulty" ) new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) - new_messages = memory_storage.search_dissimilar(new_message) + new_messages = await memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content memory_storage.clean()