Merge branch 'feat_memory' of github.com:better629/MetaGPT into feat_memory

This commit is contained in:
seehi 2024-03-07 17:56:51 +08:00
commit 8411388ccf
3 changed files with 20 additions and 74 deletions

View file

@ -1,25 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : memory mechanism including store/retrieval/rank
from typing import Optional
from pydantic import BaseModel, Field
from metagpt.memory.memory_network import MemoryNetwork
from metagpt.memory.schema import MemoryNode
from metagpt.schema import Message
class Memory(BaseModel):
mem_network: Optional[MemoryNetwork] = Field(
default_factory=MemoryNetwork, description="the network to store memory"
)
def add_msg(self, message: Message):
mem_node = MemoryNode.create_mem_node_from_message(message)
self.mem_network.add_mem(mem_node)
def add_msgs(self, messages: list[Message]):
for msg in messages:
self.add_msg(msg)

View file

@ -1,20 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the memory network to store memory segment
from pydantic import BaseModel, Field
from metagpt.memory.schema import MemoryNode, MemorySegment
class MemoryNetwork(BaseModel):
mem_seg: MemorySegment = Field(
default_factory=MemorySegment, description="the memory segment to store memory nodes"
)
def add_mem(self, mem_node: MemoryNode):
self.mem_seg.add_mem_node(mem_node)
def add_mems(self, mem_nodes: list[MemoryNode]):
for mem_node in mem_nodes:
self.add_mem(mem_node)

View file

@ -3,16 +3,17 @@
"""
@Desc : the implement of memory storage
"""
import shutil
from pathlib import Path
from llama_index.embeddings import BaseEmbedding
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import QueryBundle, TextNode
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.serialize import deserialize_message, serialize_message
from metagpt.utils.embedding import get_embedding
class MemoryStorage(FaissStore):
@ -26,6 +27,7 @@ class MemoryStorage(FaissStore):
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or get_embedding()
@property
def is_initialized(self) -> bool:
@ -35,6 +37,7 @@ class MemoryStorage(FaissStore):
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
self.cache_dir = self.role_mem_path
self.store = self._load()
messages = []
@ -43,34 +46,22 @@ class MemoryStorage(FaissStore):
pass
else:
for _id, document in self.store.docstore._dict.items():
messages.append(deserialize_message(document.metadata.get("message_ser")))
messages.append(Message(**document.metadata.get("obj_dict")))
self._initialized = True
return messages
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
if not self.role_mem_path:
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
self.cache_dir = Path(self.role_mem_path).joinpath(self.role_id)
return index_fpath, storage_fpath
def persist(self):
self.store.save_local(self.role_mem_path, self.role_id)
logger.debug(f"Agent {self.role_id} persist memory into local")
def add(self, message: Message) -> bool:
"""add message into memory storage"""
docs = [message.content]
metadatas = [{"message_ser": serialize_message(message)}]
metadatas = [{"obj_dict": message.model_dump()}]
if not self.store:
# init Faiss
self.store = self._write(docs, metadatas)
self._initialized = True
else:
self.store.add_texts(texts=docs, metadatas=metadatas)
text_node = TextNode(text=message.content, metadata=metadatas[0])
self.store.insert_nodes([text_node])
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
@ -79,25 +70,25 @@ class MemoryStorage(FaissStore):
if not self.store:
return []
resp = self.store.similarity_search_with_score(query=message.content, k=k)
retriever = self.store.as_retriever(similarity_top_k=k)
resp = retriever.retrieve(
QueryBundle(query_str=message.content, embedding=self.embedding.get_text_embedding(message.content))
)
# filter the result which score is smaller than the threshold
filtered_resp = []
for item, score in resp:
for item in resp:
# the smaller score means more similar relation
if score < self.threshold:
if item.score < self.threshold:
continue
# convert search result into Memory
metadata = item.metadata
new_mem = deserialize_message(metadata.get("message_ser"))
metadata = item.node.metadata
new_mem = Message(**metadata.get("obj_dict", {}))
filtered_resp.append(new_mem)
return filtered_resp
def clean(self):
index_fpath, storage_fpath = self._get_index_and_store_fname()
if index_fpath and index_fpath.exists():
index_fpath.unlink(missing_ok=True)
if storage_fpath and storage_fpath.exists():
storage_fpath.unlink(missing_ok=True)
shutil.rmtree(self.cache_dir, ignore_errors=True)
self.store = None
self._initialized = False