memory_storage use rag_pipeline

This commit is contained in:
betterwang 2024-03-07 22:07:04 +08:00
parent b669a7df80
commit 716cb1a0c5
7 changed files with 63 additions and 23 deletions

View file

@ -6,7 +6,10 @@
@File : test_faiss_store.py
"""
<<<<<<< HEAD
from typing import Optional
=======
>>>>>>> f14fee9b (memory_storage use rag_pipeline)
import numpy as np
import pytest
@ -17,16 +20,22 @@ from metagpt.logs import logger
from metagpt.roles import Sales
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
num = len(texts)
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0))
return embeds
embeds = (embeds - embeds.mean(axis=0)) / embeds.std(axis=0)
return embeds.tolist()
def mock_openai_embed_document(self, text: str) -> list[float]:
embeds = mock_openai_embed_documents(self, [text])
return embeds[0]
@pytest.mark.asyncio
async def test_search_json(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.json")
role = Sales(profile="Sales", store=store)
@ -37,9 +46,10 @@ async def test_search_json(mocker):
@pytest.mark.asyncio
async def test_search_xlsx(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx")
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
@ -48,7 +58,8 @@ async def test_search_xlsx(mocker):
@pytest.mark.asyncio
async def test_write(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
_faiss_store = store.write()

View file

@ -2,32 +2,42 @@
# -*- coding: utf-8 -*-
# @Desc :
from typing import Optional
import numpy as np
dim = 1536 # openai embedding dim
embed_zeros_arrr = np.zeros(shape=[1, dim]).tolist()
embed_ones_arrr = np.ones(shape=[1, dim]).tolist()
text_embed_arr = [
{"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below
{"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])},
{"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])},
{"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])},
{"text": "Write a cli snake game", "embed": embed_zeros_arrr}, # mock data, same as below
{"text": "Write a game of cli snake", "embed": embed_zeros_arrr},
{"text": "Write a 2048 web game", "embed": embed_ones_arrr},
{"text": "Write a Battle City", "embed": embed_ones_arrr},
{
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
"embed": np.zeros(shape=[1, dim]),
"embed": embed_zeros_arrr,
},
{"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])},
{"text": "The request is command-line interface (CLI) snake game", "embed": embed_zeros_arrr},
{
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
"embed": np.ones(shape=[1, dim]),
"embed": embed_ones_arrr,
},
]
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
idx = text_idx_dict.get(texts[0])
embed = text_embed_arr[idx].get("embed")
return embed
def mock_openai_embed_document(self, text: str) -> list[float]:
embeds = mock_openai_embed_documents(self, [text])
return embeds[0]
async def mock_openai_aembed_document(self, text: str) -> list[float]:
return mock_openai_embed_document(self, text)

View file

@ -13,13 +13,17 @@ from metagpt.roles.role import RoleContext
from metagpt.schema import Message
from tests.metagpt.memory.mock_text_embed import (
mock_openai_embed_documents,
mock_openai_embed_document,
mock_openai_aembed_document,
text_embed_arr,
)
@pytest.mark.asyncio
async def test_ltm_search(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
role_id = "UTUserLtm(Product Manager)"
from metagpt.environment import Environment
@ -33,7 +37,7 @@ async def test_ltm_search(mocker):
idea = text_embed_arr[0].get("text", "Write a cli snake game")
message = Message(role="User", content=idea, cause_by=UserRequirement)
news = await ltm.find_news([message])
assert len(news) == 1
assert len(news) == 0
ltm.add(message)
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
@ -48,6 +52,7 @@ async def test_ltm_search(mocker):
news = await ltm.find_news([new_message])
assert len(news) == 1
ltm.add(new_message)
ltm.persit()
# restore from local index
ltm_new = LongTermMemory()

View file

@ -17,13 +17,17 @@ from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
from tests.metagpt.memory.mock_text_embed import (
mock_openai_embed_documents,
mock_openai_embed_document,
mock_openai_aembed_document,
text_embed_arr,
)
@pytest.mark.asyncio
async def test_idea_message(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
idea = text_embed_arr[0].get("text", "Write a cli snake game")
role_id = "UTUser1(Product Manager)"
@ -53,7 +57,9 @@ async def test_idea_message(mocker):
@pytest.mark.asyncio
async def test_actionout_message(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document)
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}