From 44eec631ea18575b79b7e4638c5f244c1151400d Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 25 Dec 2023 14:47:42 +0800 Subject: [PATCH] update MemoryStorage unittest --- tests/metagpt/memory/test_memory_storage.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index 7b74eb512..f1cc12aac 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -4,20 +4,28 @@ @Desc : the unittests of metagpt/memory/memory_storage.py """ - +import os +import shutil +from pathlib import Path from typing import List from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.action_node import ActionNode +from metagpt.config import CONFIG +from metagpt.const import DATA_PATH from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message +os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key) + def test_idea_message(): idea = "Write a cli snake game" role_id = "UTUser1(Product Manager)" message = Message(role="User", content=idea, cause_by=UserRequirement) + shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/")) + memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) assert len(messages) == 0 @@ -27,12 +35,12 @@ def test_idea_message(): sim_idea = "Write a game of cli snake" sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) - new_messages = memory_storage.search(sim_message) + new_messages = memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] new_idea = "Write a 2048 web game" new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) - new_messages = memory_storage.search(new_message) + new_messages = memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content memory_storage.clean() @@ -50,6 +58,8 @@ def test_actionout_message(): content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD ) # WritePRD as test action + shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/")) + memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) assert len(messages) == 0 @@ -59,12 +69,12 @@ def test_actionout_message(): sim_conent = "The request is command-line interface (CLI) snake game" sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) - new_messages = memory_storage.search(sim_message) + new_messages = memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty" new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) - new_messages = memory_storage.search(new_message) + new_messages = memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content memory_storage.clean()