mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 20:32:38 +02:00
rag pipeline
This commit is contained in:
parent
cc91df59e5
commit
29d36948bf
18 changed files with 372 additions and 15 deletions
|
|
@ -28,7 +28,7 @@ def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int
|
|||
async def test_search_json(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -39,7 +39,7 @@ async def test_search_json(mocker):
|
|||
async def test_search_xlsx(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/example.xlsx")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -50,7 +50,7 @@ async def test_search_xlsx(mocker):
|
|||
async def test_write(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.storage_context.docstore
|
||||
assert _faiss_store.storage_context.vector_store.client
|
||||
|
|
|
|||
67
tests/metagpt/rag/engine/test_simple.py
Normal file
67
tests/metagpt/rag/engine/test_simple.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.rag import SimpleEngine
|
||||
|
||||
|
||||
class TestSimpleEngineFromDocs:
|
||||
def test_from_docs(self, mocker):
|
||||
# Mock dependencies
|
||||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
|
||||
|
||||
mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
|
||||
mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
mock_vector_index_retriever = mocker.patch("metagpt.rag.engines.simple.VectorIndexRetriever")
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
embed_model = mocker.MagicMock()
|
||||
llm = mocker.MagicMock()
|
||||
chunk_size = 100
|
||||
chunk_overlap = 10
|
||||
similarity_top_k = 5
|
||||
|
||||
# Execute
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
similarity_top_k=similarity_top_k,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_service_context.assert_called_once_with(
|
||||
embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
|
||||
)
|
||||
mock_vector_store_index.assert_called_once_with(
|
||||
["document1", "document2"], service_context=mock_service_context.return_value
|
||||
)
|
||||
mock_vector_index_retriever.assert_called_once_with(
|
||||
index=mock_vector_store_index.return_value, similarity_top_k=similarity_top_k
|
||||
)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_calls_aquery(self, mocker):
|
||||
# Mock
|
||||
test_query = "test query"
|
||||
expected_result = "expected result"
|
||||
mock_aquery = AsyncMock(return_value=expected_result)
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
engine.aquery = mock_aquery
|
||||
|
||||
# Execute
|
||||
result = await engine.asearch(test_query)
|
||||
|
||||
# Assertions
|
||||
mock_aquery.assert_called_once_with(test_query)
|
||||
assert result == expected_result
|
||||
39
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
39
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from llama_index.schema import NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
|
||||
|
||||
class TestSimpleHybridRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self):
|
||||
question = "test query"
|
||||
|
||||
# Create mock retrievers
|
||||
mock_retriever1 = AsyncMock()
|
||||
mock_retriever1.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="1"), score=1.0),
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
]
|
||||
|
||||
mock_retriever2 = AsyncMock()
|
||||
mock_retriever2.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
NodeWithScore(node=TextNode(id_="3"), score=0.8),
|
||||
]
|
||||
|
||||
# Instantiate the SimpleHybridRetriever with the mock retrievers
|
||||
hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2)
|
||||
|
||||
# Call the _aretrieve method
|
||||
results = await hybrid_retriever._aretrieve(question)
|
||||
|
||||
# Check if the results are as expected
|
||||
assert len(results) == 3 # Should be 3 unique nodes
|
||||
assert set(node.node.node_id for node in results) == {"1", "2", "3"}
|
||||
|
||||
# Check if the scores are correct (assuming you want the highest score)
|
||||
node_scores = {node.node.node_id: node.score for node in results}
|
||||
assert node_scores["2"] == 0.95
|
||||
Loading…
Add table
Add a link
Reference in a new issue