mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
memory_storage use rag_pipeline
This commit is contained in:
parent
93e61ec2da
commit
f14fee9b0d
7 changed files with 111 additions and 19 deletions
|
|
@ -32,7 +32,7 @@ class LongTermMemory(Memory):
|
|||
self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
if not self.memory_storage.is_initialized:
|
||||
logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
|
||||
logger.warning(f"It may the first time to run Role {role_id}, the long-term memory is empty")
|
||||
else:
|
||||
logger.warning(f"Role {role_id} has existing memory storage and has recovered them.")
|
||||
self.msg_from_recover = True
|
||||
|
|
@ -66,6 +66,9 @@ class LongTermMemory(Memory):
|
|||
ltm_news.append(mem)
|
||||
return ltm_news[-k:]
|
||||
|
||||
def persit(self):
|
||||
self.memory_storage.persit()
|
||||
|
||||
def delete(self, message: Message):
|
||||
super().delete(message)
|
||||
# TODO delete message in memory_storage
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class MemoryStorage(object):
|
|||
|
||||
if self.role_mem_path.joinpath("default__vector_store.json").exists():
|
||||
self.faiss_engine = SimpleEngine.from_index(
|
||||
index_config=[FAISSIndexConfig(persist_path=self.cache_dir)],
|
||||
index_config=FAISSIndexConfig(persist_path=self.cache_dir),
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
embed_model=self.embedding,
|
||||
)
|
||||
|
|
@ -73,3 +73,7 @@ class MemoryStorage(object):
|
|||
def clean(self):
|
||||
shutil.rmtree(self.cache_dir, ignore_errors=True)
|
||||
self._initialized = False
|
||||
|
||||
def persit(self):
|
||||
if self.faiss_engine:
|
||||
self.faiss_engine.index.storage_context.persist(self.cache_dir)
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ class ObjectNode(TextNode):
|
|||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
|
||||
self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys
|
||||
|
||||
@staticmethod
|
||||
def get_obj_metadata(obj: RAGObject) -> dict:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
@File : test_faiss_store.py
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
|
|
@ -14,8 +16,23 @@ from metagpt.logs import logger
|
|||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
|
||||
num = len(texts)
|
||||
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
|
||||
embeds = (embeds - embeds.mean(axis=0)) / embeds.std(axis=0)
|
||||
return embeds.tolist()
|
||||
|
||||
|
||||
def mock_openai_embed_document(self, text: str) -> list[float]:
|
||||
embeds = mock_openai_embed_documents(self, [text])
|
||||
return embeds[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_json():
|
||||
async def test_search_json(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
|
|
@ -24,7 +41,10 @@ async def test_search_json():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_xlsx():
|
||||
async def test_search_xlsx(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
|
|
@ -33,7 +53,10 @@ async def test_search_xlsx():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write():
|
||||
async def test_write(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.storage_context.docstore
|
||||
|
|
|
|||
42
tests/metagpt/memory/mock_text_embed.py
Normal file
42
tests/metagpt/memory/mock_text_embed.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import numpy as np
|
||||
|
||||
dim = 1536 # openai embedding dim
|
||||
embed_zeros_arrr = np.zeros(shape=[1, dim]).tolist()
|
||||
embed_ones_arrr = np.ones(shape=[1, dim]).tolist()
|
||||
|
||||
text_embed_arr = [
|
||||
{"text": "Write a cli snake game", "embed": embed_zeros_arrr}, # mock data, same as below
|
||||
{"text": "Write a game of cli snake", "embed": embed_zeros_arrr},
|
||||
{"text": "Write a 2048 web game", "embed": embed_ones_arrr},
|
||||
{"text": "Write a Battle City", "embed": embed_ones_arrr},
|
||||
{
|
||||
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
|
||||
"embed": embed_zeros_arrr,
|
||||
},
|
||||
{"text": "The request is command-line interface (CLI) snake game", "embed": embed_zeros_arrr},
|
||||
{
|
||||
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
|
||||
"embed": embed_ones_arrr,
|
||||
},
|
||||
]
|
||||
|
||||
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
|
||||
idx = text_idx_dict.get(texts[0])
|
||||
embed = text_embed_arr[idx].get("embed")
|
||||
return embed
|
||||
|
||||
|
||||
def mock_openai_embed_document(self, text: str) -> list[float]:
|
||||
embeds = mock_openai_embed_documents(self, [text])
|
||||
return embeds[0]
|
||||
|
||||
|
||||
async def mock_openai_aembed_document(self, text: str) -> list[float]:
|
||||
return mock_openai_embed_document(self, text)
|
||||
|
|
@ -13,12 +13,17 @@ from metagpt.config2 import config
|
|||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.memory.mock_text_embed import text_embed_arr
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ltm_search():
|
||||
async def test_ltm_search(mocker):
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
||||
|
|
@ -28,24 +33,25 @@ async def test_ltm_search():
|
|||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
idea = "Write a cli snake game"
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = await ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
assert len(news) == 0
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = await ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = await ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
ltm.persit()
|
||||
|
||||
# restore from local index
|
||||
ltm_new = LongTermMemory()
|
||||
|
|
@ -57,7 +63,7 @@ async def test_ltm_search():
|
|||
news = await ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = "Write a Battle City"
|
||||
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = await ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
|
|
|||
|
|
@ -17,13 +17,18 @@ from metagpt.config2 import config
|
|||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.memory.mock_text_embed import text_embed_arr
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idea_message():
|
||||
idea = "Write a cli snake game"
|
||||
async def test_idea_message(mocker):
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
|
||||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
|
|
@ -35,12 +40,12 @@ async def test_idea_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = idea # "Write a game of cli snake"
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = await memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = await memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
|
@ -50,13 +55,19 @@ async def test_idea_message():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actionout_message():
|
||||
async def test_actionout_message(mocker):
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
role_id = "UTUser2(Architect)"
|
||||
content = "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
content = text_embed_arr[4].get(
|
||||
"text", "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
)
|
||||
message = Message(
|
||||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
|
@ -69,12 +80,14 @@ async def test_actionout_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = await memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_conent = text_embed_arr[6].get(
|
||||
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
)
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = await memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue