mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-03 21:02:38 +02:00
fix
This commit is contained in:
parent
716cb1a0c5
commit
fb6b9e2928
5 changed files with 34 additions and 44 deletions
|
|
@ -61,8 +61,8 @@ class LongTermMemory(Memory):
|
|||
ltm_news: list[Message] = []
|
||||
for mem in stm_news:
|
||||
# filter out messages similar to those seen previously in ltm, only keep fresh news
|
||||
mem_searched = await self.memory_storage.search_dissimilar(mem)
|
||||
if len(mem_searched) > 0:
|
||||
mem_searched = await self.memory_storage.search_similar(mem)
|
||||
if len(mem_searched) == 0:
|
||||
ltm_news.append(mem)
|
||||
return ltm_news[-k:]
|
||||
|
||||
|
|
|
|||
|
|
@ -58,16 +58,14 @@ class MemoryStorage(object):
|
|||
self.faiss_engine.add_objs([message])
|
||||
logger.info(f"Role {self.role_id}'s memory_storage add a message")
|
||||
|
||||
async def search_dissimilar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for dissimilar messages"""
|
||||
async def search_similar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for similar messages"""
|
||||
# filter the result which score is smaller than the threshold
|
||||
filtered_resp = []
|
||||
resp = await self.faiss_engine.aretrieve(message.content)
|
||||
for item in resp:
|
||||
print(" item.score ", item.score, item)
|
||||
if item.score < self.threshold:
|
||||
continue
|
||||
filtered_resp.append(item.metadata.get("obj"))
|
||||
filtered_resp.append(item.metadata.get("obj"))
|
||||
return filtered_resp
|
||||
|
||||
def clean(self):
|
||||
|
|
@ -76,4 +74,4 @@ class MemoryStorage(object):
|
|||
|
||||
def persit(self):
|
||||
if self.faiss_engine:
|
||||
self.faiss_engine.index.storage_context.persist(self.cache_dir)
|
||||
self.faiss_engine.retriever._index.storage_context.persist(self.cache_dir)
|
||||
|
|
|
|||
|
|
@ -29,10 +29,8 @@ class RAGIndexFactory(ConfigFactory):
|
|||
embed_model = self.extract_embed_model(config, **kwargs)
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
|
||||
storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store, persist_dir=config.persist_path, embed_mode=embed_model
|
||||
)
|
||||
index = load_index_from_storage(storage_context=storage_context)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from metagpt.memory.longterm_memory import LongTermMemory
|
|||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_embed_documents,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_aembed_document,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
|
@ -23,7 +23,9 @@ from tests.metagpt.memory.mock_text_embed import (
|
|||
async def test_ltm_search(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
|
@ -37,7 +39,7 @@ async def test_ltm_search(mocker):
|
|||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = await ltm.find_news([message])
|
||||
assert len(news) == 0
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
|
|
@ -52,24 +54,8 @@ async def test_ltm_search(mocker):
|
|||
news = await ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
ltm.persit()
|
||||
|
||||
# restore from local index
|
||||
ltm_new = LongTermMemory()
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = await ltm_new.find_news([message])
|
||||
assert len(news) == 0
|
||||
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = await ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = await ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
ltm_new.clear()
|
||||
ltm.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ from metagpt.const import DATA_PATH
|
|||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_embed_documents,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_aembed_document,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
|
@ -27,7 +27,9 @@ from tests.metagpt.memory.mock_text_embed import (
|
|||
async def test_idea_message(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
|
|
@ -43,13 +45,13 @@ async def test_idea_message(mocker):
|
|||
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = await memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
new_messages = await memory_storage.search_similar(sim_message)
|
||||
assert len(new_messages) == 1 # similar, return []
|
||||
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = await memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
new_messages = await memory_storage.search_similar(new_message)
|
||||
assert len(new_messages) == 0
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
|
@ -59,7 +61,13 @@ async def test_idea_message(mocker):
|
|||
async def test_actionout_message(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
<<<<<<< HEAD
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
|
||||
=======
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
>>>>>>> c2a280d7 (fix)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
|
|
@ -83,15 +91,15 @@ async def test_actionout_message(mocker):
|
|||
|
||||
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = await memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
new_messages = await memory_storage.search_similar(sim_message)
|
||||
assert len(new_messages) == 1 # similar, return []
|
||||
|
||||
new_conent = text_embed_arr[6].get(
|
||||
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
)
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = await memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
new_messages = await memory_storage.search_similar(new_message)
|
||||
assert len(new_messages) == 0
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue