diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 27a737e6c..62d1dfd76 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -61,8 +61,8 @@ class LongTermMemory(Memory): ltm_news: list[Message] = [] for mem in stm_news: # filter out messages similar to those seen previously in ltm, only keep fresh news - mem_searched = await self.memory_storage.search_dissimilar(mem) - if len(mem_searched) > 0: + mem_searched = await self.memory_storage.search_similar(mem) + if len(mem_searched) == 0: ltm_news.append(mem) return ltm_news[-k:] diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index 706e75c5a..44b03cda3 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -58,16 +58,14 @@ class MemoryStorage(object): self.faiss_engine.add_objs([message]) logger.info(f"Role {self.role_id}'s memory_storage add a message") - async def search_dissimilar(self, message: Message, k=4) -> list[Message]: - """search for dissimilar messages""" + async def search_similar(self, message: Message, k=4) -> list[Message]: + """search for similar messages""" # filter the result which score is smaller than the threshold filtered_resp = [] resp = await self.faiss_engine.aretrieve(message.content) for item in resp: - print(" item.score ", item.score, item) if item.score < self.threshold: - continue - filtered_resp.append(item.metadata.get("obj")) + filtered_resp.append(item.metadata.get("obj")) return filtered_resp def clean(self): @@ -76,4 +74,4 @@ class MemoryStorage(object): def persit(self): if self.faiss_engine: - self.faiss_engine.index.storage_context.persist(self.cache_dir) + self.faiss_engine.retriever._index.storage_context.persist(self.cache_dir) diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index d1008081c..50b286cdc 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -29,10 +29,8 @@ class RAGIndexFactory(ConfigFactory): embed_model = self.extract_embed_model(config, **kwargs) vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path)) - storage_context = StorageContext.from_defaults( - vector_store=vector_store, persist_dir=config.persist_path, embed_mode=embed_model - ) - index = load_index_from_storage(storage_context=storage_context) + storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path) + index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) return index def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index 8af0fb5cf..990017fee 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -4,25 +4,28 @@ @Desc : unittest of `metagpt/memory/longterm_memory.py` """ -import os import pytest from metagpt.actions import UserRequirement -from metagpt.config2 import config from metagpt.memory.longterm_memory import LongTermMemory from metagpt.roles.role import RoleContext from metagpt.schema import Message -from tests.metagpt.memory.mock_text_embed import text_embed_arr - -os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key) +from tests.metagpt.memory.mock_text_embed import ( + mock_openai_aembed_document, + mock_openai_embed_document, + mock_openai_embed_documents, + text_embed_arr, +) @pytest.mark.asyncio async def test_ltm_search(mocker): - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + mocker.patch( + "llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document + ) role_id = "UTUserLtm(Product Manager)" from metagpt.environment import Environment @@ -36,7 +39,7 @@ async def test_ltm_search(mocker): idea = text_embed_arr[0].get("text", "Write a cli snake game") message = Message(role="User", content=idea, cause_by=UserRequirement) news = await ltm.find_news([message]) - assert len(news) == 0 + assert len(news) == 1 ltm.add(message) sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") @@ -51,24 +54,8 @@ async def test_ltm_search(mocker): news = await ltm.find_news([new_message]) assert len(news) == 1 ltm.add(new_message) - ltm.persit() - # restore from local index - ltm_new = LongTermMemory() - ltm_new.recover_memory(role_id, rc) - news = await ltm_new.find_news([message]) - assert len(news) == 0 - - ltm_new.recover_memory(role_id, rc) - news = await ltm_new.find_news([sim_message]) - assert len(news) == 0 - - new_idea = text_embed_arr[3].get("text", "Write a Battle City") - new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) - news = await ltm_new.find_news([new_message]) - assert len(news) == 1 - - ltm_new.clear() + ltm.clear() if __name__ == "__main__": diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index efb2b4eed..09671aaab 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -4,7 +4,6 @@ @Desc : the unittests of metagpt/memory/memory_storage.py """ -import os import shutil from pathlib import Path from typing import List @@ -13,20 +12,24 @@ import pytest from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.action_node import ActionNode -from metagpt.config2 import config from metagpt.const import DATA_PATH from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message -from tests.metagpt.memory.mock_text_embed import text_embed_arr - -os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key) +from tests.metagpt.memory.mock_text_embed import ( + mock_openai_aembed_document, + mock_openai_embed_document, + mock_openai_embed_documents, + text_embed_arr, +) @pytest.mark.asyncio async def test_idea_message(mocker): - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + mocker.patch( + "llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document + ) idea = text_embed_arr[0].get("text", "Write a cli snake game") role_id = "UTUser1(Product Manager)" @@ -42,13 +45,13 @@ async def test_idea_message(mocker): sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) - new_messages = await memory_storage.search_dissimilar(sim_message) - assert len(new_messages) == 0 # similar, return [] + new_messages = await memory_storage.search_similar(sim_message) + assert len(new_messages) == 1 # similar, return [] new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) - new_messages = await memory_storage.search_dissimilar(new_message) - assert new_messages[0].content == message.content + new_messages = await memory_storage.search_similar(new_message) + assert len(new_messages) == 0 memory_storage.clean() assert memory_storage.is_initialized is False @@ -56,9 +59,11 @@ async def test_idea_message(mocker): @pytest.mark.asyncio async def test_actionout_message(mocker): - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) - # mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents) + mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document) + mocker.patch( + "llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document + ) out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} @@ -82,15 +87,15 @@ async def test_actionout_message(mocker): sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game") sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) - new_messages = await memory_storage.search_dissimilar(sim_message) - assert len(new_messages) == 0 # similar, return [] + new_messages = await memory_storage.search_similar(sim_message) + assert len(new_messages) == 1 # similar, return [] new_conent = text_embed_arr[6].get( "text", "Incorporate basic features of a snake game such as scoring and increasing difficulty" ) new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) - new_messages = await memory_storage.search_dissimilar(new_message) - assert new_messages[0].content == message.content + new_messages = await memory_storage.search_similar(new_message) + assert len(new_messages) == 0 memory_storage.clean() assert memory_storage.is_initialized is False