From 9ec56263133f916269c04355f925ce8140217790 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 27 Feb 2024 14:08:05 +0800 Subject: [PATCH] mock openai embed for document_store and memory UTs --- metagpt/memory/memory_storage.py | 4 +-- .../document_store/test_faiss_store.py | 22 +++++++++++-- tests/metagpt/memory/mock_text_embed.py | 33 +++++++++++++++++++ tests/metagpt/memory/test_longterm_memory.py | 20 ++++++----- tests/metagpt/memory/test_memory_storage.py | 32 +++++++++++------- 5 files changed, 85 insertions(+), 26 deletions(-) create mode 100644 tests/metagpt/memory/mock_text_embed.py diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index c029d027b..fa04d8138 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import Optional -from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores.faiss import FAISS from langchain_core.embeddings import Embeddings @@ -15,6 +14,7 @@ from metagpt.const import DATA_PATH, MEM_TTL from metagpt.document_store.faiss_store import FaissStore from metagpt.logs import logger from metagpt.schema import Message +from metagpt.utils.embedding import get_embedding from metagpt.utils.serialize import deserialize_message, serialize_message @@ -30,7 +30,7 @@ class MemoryStorage(FaissStore): self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories self._initialized: bool = False - self.embedding = embedding or OpenAIEmbeddings() + self.embedding = embedding or get_embedding() self.store: FAISS = None # Faiss engine @property diff --git a/tests/metagpt/document_store/test_faiss_store.py b/tests/metagpt/document_store/test_faiss_store.py index 7e2979bd4..397ba6ce5 100644 --- a/tests/metagpt/document_store/test_faiss_store.py +++ b/tests/metagpt/document_store/test_faiss_store.py @@ -6,6 +6,9 @@ @File : test_faiss_store.py """ +from typing import Optional + +import numpy as np import pytest from metagpt.const import EXAMPLE_PATH @@ -14,8 +17,17 @@ from metagpt.logs import logger from metagpt.roles import Sales +def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]: + num = len(texts) + embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim + embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0)) + return embeds + + @pytest.mark.asyncio -async def test_search_json(): +async def test_search_json(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + store = FaissStore(EXAMPLE_PATH / "example.json") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" @@ -24,7 +36,9 @@ async def test_search_json(): @pytest.mark.asyncio -async def test_search_xlsx(): +async def test_search_xlsx(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + store = FaissStore(EXAMPLE_PATH / "example.xlsx") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" @@ -33,7 +47,9 @@ async def test_search_xlsx(): @pytest.mark.asyncio -async def test_write(): +async def test_write(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question") _faiss_store = store.write() assert _faiss_store.docstore diff --git a/tests/metagpt/memory/mock_text_embed.py b/tests/metagpt/memory/mock_text_embed.py new file mode 100644 index 000000000..897c7cf10 --- /dev/null +++ b/tests/metagpt/memory/mock_text_embed.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Optional + +import numpy as np + +dim = 1536 # openai embedding dim + +text_embed_arr = [ + {"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below + {"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])}, + {"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])}, + {"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])}, + { + "text": "The user has requested the creation of a command-line interface (CLI) snake game", + "embed": np.zeros(shape=[1, dim]), + }, + {"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])}, + { + "text": "Incorporate basic features of a snake game such as scoring and increasing difficulty", + "embed": np.ones(shape=[1, dim]), + }, +] + +text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)} + + +def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]: + idx = text_idx_dict.get(texts[0]) + embed = text_embed_arr[idx].get("embed") + return embed diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index 5c71ddd13..f7e652758 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -4,20 +4,22 @@ @Desc : unittest of `metagpt/memory/longterm_memory.py` """ -import os import pytest from metagpt.actions import UserRequirement -from metagpt.config2 import config from metagpt.memory.longterm_memory import LongTermMemory from metagpt.roles.role import RoleContext from metagpt.schema import Message - -os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key) +from tests.metagpt.memory.mock_text_embed import ( + mock_openai_embed_documents, + text_embed_arr, +) -def test_ltm_search(): +def test_ltm_search(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + role_id = "UTUserLtm(Product Manager)" from metagpt.environment import Environment @@ -27,20 +29,20 @@ def test_ltm_search(): ltm = LongTermMemory() ltm.recover_memory(role_id, rc) - idea = "Write a cli snake game" + idea = text_embed_arr[0].get("text", "Write a cli snake game") message = Message(role="User", content=idea, cause_by=UserRequirement) news = ltm.find_news([message]) assert len(news) == 1 ltm.add(message) - sim_idea = "Write a game of cli snake" + sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) news = ltm.find_news([sim_message]) assert len(news) == 0 ltm.add(sim_message) - new_idea = "Write a 2048 web game" + new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) news = ltm.find_news([new_message]) assert len(news) == 1 @@ -56,7 +58,7 @@ def test_ltm_search(): news = ltm_new.find_news([sim_message]) assert len(news) == 0 - new_idea = "Write a Battle City" + new_idea = text_embed_arr[3].get("text", "Write a Battle City") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) news = ltm_new.find_news([new_message]) assert len(news) == 1 diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index e82a82fc8..28a73276b 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -4,23 +4,25 @@ @Desc : the unittests of metagpt/memory/memory_storage.py """ -import os import shutil from pathlib import Path from typing import List from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.action_node import ActionNode -from metagpt.config2 import config from metagpt.const import DATA_PATH from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message - -os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key) +from tests.metagpt.memory.mock_text_embed import ( + mock_openai_embed_documents, + text_embed_arr, +) -def test_idea_message(): - idea = "Write a cli snake game" +def test_idea_message(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + + idea = text_embed_arr[0].get("text", "Write a cli snake game") role_id = "UTUser1(Product Manager)" message = Message(role="User", content=idea, cause_by=UserRequirement) @@ -33,12 +35,12 @@ def test_idea_message(): memory_storage.add(message) assert memory_storage.is_initialized is True - sim_idea = "Write a game of cli snake" + sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) new_messages = memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] - new_idea = "Write a 2048 web game" + new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) new_messages = memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content @@ -47,13 +49,17 @@ def test_idea_message(): assert memory_storage.is_initialized is False -def test_actionout_message(): +def test_actionout_message(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} ic_obj = ActionNode.create_model_class("prd", out_mapping) role_id = "UTUser2(Architect)" - content = "The user has requested the creation of a command-line interface (CLI) snake game" + content = text_embed_arr[4].get( + "text", "The user has requested the creation of a command-line interface (CLI) snake game" + ) message = Message( content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD ) # WritePRD as test action @@ -67,12 +73,14 @@ def test_actionout_message(): memory_storage.add(message) assert memory_storage.is_initialized is True - sim_conent = "The request is command-line interface (CLI) snake game" + sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game") sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) new_messages = memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] - new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty" + new_conent = text_embed_arr[6].get( + "text", "Incorporate basic features of a snake game such as scoring and increasing difficulty" + ) new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) new_messages = memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content