mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
memory_storage use rag_engine
This commit is contained in:
parent
d289dad8b3
commit
57a1fac357
7 changed files with 59 additions and 73 deletions
|
|
@ -38,9 +38,9 @@ class LocalStore(BaseStore, ABC):
|
|||
if not self.store:
|
||||
self.store = self.write()
|
||||
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
index_file = self.cache_dir / "default__vector_store.json"
|
||||
store_file = self.cache_dir / "docstore.json"
|
||||
def _get_index_and_store_fname(self, index_ext=".json", docstore_ext=".json"):
|
||||
index_file = self.cache_dir / "default__vector_store" / index_ext
|
||||
store_file = self.cache_dir / "docstore" / docstore_ext
|
||||
return index_file, store_file
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class FaissStore(LocalStore):
|
|||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["VectorStoreIndex"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # FAISS using .faiss
|
||||
index_file, store_file = self._get_index_and_store_fname()
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
|
|
@ -46,12 +46,8 @@ class FaissStore(LocalStore):
|
|||
|
||||
def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex:
|
||||
assert len(docs) == len(metadatas)
|
||||
texts_embeds = self.embedding.get_text_embedding_batch(docs)
|
||||
documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)]
|
||||
|
||||
[TextNode(embedding=embed, metadata=metadatas[idx]) for idx, embed in enumerate(texts_embeds)]
|
||||
# doc_store = SimpleDocumentStore()
|
||||
# doc_store.add_documents(nodes)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
|
|
@ -90,7 +86,7 @@ class FaissStore(LocalStore):
|
|||
def add(self, texts: list[str], *args, **kwargs) -> list[str]:
|
||||
"""FIXME: Currently, the store is not updated after adding."""
|
||||
texts_embeds = self.embedding.get_text_embedding_batch(texts)
|
||||
nodes = [TextNode(embedding=embed) for embed in texts_embeds]
|
||||
nodes = [TextNode(text=texts[idx], embedding=embed) for idx, embed in enumerate(texts_embeds)]
|
||||
self.store.insert_nodes(nodes)
|
||||
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -29,16 +29,14 @@ class LongTermMemory(Memory):
|
|||
msg_from_recover: bool = False
|
||||
|
||||
def recover_memory(self, role_id: str, rc: RoleContext):
|
||||
messages = self.memory_storage.recover_memory(role_id)
|
||||
self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
if not self.memory_storage.is_initialized:
|
||||
logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Agent {role_id} has existing memory storage with {len(messages)} messages " f"and has recovered them."
|
||||
)
|
||||
logger.warning(f"Role {role_id} has existing memory storage and has recovered them.")
|
||||
self.msg_from_recover = True
|
||||
self.add_batch(messages)
|
||||
# self.add_batch(messages) # TODO no need
|
||||
self.msg_from_recover = False
|
||||
|
||||
def add(self, message: Message):
|
||||
|
|
@ -49,7 +47,7 @@ class LongTermMemory(Memory):
|
|||
# and ignore adding messages from recover repeatedly
|
||||
self.memory_storage.add(message)
|
||||
|
||||
def find_news(self, observed: list[Message], k=0) -> list[Message]:
|
||||
async def find_news(self, observed: list[Message], k=0) -> list[Message]:
|
||||
"""
|
||||
find news (previously unseen messages) from the the most recent k memories, from all memories when k=0
|
||||
1. find the short-term memory(stm) news
|
||||
|
|
@ -63,7 +61,7 @@ class LongTermMemory(Memory):
|
|||
ltm_news: list[Message] = []
|
||||
for mem in stm_news:
|
||||
# filter out messages similar to those seen previously in ltm, only keep fresh news
|
||||
mem_searched = self.memory_storage.search_dissimilar(mem)
|
||||
mem_searched = await self.memory_storage.search_dissimilar(mem)
|
||||
if len(mem_searched) > 0:
|
||||
ltm_news.append(mem)
|
||||
return ltm_news[-k:]
|
||||
|
|
|
|||
|
|
@ -7,16 +7,16 @@ import shutil
|
|||
from pathlib import Path
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import QueryBundle, TextNode
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
||||
class MemoryStorage(FaissStore):
|
||||
class MemoryStorage(object):
|
||||
"""
|
||||
The memory storage with Faiss as ANN search engine
|
||||
"""
|
||||
|
|
@ -29,6 +29,8 @@ class MemoryStorage(FaissStore):
|
|||
self._initialized: bool = False
|
||||
self.embedding = embedding or get_embedding()
|
||||
|
||||
self.faiss_engine = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
|
@ -39,56 +41,35 @@ class MemoryStorage(FaissStore):
|
|||
self.role_mem_path.mkdir(parents=True, exist_ok=True)
|
||||
self.cache_dir = self.role_mem_path
|
||||
|
||||
self.store = self._load()
|
||||
messages = []
|
||||
if not self.store:
|
||||
# TODO init `self.store` under here with raw faiss api instead under `add`
|
||||
pass
|
||||
if self.role_mem_path.joinpath("default__vector_store.json").exists():
|
||||
self.faiss_engine = SimpleEngine.from_index(
|
||||
index_config=[FAISSIndexConfig(persist_path=self.cache_dir)],
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
embed_model=self.embedding,
|
||||
)
|
||||
else:
|
||||
for _id, document in self.store.docstore._dict.items():
|
||||
messages.append(Message(**document.metadata.get("obj_dict")))
|
||||
self._initialized = True
|
||||
|
||||
return messages
|
||||
self.faiss_engine = SimpleEngine.from_objs(
|
||||
objs=[], retriever_configs=[FAISSRetrieverConfig()], embed_model=self.embedding
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
def add(self, message: Message) -> bool:
|
||||
"""add message into memory storage"""
|
||||
docs = [message.content]
|
||||
metadatas = [{"obj_dict": message.model_dump()}]
|
||||
if not self.store:
|
||||
# init Faiss
|
||||
self.store = self._write(docs, metadatas)
|
||||
self._initialized = True
|
||||
else:
|
||||
text_node = TextNode(text=message.content, metadata=metadatas[0])
|
||||
self.store.insert_nodes([text_node])
|
||||
self.persist()
|
||||
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
|
||||
self.faiss_engine.add_objs([message])
|
||||
logger.info(f"Role {self.role_id}'s memory_storage add a message")
|
||||
|
||||
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
|
||||
async def search_dissimilar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for dissimilar messages"""
|
||||
if not self.store:
|
||||
return []
|
||||
|
||||
retriever = self.store.as_retriever(similarity_top_k=k)
|
||||
resp = retriever.retrieve(
|
||||
QueryBundle(query_str=message.content, embedding=self.embedding.get_text_embedding(message.content))
|
||||
)
|
||||
# filter the result which score is smaller than the threshold
|
||||
filtered_resp = []
|
||||
resp = await self.faiss_engine.aretrieve(message.content)
|
||||
for item in resp:
|
||||
# the smaller score means more similar relation
|
||||
|
||||
print(" item.score ", item.score, item)
|
||||
if item.score < self.threshold:
|
||||
continue
|
||||
# convert search result into Memory
|
||||
metadata = item.node.metadata
|
||||
new_mem = Message(**metadata.get("obj_dict", {}))
|
||||
filtered_resp.append(new_mem)
|
||||
filtered_resp.append(item.metadata.get("obj"))
|
||||
return filtered_resp
|
||||
|
||||
def clean(self):
|
||||
shutil.rmtree(self.cache_dir, ignore_errors=True)
|
||||
|
||||
self.store = None
|
||||
self._initialized = False
|
||||
|
|
|
|||
|
|
@ -233,6 +233,10 @@ class Message(BaseModel):
|
|||
def check_send_to(cls, send_to: Any) -> set:
|
||||
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
@field_serializer("send_to", mode="plain")
|
||||
def ser_send_to(self, send_to: set) -> list:
|
||||
return list(send_to)
|
||||
|
||||
@field_serializer("instruct_content", mode="plain")
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
|
||||
ic_dict = None
|
||||
|
|
@ -276,6 +280,10 @@ class Message(BaseModel):
|
|||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.content
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ from tests.metagpt.memory.mock_text_embed import (
|
|||
)
|
||||
|
||||
|
||||
def test_ltm_search(mocker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_ltm_search(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
|
|
@ -31,36 +32,36 @@ def test_ltm_search(mocker):
|
|||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([message])
|
||||
news = await ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
news = await ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
news = await ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
|
||||
# restore from local index
|
||||
ltm_new = LongTermMemory()
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.find_news([message])
|
||||
news = await ltm_new.find_news([message])
|
||||
assert len(news) == 0
|
||||
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.find_news([sim_message])
|
||||
news = await ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
news = await ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
ltm_new.clear()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.const import DATA_PATH
|
||||
|
|
@ -19,7 +21,8 @@ from tests.metagpt.memory.mock_text_embed import (
|
|||
)
|
||||
|
||||
|
||||
def test_idea_message(mocker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_idea_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
|
|
@ -29,27 +32,27 @@ def test_idea_message(mocker):
|
|||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
memory_storage.recover_memory(role_id)
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
new_messages = await memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
new_messages = await memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
||||
|
||||
def test_actionout_message(mocker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_actionout_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
|
|
@ -67,22 +70,21 @@ def test_actionout_message(mocker):
|
|||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
memory_storage.recover_memory(role_id)
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
new_messages = await memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = text_embed_arr[6].get(
|
||||
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
)
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
new_messages = await memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue