mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev
This commit is contained in:
commit
fe1d60f111
5 changed files with 97 additions and 24 deletions
|
|
@ -135,11 +135,3 @@ class Memory(BaseModel):
|
|||
continue
|
||||
rsp += self.index[action]
|
||||
return rsp
|
||||
|
||||
def get_by_tags(self, tags: list) -> list[Message]:
|
||||
"""Return messages with specified tags"""
|
||||
result = []
|
||||
for m in self.storage:
|
||||
if m.is_contain_tags(tags):
|
||||
result.append(m)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
|
|
@ -22,20 +24,30 @@ class MemoryStorage(FaissStore):
|
|||
The memory storage with Faiss as ANN search engine
|
||||
"""
|
||||
|
||||
def __init__(self, mem_ttl: int = MEM_TTL):
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
|
||||
self.role_id: str = None
|
||||
self.role_mem_path: str = None
|
||||
self.mem_ttl: int = mem_ttl # later use
|
||||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def recover_memory(self, role_id: str) -> List[Message]:
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
|
||||
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
|
||||
|
||||
def recover_memory(self, role_id: str) -> list[Message]:
|
||||
self.role_id = role_id
|
||||
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
|
||||
self.role_mem_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -52,16 +64,16 @@ class MemoryStorage(FaissStore):
|
|||
|
||||
return messages
|
||||
|
||||
def _get_index_and_store_fname(self):
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
if not self.role_mem_path:
|
||||
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
|
||||
return None, None
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
|
||||
return index_fpath, storage_fpath
|
||||
|
||||
def persist(self):
|
||||
super().persist()
|
||||
self.store.save_local(self.role_mem_path, self.role_id)
|
||||
logger.debug(f"Agent {self.role_id} persist memory into local")
|
||||
|
||||
def add(self, message: Message) -> bool:
|
||||
|
|
@ -77,7 +89,7 @@ class MemoryStorage(FaissStore):
|
|||
self.persist()
|
||||
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
|
||||
|
||||
def search_dissimilar(self, message: Message, k=4) -> List[Message]:
|
||||
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for dissimilar messages"""
|
||||
if not self.store:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
|
|
@ -14,11 +16,11 @@ from metagpt.schema import Message
|
|||
|
||||
def test_ltm_search():
|
||||
assert hasattr(CONFIG, "long_term_memory") is True
|
||||
openai_api_key = CONFIG.openai_api_key
|
||||
assert len(openai_api_key) > 20
|
||||
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
|
||||
assert len(CONFIG.openai_api_key) > 20
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
rc = RoleContext(watch=[UserRequirement])
|
||||
rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"})
|
||||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
|
|
|
|||
57
tests/metagpt/memory/test_memory.py
Normal file
57
tests/metagpt/memory/test_memory.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Memory
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_memory():
|
||||
memory = Memory()
|
||||
|
||||
message1 = Message(content="test message1", role="user1")
|
||||
message2 = Message(content="test message2", role="user2")
|
||||
message3 = Message(content="test message3", role="user1")
|
||||
memory.add(message1)
|
||||
assert memory.count() == 1
|
||||
|
||||
memory.delete_newest()
|
||||
assert memory.count() == 0
|
||||
|
||||
memory.add_batch([message1, message2])
|
||||
assert memory.count() == 2
|
||||
assert len(memory.index.get(message1.cause_by)) == 2
|
||||
|
||||
messages = memory.get_by_role("user1")
|
||||
assert messages[0].content == message1.content
|
||||
|
||||
messages = memory.get_by_content("test message")
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get_by_action(UserRequirement)
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get_by_actions([UserRequirement])
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.try_remember("test message")
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get(k=1)
|
||||
assert len(messages) == 1
|
||||
|
||||
messages = memory.get(k=5)
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.find_news([message3])
|
||||
assert len(messages) == 1
|
||||
|
||||
memory.delete(message1)
|
||||
assert memory.count() == 1
|
||||
messages = memory.get_by_role("user2")
|
||||
assert messages[0].content == message2.content
|
||||
|
||||
memory.clear()
|
||||
assert memory.count() == 0
|
||||
assert len(memory.index) == 0
|
||||
|
|
@ -4,20 +4,28 @@
|
|||
@Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = "Write a cli snake game"
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"))
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
|
|
@ -27,12 +35,12 @@ def test_idea_message():
|
|||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
|
|
@ -50,6 +58,8 @@ def test_actionout_message():
|
|||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"))
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
|
|
@ -59,12 +69,12 @@ def test_actionout_message():
|
|||
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue