From 94a0699ec4c71a29359981bbd39fc90a92a1cbb8 Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 25 Dec 2023 13:50:47 +0800 Subject: [PATCH 1/3] add memory unittest --- metagpt/memory/memory.py | 8 ---- tests/metagpt/memory/test_memory.py | 57 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) create mode 100644 tests/metagpt/memory/test_memory.py diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index d964cc1dc..e9891ed00 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -129,11 +129,3 @@ class Memory(BaseModel): continue rsp += self.index[action] return rsp - - def get_by_tags(self, tags: list) -> list[Message]: - """Return messages with specified tags""" - result = [] - for m in self.storage: - if m.is_contain_tags(tags): - result.append(m) - return result diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/memory/test_memory.py new file mode 100644 index 000000000..36d7ad488 --- /dev/null +++ b/tests/metagpt/memory/test_memory.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of Memory + +from metagpt.actions import UserRequirement +from metagpt.memory.memory import Memory +from metagpt.schema import Message + + +def test_memory(): + memory = Memory() + + message1 = Message(content="test message1", role="user1") + message2 = Message(content="test message2", role="user2") + message3 = Message(content="test message3", role="user1") + memory.add(message1) + assert memory.count() == 1 + + memory.delete_newest() + assert memory.count() == 0 + + memory.add_batch([message1, message2]) + assert memory.count() == 2 + assert len(memory.index.get(message1.cause_by)) == 2 + + messages = memory.get_by_role("user1") + assert messages[0].content == message1.content + + messages = memory.get_by_content("test message") + assert len(messages) == 2 + + messages = memory.get_by_action(UserRequirement) + assert len(messages) == 2 + + messages = memory.get_by_actions([UserRequirement]) + assert len(messages) == 2 + + messages = memory.try_remember("test message") + assert len(messages) == 2 + + messages = memory.get(k=1) + assert len(messages) == 1 + + messages = memory.get(k=5) + assert len(messages) == 2 + + messages = memory.find_news([message3]) + assert len(messages) == 1 + + memory.delete(message1) + assert memory.count() == 1 + messages = memory.get_by_role("user2") + assert messages[0].content == message2.content + + memory.clear() + assert memory.count() == 0 + assert len(memory.index) == 0 From 6a65639cd7790f55dab143886f60aec8e0a032c1 Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 25 Dec 2023 14:38:20 +0800 Subject: [PATCH 2/3] update ltm unittest --- metagpt/memory/memory_storage.py | 28 ++++++++++++++------ tests/metagpt/memory/test_longterm_memory.py | 8 +++--- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index 3017c23ad..1850e0ea0 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -6,9 +6,11 @@ """ from pathlib import Path -from typing import List +from typing import Optional +from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores.faiss import FAISS +from langchain_core.embeddings import Embeddings from metagpt.const import DATA_PATH, MEM_TTL from metagpt.document_store.faiss_store import FaissStore @@ -22,20 +24,30 @@ class MemoryStorage(FaissStore): The memory storage with Faiss as ANN search engine """ - def __init__(self, mem_ttl: int = MEM_TTL): + def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None): self.role_id: str = None self.role_mem_path: str = None self.mem_ttl: int = mem_ttl # later use self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories self._initialized: bool = False + self.embedding = embedding or OpenAIEmbeddings() self.store: FAISS = None # Faiss engine @property def is_initialized(self) -> bool: return self._initialized - def recover_memory(self, role_id: str) -> List[Message]: + def _load(self) -> Optional["FaissStore"]: + index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss + + if not (index_file.exists() and store_file.exists()): + logger.info("Missing at least one of index_file/store_file, load failed and return None") + return None + + return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id) + + def recover_memory(self, role_id: str) -> list[Message]: self.role_id = role_id self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/") self.role_mem_path.mkdir(parents=True, exist_ok=True) @@ -52,16 +64,16 @@ class MemoryStorage(FaissStore): return messages - def _get_index_and_store_fname(self): + def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"): if not self.role_mem_path: logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory") return None, None - index_fpath = Path(self.role_mem_path / f"{self.role_id}.index") - storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl") + index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}") + storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}") return index_fpath, storage_fpath def persist(self): - super().persist() + self.store.save_local(self.role_mem_path, self.role_id) logger.debug(f"Agent {self.role_id} persist memory into local") def add(self, message: Message) -> bool: @@ -77,7 +89,7 @@ class MemoryStorage(FaissStore): self.persist() logger.info(f"Agent {self.role_id}'s memory_storage add a message") - def search_dissimilar(self, message: Message, k=4) -> List[Message]: + def search_dissimilar(self, message: Message, k=4) -> list[Message]: """search for dissimilar messages""" if not self.store: return [] diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index 1f07d74e3..ac33552b3 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -5,6 +5,8 @@ @Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ +import os + from metagpt.actions import UserRequirement from metagpt.config import CONFIG from metagpt.memory.longterm_memory import LongTermMemory @@ -14,11 +16,11 @@ from metagpt.schema import Message def test_ltm_search(): assert hasattr(CONFIG, "long_term_memory") is True - openai_api_key = CONFIG.openai_api_key - assert len(openai_api_key) > 20 + os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key) + assert len(CONFIG.openai_api_key) > 20 role_id = "UTUserLtm(Product Manager)" - rc = RoleContext(watch=[UserRequirement]) + rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"}) ltm = LongTermMemory() ltm.recover_memory(role_id, rc) From 44eec631ea18575b79b7e4638c5f244c1151400d Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 25 Dec 2023 14:47:42 +0800 Subject: [PATCH 3/3] update MemoryStorage unittest --- tests/metagpt/memory/test_memory_storage.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index 7b74eb512..f1cc12aac 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -4,20 +4,28 @@ @Desc : the unittests of metagpt/memory/memory_storage.py """ - +import os +import shutil +from pathlib import Path from typing import List from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.action_node import ActionNode +from metagpt.config import CONFIG +from metagpt.const import DATA_PATH from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message +os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key) + def test_idea_message(): idea = "Write a cli snake game" role_id = "UTUser1(Product Manager)" message = Message(role="User", content=idea, cause_by=UserRequirement) + shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/")) + memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) assert len(messages) == 0 @@ -27,12 +35,12 @@ def test_idea_message(): sim_idea = "Write a game of cli snake" sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) - new_messages = memory_storage.search(sim_message) + new_messages = memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] new_idea = "Write a 2048 web game" new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) - new_messages = memory_storage.search(new_message) + new_messages = memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content memory_storage.clean() @@ -50,6 +58,8 @@ def test_actionout_message(): content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD ) # WritePRD as test action + shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/")) + memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) assert len(messages) == 0 @@ -59,12 +69,12 @@ def test_actionout_message(): sim_conent = "The request is command-line interface (CLI) snake game" sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) - new_messages = memory_storage.search(sim_message) + new_messages = memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty" new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) - new_messages = memory_storage.search(new_message) + new_messages = memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content memory_storage.clean()