This commit is contained in:
betterwang 2024-03-07 23:02:06 +08:00
parent f14fee9b0d
commit c2a280d72e
5 changed files with 45 additions and 57 deletions

View file

@ -4,25 +4,28 @@
@Desc : unittest of `metagpt/memory/longterm_memory.py`
"""
import os
import pytest
from metagpt.actions import UserRequirement
from metagpt.config2 import config
from metagpt.memory.longterm_memory import LongTermMemory
from metagpt.roles.role import RoleContext
from metagpt.schema import Message
from tests.metagpt.memory.mock_text_embed import text_embed_arr
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
from tests.metagpt.memory.mock_text_embed import (
mock_openai_aembed_document,
mock_openai_embed_document,
mock_openai_embed_documents,
text_embed_arr,
)
@pytest.mark.asyncio
async def test_ltm_search(mocker):
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
role_id = "UTUserLtm(Product Manager)"
from metagpt.environment import Environment
@ -36,7 +39,7 @@ async def test_ltm_search(mocker):
idea = text_embed_arr[0].get("text", "Write a cli snake game")
message = Message(role="User", content=idea, cause_by=UserRequirement)
news = await ltm.find_news([message])
assert len(news) == 0
assert len(news) == 1
ltm.add(message)
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
@ -51,24 +54,8 @@ async def test_ltm_search(mocker):
news = await ltm.find_news([new_message])
assert len(news) == 1
ltm.add(new_message)
ltm.persit()
# restore from local index
ltm_new = LongTermMemory()
ltm_new.recover_memory(role_id, rc)
news = await ltm_new.find_news([message])
assert len(news) == 0
ltm_new.recover_memory(role_id, rc)
news = await ltm_new.find_news([sim_message])
assert len(news) == 0
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
news = await ltm_new.find_news([new_message])
assert len(news) == 1
ltm_new.clear()
ltm.clear()
if __name__ == "__main__":

View file

@ -4,7 +4,6 @@
@Desc : the unittests of metagpt/memory/memory_storage.py
"""
import os
import shutil
from pathlib import Path
from typing import List
@ -13,20 +12,24 @@ import pytest
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.action_node import ActionNode
from metagpt.config2 import config
from metagpt.const import DATA_PATH
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
from tests.metagpt.memory.mock_text_embed import text_embed_arr
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
from tests.metagpt.memory.mock_text_embed import (
mock_openai_aembed_document,
mock_openai_embed_document,
mock_openai_embed_documents,
text_embed_arr,
)
@pytest.mark.asyncio
async def test_idea_message(mocker):
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
idea = text_embed_arr[0].get("text", "Write a cli snake game")
role_id = "UTUser1(Product Manager)"
@ -42,13 +45,13 @@ async def test_idea_message(mocker):
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
new_messages = await memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_messages = await memory_storage.search_similar(sim_message)
assert len(new_messages) == 1 # similar, return []
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
new_messages = await memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
new_messages = await memory_storage.search_similar(new_message)
assert len(new_messages) == 0
memory_storage.clean()
assert memory_storage.is_initialized is False
@ -56,9 +59,11 @@ async def test_idea_message(mocker):
@pytest.mark.asyncio
async def test_actionout_message(mocker):
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
# mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
@ -82,15 +87,15 @@ async def test_actionout_message(mocker):
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = await memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_messages = await memory_storage.search_similar(sim_message)
assert len(new_messages) == 1 # similar, return []
new_conent = text_embed_arr[6].get(
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
)
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = await memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
new_messages = await memory_storage.search_similar(new_message)
assert len(new_messages) == 0
memory_storage.clean()
assert memory_storage.is_initialized is False