Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev

This commit is contained in:
莘权 马 2023-12-25 16:15:04 +08:00
commit fe1d60f111
5 changed files with 97 additions and 24 deletions

View file

@ -135,11 +135,3 @@ class Memory(BaseModel):
continue
rsp += self.index[action]
return rsp
def get_by_tags(self, tags: list) -> list[Message]:
"""Return messages with specified tags"""
result = []
for m in self.storage:
if m.is_contain_tags(tags):
result.append(m)
return result

View file

@ -6,9 +6,11 @@
"""
from pathlib import Path
from typing import List
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
@ -22,20 +24,30 @@ class MemoryStorage(FaissStore):
The memory storage with Faiss as ANN search engine
"""
def __init__(self, mem_ttl: int = MEM_TTL):
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
self.role_id: str = None
self.role_mem_path: str = None
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or OpenAIEmbeddings()
self.store: FAISS = None # Faiss engine
@property
def is_initialized(self) -> bool:
return self._initialized
def recover_memory(self, role_id: str) -> List[Message]:
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
def recover_memory(self, role_id: str) -> list[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
@ -52,16 +64,16 @@ class MemoryStorage(FaissStore):
return messages
def _get_index_and_store_fname(self):
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
if not self.role_mem_path:
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
return index_fpath, storage_fpath
def persist(self):
super().persist()
self.store.save_local(self.role_mem_path, self.role_id)
logger.debug(f"Agent {self.role_id} persist memory into local")
def add(self, message: Message) -> bool:
@ -77,7 +89,7 @@ class MemoryStorage(FaissStore):
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
def search_dissimilar(self, message: Message, k=4) -> List[Message]:
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
"""search for dissimilar messages"""
if not self.store:
return []

View file

@ -5,6 +5,8 @@
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
import os
from metagpt.actions import UserRequirement
from metagpt.config import CONFIG
from metagpt.memory.longterm_memory import LongTermMemory
@ -14,11 +16,11 @@ from metagpt.schema import Message
def test_ltm_search():
assert hasattr(CONFIG, "long_term_memory") is True
openai_api_key = CONFIG.openai_api_key
assert len(openai_api_key) > 20
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
assert len(CONFIG.openai_api_key) > 20
role_id = "UTUserLtm(Product Manager)"
rc = RoleContext(watch=[UserRequirement])
rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"})
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)

View file

@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of Memory
from metagpt.actions import UserRequirement
from metagpt.memory.memory import Memory
from metagpt.schema import Message
def test_memory():
memory = Memory()
message1 = Message(content="test message1", role="user1")
message2 = Message(content="test message2", role="user2")
message3 = Message(content="test message3", role="user1")
memory.add(message1)
assert memory.count() == 1
memory.delete_newest()
assert memory.count() == 0
memory.add_batch([message1, message2])
assert memory.count() == 2
assert len(memory.index.get(message1.cause_by)) == 2
messages = memory.get_by_role("user1")
assert messages[0].content == message1.content
messages = memory.get_by_content("test message")
assert len(messages) == 2
messages = memory.get_by_action(UserRequirement)
assert len(messages) == 2
messages = memory.get_by_actions([UserRequirement])
assert len(messages) == 2
messages = memory.try_remember("test message")
assert len(messages) == 2
messages = memory.get(k=1)
assert len(messages) == 1
messages = memory.get(k=5)
assert len(messages) == 2
messages = memory.find_news([message3])
assert len(messages) == 1
memory.delete(message1)
assert memory.count() == 1
messages = memory.get_by_role("user2")
assert messages[0].content == message2.content
memory.clear()
assert memory.count() == 0
assert len(memory.index) == 0

View file

@ -4,20 +4,28 @@
@Desc : the unittests of metagpt/memory/memory_storage.py
"""
import os
import shutil
from pathlib import Path
from typing import List
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.action_node import ActionNode
from metagpt.config import CONFIG
from metagpt.const import DATA_PATH
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
def test_idea_message():
idea = "Write a cli snake game"
role_id = "UTUser1(Product Manager)"
message = Message(role="User", content=idea, cause_by=UserRequirement)
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"))
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
assert len(messages) == 0
@ -27,12 +35,12 @@ def test_idea_message():
sim_idea = "Write a game of cli snake"
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
new_messages = memory_storage.search(sim_message)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_idea = "Write a 2048 web game"
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
new_messages = memory_storage.search(new_message)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
memory_storage.clean()
@ -50,6 +58,8 @@ def test_actionout_message():
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
) # WritePRD as test action
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"))
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
assert len(messages) == 0
@ -59,12 +69,12 @@ def test_actionout_message():
sim_conent = "The request is command-line interface (CLI) snake game"
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(sim_message)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(new_message)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
memory_storage.clean()