From f14fee9b0d0d2d35736ab3fbd14f7896a990209c Mon Sep 17 00:00:00 2001 From: betterwang Date: Thu, 7 Mar 2024 22:07:04 +0800 Subject: [PATCH] memory_storage use rag_pipeline --- metagpt/memory/longterm_memory.py | 5 ++- metagpt/memory/memory_storage.py | 6 ++- metagpt/rag/schema.py | 1 + .../document_store/test_faiss_store.py | 29 +++++++++++-- tests/metagpt/memory/mock_text_embed.py | 42 +++++++++++++++++++ tests/metagpt/memory/test_longterm_memory.py | 18 +++++--- tests/metagpt/memory/test_memory_storage.py | 29 +++++++++---- 7 files changed, 111 insertions(+), 19 deletions(-) create mode 100644 tests/metagpt/memory/mock_text_embed.py diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index e90413085..27a737e6c 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -32,7 +32,7 @@ class LongTermMemory(Memory): self.memory_storage.recover_memory(role_id) self.rc = rc if not self.memory_storage.is_initialized: - logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty") + logger.warning(f"It may the first time to run Role {role_id}, the long-term memory is empty") else: logger.warning(f"Role {role_id} has existing memory storage and has recovered them.") self.msg_from_recover = True @@ -66,6 +66,9 @@ class LongTermMemory(Memory): ltm_news.append(mem) return ltm_news[-k:] + def persit(self): + self.memory_storage.persit() + def delete(self, message: Message): super().delete(message) # TODO delete message in memory_storage diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index b7d49e1c3..706e75c5a 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -43,7 +43,7 @@ class MemoryStorage(object): if self.role_mem_path.joinpath("default__vector_store.json").exists(): self.faiss_engine = SimpleEngine.from_index( - index_config=[FAISSIndexConfig(persist_path=self.cache_dir)], + index_config=FAISSIndexConfig(persist_path=self.cache_dir), retriever_configs=[FAISSRetrieverConfig()], embed_model=self.embedding, ) @@ -73,3 +73,7 @@ class MemoryStorage(object): def clean(self): shutil.rmtree(self.cache_dir, ignore_errors=True) self._initialized = False + + def persit(self): + if self.faiss_engine: + self.faiss_engine.index.storage_context.persist(self.cache_dir) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 9657ae846..8f5828233 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -104,6 +104,7 @@ class ObjectNode(TextNode): def __init__(self, **kwargs): super().__init__(**kwargs) self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys()) + self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys @staticmethod def get_obj_metadata(obj: RAGObject) -> dict: diff --git a/tests/metagpt/document_store/test_faiss_store.py b/tests/metagpt/document_store/test_faiss_store.py index f7032be9f..6443a179c 100644 --- a/tests/metagpt/document_store/test_faiss_store.py +++ b/tests/metagpt/document_store/test_faiss_store.py @@ -6,6 +6,8 @@ @File : test_faiss_store.py """ + +import numpy as np import pytest from metagpt.const import EXAMPLE_PATH @@ -14,8 +16,23 @@ from metagpt.logs import logger from metagpt.roles import Sales +def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]: + num = len(texts) + embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim + embeds = (embeds - embeds.mean(axis=0)) / embeds.std(axis=0) + return embeds.tolist() + + +def mock_openai_embed_document(self, text: str) -> list[float]: + embeds = mock_openai_embed_documents(self, [text]) + return embeds[0] + + @pytest.mark.asyncio -async def test_search_json(): +async def test_search_json(mocker): + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.json") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" @@ -24,7 +41,10 @@ async def test_search_json(): @pytest.mark.asyncio -async def test_search_xlsx(): +async def test_search_xlsx(mocker): + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" @@ -33,7 +53,10 @@ async def test_search_xlsx(): @pytest.mark.asyncio -async def test_write(): +async def test_write(mocker): + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question") _faiss_store = store.write() assert _faiss_store.storage_context.docstore diff --git a/tests/metagpt/memory/mock_text_embed.py b/tests/metagpt/memory/mock_text_embed.py new file mode 100644 index 000000000..2f3ffc434 --- /dev/null +++ b/tests/metagpt/memory/mock_text_embed.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import numpy as np + +dim = 1536 # openai embedding dim +embed_zeros_arrr = np.zeros(shape=[1, dim]).tolist() +embed_ones_arrr = np.ones(shape=[1, dim]).tolist() + +text_embed_arr = [ + {"text": "Write a cli snake game", "embed": embed_zeros_arrr}, # mock data, same as below + {"text": "Write a game of cli snake", "embed": embed_zeros_arrr}, + {"text": "Write a 2048 web game", "embed": embed_ones_arrr}, + {"text": "Write a Battle City", "embed": embed_ones_arrr}, + { + "text": "The user has requested the creation of a command-line interface (CLI) snake game", + "embed": embed_zeros_arrr, + }, + {"text": "The request is command-line interface (CLI) snake game", "embed": embed_zeros_arrr}, + { + "text": "Incorporate basic features of a snake game such as scoring and increasing difficulty", + "embed": embed_ones_arrr, + }, +] + +text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)} + + +def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]: + idx = text_idx_dict.get(texts[0]) + embed = text_embed_arr[idx].get("embed") + return embed + + +def mock_openai_embed_document(self, text: str) -> list[float]: + embeds = mock_openai_embed_documents(self, [text]) + return embeds[0] + + +async def mock_openai_aembed_document(self, text: str) -> list[float]: + return mock_openai_embed_document(self, text) diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index 08bae4d91..8af0fb5cf 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -13,12 +13,17 @@ from metagpt.config2 import config from metagpt.memory.longterm_memory import LongTermMemory from metagpt.roles.role import RoleContext from metagpt.schema import Message +from tests.metagpt.memory.mock_text_embed import text_embed_arr os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key) @pytest.mark.asyncio -async def test_ltm_search(): +async def test_ltm_search(mocker): + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document) + role_id = "UTUserLtm(Product Manager)" from metagpt.environment import Environment @@ -28,24 +33,25 @@ async def test_ltm_search(): ltm = LongTermMemory() ltm.recover_memory(role_id, rc) - idea = "Write a cli snake game" + idea = text_embed_arr[0].get("text", "Write a cli snake game") message = Message(role="User", content=idea, cause_by=UserRequirement) news = await ltm.find_news([message]) - assert len(news) == 1 + assert len(news) == 0 ltm.add(message) - sim_idea = "Write a game of cli snake" + sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) news = await ltm.find_news([sim_message]) assert len(news) == 0 ltm.add(sim_message) - new_idea = "Write a 2048 web game" + new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) news = await ltm.find_news([new_message]) assert len(news) == 1 ltm.add(new_message) + ltm.persit() # restore from local index ltm_new = LongTermMemory() @@ -57,7 +63,7 @@ async def test_ltm_search(): news = await ltm_new.find_news([sim_message]) assert len(news) == 0 - new_idea = "Write a Battle City" + new_idea = text_embed_arr[3].get("text", "Write a Battle City") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) news = await ltm_new.find_news([new_message]) assert len(news) == 1 diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index b989df2fb..efb2b4eed 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -17,13 +17,18 @@ from metagpt.config2 import config from metagpt.const import DATA_PATH from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message +from tests.metagpt.memory.mock_text_embed import text_embed_arr os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key) @pytest.mark.asyncio -async def test_idea_message(): - idea = "Write a cli snake game" +async def test_idea_message(mocker): + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document) + + idea = text_embed_arr[0].get("text", "Write a cli snake game") role_id = "UTUser1(Product Manager)" message = Message(role="User", content=idea, cause_by=UserRequirement) @@ -35,12 +40,12 @@ async def test_idea_message(): memory_storage.add(message) assert memory_storage.is_initialized is True - sim_idea = idea # "Write a game of cli snake" + sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) new_messages = await memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] - new_idea = "Write a 2048 web game" + new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) new_messages = await memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content @@ -50,13 +55,19 @@ async def test_idea_message(): @pytest.mark.asyncio -async def test_actionout_message(): +async def test_actionout_message(mocker): + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document) + out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} ic_obj = ActionNode.create_model_class("prd", out_mapping) role_id = "UTUser2(Architect)" - content = "The user has requested the creation of a command-line interface (CLI) snake game" + content = text_embed_arr[4].get( + "text", "The user has requested the creation of a command-line interface (CLI) snake game" + ) message = Message( content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD ) # WritePRD as test action @@ -69,12 +80,14 @@ async def test_actionout_message(): memory_storage.add(message) assert memory_storage.is_initialized is True - sim_conent = "The request is command-line interface (CLI) snake game" + sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game") sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) new_messages = await memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] - new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty" + new_conent = text_embed_arr[6].get( + "text", "Incorporate basic features of a snake game such as scoring and increasing difficulty" + ) new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) new_messages = await memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content