rag pipeline

This commit is contained in:
seehi 2024-01-30 20:19:50 +08:00 committed by betterwang
parent cc91df59e5
commit 29d36948bf
18 changed files with 372 additions and 15 deletions

View file

@ -28,7 +28,7 @@ def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int
async def test_search_json(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
store = FaissStore(EXAMPLE_PATH / "example.json")
store = FaissStore(EXAMPLE_PATH / "data/example.json")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
@ -39,7 +39,7 @@ async def test_search_json(mocker):
async def test_search_xlsx(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
store = FaissStore(EXAMPLE_PATH / "data/example.xlsx")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
@ -50,7 +50,7 @@ async def test_search_xlsx(mocker):
async def test_write(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
store = FaissStore(EXAMPLE_PATH / "data/example.xlsx", meta_col="Answer", content_col="Question")
_faiss_store = store.write()
assert _faiss_store.storage_context.docstore
assert _faiss_store.storage_context.vector_store.client

View file

@ -0,0 +1,67 @@
from unittest.mock import AsyncMock
import pytest
from metagpt.rag import SimpleEngine
class TestSimpleEngineFromDocs:
def test_from_docs(self, mocker):
# Mock dependencies
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
mock_vector_index_retriever = mocker.patch("metagpt.rag.engines.simple.VectorIndexRetriever")
# Setup
input_dir = "test_dir"
input_files = ["test_file1", "test_file2"]
embed_model = mocker.MagicMock()
llm = mocker.MagicMock()
chunk_size = 100
chunk_overlap = 10
similarity_top_k = 5
# Execute
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
embed_model=embed_model,
llm=llm,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
similarity_top_k=similarity_top_k,
)
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_service_context.assert_called_once_with(
embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
)
mock_vector_store_index.assert_called_once_with(
["document1", "document2"], service_context=mock_service_context.return_value
)
mock_vector_index_retriever.assert_called_once_with(
index=mock_vector_store_index.return_value, similarity_top_k=similarity_top_k
)
assert isinstance(engine, SimpleEngine)
@pytest.mark.asyncio
async def test_asearch_calls_aquery(self, mocker):
# Mock
test_query = "test query"
expected_result = "expected result"
mock_aquery = AsyncMock(return_value=expected_result)
# Setup
engine = SimpleEngine(retriever=mocker.MagicMock())
engine.aquery = mock_aquery
# Execute
result = await engine.asearch(test_query)
# Assertions
mock_aquery.assert_called_once_with(test_query)
assert result == expected_result

View file

@ -0,0 +1,39 @@
from unittest.mock import AsyncMock
import pytest
from llama_index.schema import NodeWithScore, TextNode
from metagpt.rag.retrievers import SimpleHybridRetriever
class TestSimpleHybridRetriever:
@pytest.mark.asyncio
async def test_aretrieve(self):
question = "test query"
# Create mock retrievers
mock_retriever1 = AsyncMock()
mock_retriever1.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="1"), score=1.0),
NodeWithScore(node=TextNode(id_="2"), score=0.95),
]
mock_retriever2 = AsyncMock()
mock_retriever2.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="2"), score=0.95),
NodeWithScore(node=TextNode(id_="3"), score=0.8),
]
# Instantiate the SimpleHybridRetriever with the mock retrievers
hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2)
# Call the _aretrieve method
results = await hybrid_retriever._aretrieve(question)
# Check if the results are as expected
assert len(results) == 3 # Should be 3 unique nodes
assert set(node.node.node_id for node in results) == {"1", "2", "3"}
# Check if the scores are correct (assuming you want the highest score)
node_scores = {node.node.node_id: node.score for node in results}
assert node_scores["2"] == 0.95