mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-03 12:52:37 +02:00
Merge branch 'feat_memory' of github.com:better629/MetaGPT into feat_memory
This commit is contained in:
commit
8411388ccf
3 changed files with 20 additions and 74 deletions
|
|
@ -1,25 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : memory mechanism including store/retrieval/rank
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.memory.memory_network import MemoryNetwork
|
||||
from metagpt.memory.schema import MemoryNode
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
mem_network: Optional[MemoryNetwork] = Field(
|
||||
default_factory=MemoryNetwork, description="the network to store memory"
|
||||
)
|
||||
|
||||
def add_msg(self, message: Message):
|
||||
mem_node = MemoryNode.create_mem_node_from_message(message)
|
||||
self.mem_network.add_mem(mem_node)
|
||||
|
||||
def add_msgs(self, messages: list[Message]):
|
||||
for msg in messages:
|
||||
self.add_msg(msg)
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the memory network to store memory segment
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.memory.schema import MemoryNode, MemorySegment
|
||||
|
||||
|
||||
class MemoryNetwork(BaseModel):
|
||||
mem_seg: MemorySegment = Field(
|
||||
default_factory=MemorySegment, description="the memory segment to store memory nodes"
|
||||
)
|
||||
|
||||
def add_mem(self, mem_node: MemoryNode):
|
||||
self.mem_seg.add_mem_node(mem_node)
|
||||
|
||||
def add_mems(self, mem_nodes: list[MemoryNode]):
|
||||
for mem_node in mem_nodes:
|
||||
self.add_mem(mem_node)
|
||||
|
|
@ -3,16 +3,17 @@
|
|||
"""
|
||||
@Desc : the implement of memory storage
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from llama_index.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import QueryBundle, TextNode
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import deserialize_message, serialize_message
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
||||
class MemoryStorage(FaissStore):
|
||||
|
|
@ -26,6 +27,7 @@ class MemoryStorage(FaissStore):
|
|||
self.mem_ttl: int = mem_ttl # later use
|
||||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
self.embedding = embedding or get_embedding()
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
|
|
@ -35,6 +37,7 @@ class MemoryStorage(FaissStore):
|
|||
self.role_id = role_id
|
||||
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
|
||||
self.role_mem_path.mkdir(parents=True, exist_ok=True)
|
||||
self.cache_dir = self.role_mem_path
|
||||
|
||||
self.store = self._load()
|
||||
messages = []
|
||||
|
|
@ -43,34 +46,22 @@ class MemoryStorage(FaissStore):
|
|||
pass
|
||||
else:
|
||||
for _id, document in self.store.docstore._dict.items():
|
||||
messages.append(deserialize_message(document.metadata.get("message_ser")))
|
||||
messages.append(Message(**document.metadata.get("obj_dict")))
|
||||
self._initialized = True
|
||||
|
||||
return messages
|
||||
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
if not self.role_mem_path:
|
||||
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
|
||||
return None, None
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
|
||||
self.cache_dir = Path(self.role_mem_path).joinpath(self.role_id)
|
||||
return index_fpath, storage_fpath
|
||||
|
||||
def persist(self):
|
||||
self.store.save_local(self.role_mem_path, self.role_id)
|
||||
logger.debug(f"Agent {self.role_id} persist memory into local")
|
||||
|
||||
def add(self, message: Message) -> bool:
|
||||
"""add message into memory storage"""
|
||||
docs = [message.content]
|
||||
metadatas = [{"message_ser": serialize_message(message)}]
|
||||
metadatas = [{"obj_dict": message.model_dump()}]
|
||||
if not self.store:
|
||||
# init Faiss
|
||||
self.store = self._write(docs, metadatas)
|
||||
self._initialized = True
|
||||
else:
|
||||
self.store.add_texts(texts=docs, metadatas=metadatas)
|
||||
text_node = TextNode(text=message.content, metadata=metadatas[0])
|
||||
self.store.insert_nodes([text_node])
|
||||
self.persist()
|
||||
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
|
||||
|
||||
|
|
@ -79,25 +70,25 @@ class MemoryStorage(FaissStore):
|
|||
if not self.store:
|
||||
return []
|
||||
|
||||
resp = self.store.similarity_search_with_score(query=message.content, k=k)
|
||||
retriever = self.store.as_retriever(similarity_top_k=k)
|
||||
resp = retriever.retrieve(
|
||||
QueryBundle(query_str=message.content, embedding=self.embedding.get_text_embedding(message.content))
|
||||
)
|
||||
# filter the result which score is smaller than the threshold
|
||||
filtered_resp = []
|
||||
for item, score in resp:
|
||||
for item in resp:
|
||||
# the smaller score means more similar relation
|
||||
if score < self.threshold:
|
||||
|
||||
if item.score < self.threshold:
|
||||
continue
|
||||
# convert search result into Memory
|
||||
metadata = item.metadata
|
||||
new_mem = deserialize_message(metadata.get("message_ser"))
|
||||
metadata = item.node.metadata
|
||||
new_mem = Message(**metadata.get("obj_dict", {}))
|
||||
filtered_resp.append(new_mem)
|
||||
return filtered_resp
|
||||
|
||||
def clean(self):
|
||||
index_fpath, storage_fpath = self._get_index_and_store_fname()
|
||||
if index_fpath and index_fpath.exists():
|
||||
index_fpath.unlink(missing_ok=True)
|
||||
if storage_fpath and storage_fpath.exists():
|
||||
storage_fpath.unlink(missing_ok=True)
|
||||
shutil.rmtree(self.cache_dir, ignore_errors=True)
|
||||
|
||||
self.store = None
|
||||
self._initialized = False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue