upgrade llama-index-vector-stores-chroma and rag test coverage 100%

This commit is contained in:
seehi 2024-03-27 19:53:50 +08:00
parent 6450a09d3b
commit 6cb5492f02
21 changed files with 600 additions and 382 deletions

View file

@ -1,12 +1,26 @@
import json
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import Document, TextNode
from llama_index.core.embeddings import MockEmbedding
from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
class TestSimpleEngine:
@pytest.fixture
def mock_llm(self):
return MockLLM()
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
@pytest.fixture
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@ -54,7 +68,7 @@ class TestSimpleEngine:
retriever_configs = [mocker.MagicMock()]
ranker_configs = [mocker.MagicMock()]
# Execute
# Exec
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
@ -65,7 +79,7 @@ class TestSimpleEngine:
ranker_configs=ranker_configs,
)
# Assertions
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
@ -75,6 +89,68 @@ class TestSimpleEngine:
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
def test_from_docs_without_file(self):
with pytest.raises(ValueError):
SimpleEngine.from_docs()
def test_from_objs(self, mock_llm, mock_embedding):
# Mock
class MockRAGObject:
def rag_key(self):
return "key"
def model_dump_json(self):
return "{}"
objs = [MockRAGObject()]
# Setup
retriever_configs = []
ranker_configs = []
# Exec
engine = SimpleEngine.from_objs(
objs=objs,
llm=mock_llm,
embed_model=mock_embedding,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is not None
def test_from_objs_with_bm25_config(self):
# Setup
retriever_configs = [BM25RetrieverConfig()]
# Exec
with pytest.raises(ValueError):
SimpleEngine.from_objs(
objs=[],
llm=MockLLM(),
retriever_configs=retriever_configs,
ranker_configs=[],
)
def test_from_index(self, mocker, mock_llm, mock_embedding):
# Mock
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
mock_get_index.return_value = mock_index
# Exec
engine = SimpleEngine.from_index(
index_config=mock_index,
embed_model=mock_embedding,
llm=mock_llm,
)
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is mock_index
@pytest.mark.asyncio
async def test_asearch(self, mocker):
# Mock
@ -86,10 +162,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mocker.MagicMock())
engine.aquery = mock_aquery
# Execute
# Exec
result = await engine.asearch(test_query)
# Assertions
# Assert
mock_aquery.assert_called_once_with(test_query)
assert result == expected_result
@ -106,10 +182,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mocker.MagicMock())
test_query = "test query"
# Execute
# Exec
result = await engine.aretrieve(test_query)
# Assertions
# Assert
mock_query_bundle.assert_called_once_with(test_query)
mock_super_aretrieve.assert_called_once_with("query_bundle")
assert result[0].text == "node_with_score"
@ -134,10 +210,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
input_files = ["test_file1", "test_file2"]
# Execute
# Exec
engine.add_docs(input_files=input_files)
# Assertions
# Assert
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
@ -156,11 +232,79 @@ class TestSimpleEngine:
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
# Execute
# Exec
engine.add_objs(objs=objs)
# Assertions
# Assert
assert mock_retriever.add_nodes.call_count == 1
for node in mock_retriever.add_nodes.call_args[0][0]:
assert isinstance(node, TextNode)
assert "is_obj" in node.metadata
def test_persist_successfully(self, mocker):
# Mock
mock_retriever = mocker.MagicMock(spec=PersistableRAGRetriever)
mock_retriever.persist.return_value = mocker.MagicMock()
# Setup
engine = SimpleEngine(retriever=mock_retriever)
# Exec
engine.persist(persist_dir="")
def test_ensure_retriever_of_type(self, mocker):
# Mock
class MyRetriever:
def add_nodes(self):
...
mock_retriever = mocker.MagicMock(spec=SimpleHybridRetriever)
mock_retriever.retrievers = [MyRetriever()]
# Setup
engine = SimpleEngine(retriever=mock_retriever)
# Assert
engine._ensure_retriever_of_type(ModifiableRAGRetriever)
with pytest.raises(TypeError):
engine._ensure_retriever_of_type(PersistableRAGRetriever)
with pytest.raises(TypeError):
other_engine = SimpleEngine(retriever=mocker.MagicMock(spec=ModifiableRAGRetriever))
other_engine._ensure_retriever_of_type(PersistableRAGRetriever)
def test_with_obj_metadata(self, mocker):
# Mock
node = NodeWithScore(
node=ObjectNode(
text="example",
metadata={
"is_obj": True,
"obj_cls_name": "ExampleObject",
"obj_mod_name": "__main__",
"obj_json": json.dumps({"key": "test_key", "value": "test_value"}),
},
)
)
class ExampleObject:
def __init__(self, key, value):
self.key = key
self.value = value
def __eq__(self, other):
return self.key == other.key and self.value == other.value
mock_import_class = mocker.patch("metagpt.rag.engines.simple.import_class")
mock_import_class.return_value = ExampleObject
# Setup
SimpleEngine._try_reconstruct_obj([node])
# Exec
expected_obj = ExampleObject(key="test_key", value="test_value")
# Assert
assert "obj" in node.node.metadata
assert node.node.metadata["obj"] == expected_obj

View file

@ -0,0 +1,43 @@
import pytest
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
class TestRAGEmbeddingFactory:
@pytest.fixture(autouse=True)
def mock_embedding_factory(self):
self.embedding_factory = RAGEmbeddingFactory()
@pytest.fixture
def mock_openai_embedding(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
@pytest.fixture
def mock_azure_embedding(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
def test_get_rag_embedding_openai(self, mock_openai_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
# Assert
mock_openai_embedding.assert_called_once()
def test_get_rag_embedding_azure(self, mock_azure_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
# Assert
mock_azure_embedding.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
# Mock
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
mock_config.llm.api_type = LLMType.OPENAI
# Exec
self.embedding_factory.get_rag_embedding()
# Assert
mock_openai_embedding.assert_called_once()

View file

@ -0,0 +1,89 @@
import pytest
from llama_index.core.embeddings import MockEmbedding
from metagpt.rag.factories.index import RAGIndexFactory
from metagpt.rag.schema import (
BM25IndexConfig,
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
FAISSIndexConfig,
)
class TestRAGIndexFactory:
@pytest.fixture(autouse=True)
def setup(self):
self.index_factory = RAGIndexFactory()
@pytest.fixture
def faiss_config(self):
return FAISSIndexConfig(persist_path="")
@pytest.fixture
def chroma_config(self):
return ChromaIndexConfig(persist_path="", collection_name="")
@pytest.fixture
def bm25_config(self):
return BM25IndexConfig(persist_path="")
@pytest.fixture
def es_config(self, mocker):
return ElasticsearchIndexConfig(store_config=ElasticsearchStoreConfig())
@pytest.fixture
def mock_storage_context(self, mocker):
return mocker.patch("metagpt.rag.factories.index.StorageContext.from_defaults")
@pytest.fixture
def mock_load_index_from_storage(self, mocker):
return mocker.patch("metagpt.rag.factories.index.load_index_from_storage")
@pytest.fixture
def mock_from_vector_store(self, mocker):
return mocker.patch("metagpt.rag.factories.index.VectorStoreIndex.from_vector_store")
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
def test_create_faiss_index(
self, mocker, faiss_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
):
# Mock
mock_faiss_store = mocker.patch("metagpt.rag.factories.index.FaissVectorStore.from_persist_dir")
# Exec
self.index_factory.get_index(faiss_config, embed_model=mock_embedding)
# Assert
mock_faiss_store.assert_called_once()
def test_create_bm25_index(
self, mocker, bm25_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
):
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
# Mock
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
mock_chroma_db.get_or_create_collection.return_value = mocker.MagicMock()
mock_chroma_store = mocker.patch("metagpt.rag.factories.index.ChromaVectorStore")
# Exec
self.index_factory.get_index(chroma_config, embed_model=mock_embedding)
# Assert
mock_chroma_store.assert_called_once()
def test_create_es_index(self, mocker, es_config, mock_from_vector_store, mock_embedding):
# Mock
mock_es_store = mocker.patch("metagpt.rag.factories.index.ElasticsearchStore")
# Exec
self.index_factory.get_index(es_config, embed_model=mock_embedding)
# Assert
mock_es_store.assert_called_once()

View file

@ -0,0 +1,71 @@
from typing import Optional, Union
import pytest
from llama_index.core.llms import LLMMetadata
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.provider.base_llm import BaseLLM
from metagpt.rag.factories.llm import RAGLLM, get_rag_llm
class MockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
...
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""_achat_completion implemented by inherited class"""
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return "ok"
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return "ok"
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
"""_achat_completion_stream implemented by inherited class"""
async def aask(
self,
msg: Union[str, list[dict[str, str]]],
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
stream=True,
) -> str:
return "ok"
class TestRAGLLM:
@pytest.fixture
def mock_model_infer(self):
return MockLLM(config=LLMConfig())
@pytest.fixture
def rag_llm(self, mock_model_infer):
return RAGLLM(model_infer=mock_model_infer)
def test_metadata(self, rag_llm):
metadata = rag_llm.metadata
assert isinstance(metadata, LLMMetadata)
assert metadata.context_window == rag_llm.context_window
assert metadata.num_output == rag_llm.num_output
assert metadata.model_name == rag_llm.model_name
@pytest.mark.asyncio
async def test_acomplete(self, rag_llm, mock_model_infer):
response = await rag_llm.acomplete("question")
assert response.text == "ok"
def test_complete(self, rag_llm, mock_model_infer):
response = rag_llm.complete("question")
assert response.text == "ok"
def test_stream_complete(self, rag_llm, mock_model_infer):
rag_llm.stream_complete("question")
def test_get_rag_llm():
result = get_rag_llm(MockLLM(config=LLMConfig()))
assert isinstance(result, RAGLLM)

View file

@ -1,41 +1,57 @@
import pytest
from llama_index.core.llms import LLM
from llama_index.core.llms import MockLLM
from llama_index.core.postprocessor import LLMRerank
from metagpt.rag.factories.ranker import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
from metagpt.rag.schema import ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig
class TestRankerFactory:
@pytest.fixture
def ranker_factory(self) -> RankerFactory:
return RankerFactory()
@pytest.fixture(autouse=True)
def ranker_factory(self):
self.ranker_factory: RankerFactory = RankerFactory()
@pytest.fixture
def mock_llm(self, mocker):
return mocker.MagicMock(spec=LLM)
def mock_llm(self):
return MockLLM()
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = ranker_factory.get_rankers()
def test_get_rankers_with_no_configs(self, mock_llm, mocker):
mocker.patch.object(self.ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = self.ranker_factory.get_rankers()
assert len(default_rankers) == 0
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
def test_get_rankers_with_configs(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
rankers = ranker_factory.get_rankers(configs=[mock_config])
rankers = self.ranker_factory.get_rankers(configs=[mock_config])
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
def test_extract_llm_from_config(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = ranker_factory._create_llm_ranker(mock_config)
extracted_llm = self.ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_extract_llm_from_kwargs(self, mock_llm):
extracted_llm = self.ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm
def test_create_llm_ranker(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = self.ranker_factory._create_llm_ranker(mock_config)
assert isinstance(ranker, LLMRerank)
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
extracted_llm = ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_create_colbert_ranker(self, mocker, mock_llm):
mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert")
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm
mock_config = ColbertRerankConfig(llm=mock_llm)
ranker = self.ranker_factory._create_colbert_ranker(mock_config)
assert ranker == "colbert"
def test_create_object_ranker(self, mocker, mock_llm):
mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")
mock_config = ObjectRankerConfig(field_name="fake", llm=mock_llm)
ranker = self.ranker_factory._create_object_ranker(mock_config)
assert ranker == "object"

View file

@ -1,18 +1,28 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
)
class TestRetrieverFactory:
@pytest.fixture
@pytest.fixture(autouse=True)
def retriever_factory(self):
return RetrieverFactory()
self.retriever_factory: RetrieverFactory = RetrieverFactory()
@pytest.fixture
def mock_faiss_index(self, mocker):
@ -25,55 +35,79 @@ class TestRetrieverFactory:
mock.docstore.docs.values.return_value = []
return mock
def test_get_retriever_with_faiss_config(
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
):
@pytest.fixture
def mock_chroma_vector_store(self, mocker):
return mocker.MagicMock(spec=ChromaVectorStore)
@pytest.fixture
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
):
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
assert isinstance(retriever, SimpleHybridRetriever)
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, ElasticsearchRetriever)
def test_create_default_retriever(self, mocker, mock_vector_store_index):
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mock_vector_store_index.as_retriever = mocker.MagicMock()
retriever = retriever_factory.get_retriever()
retriever = self.retriever_factory.get_retriever()
mock_vector_store_index.as_retriever.assert_called_once()
assert retriever is mock_vector_store_index.as_retriever.return_value
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
def test_extract_index_from_config(self, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
extracted_index = retriever_factory._extract_index(config=mock_config)
extracted_index = self.retriever_factory._extract_index(config=mock_config)
assert extracted_index == mock_vector_store_index
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
def test_extract_index_from_kwargs(self, mock_vector_store_index):
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index

View file

@ -0,0 +1,23 @@
import pytest
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from metagpt.rag.rankers.base import RAGRanker
class SimpleRAGRanker(RAGRanker):
def _postprocess_nodes(self, nodes, query_bundle=None):
return [NodeWithScore(node=node.node, score=node.score + 1) for node in nodes]
class TestSimpleRAGRanker:
@pytest.fixture
def ranker(self):
return SimpleRAGRanker()
def test_postprocess_nodes_increases_scores(self, ranker):
nodes = [NodeWithScore(node=TextNode(text="a"), score=10), NodeWithScore(node=TextNode(text="b"), score=20)]
query_bundle = QueryBundle(query_str="test query")
processed_nodes = ranker._postprocess_nodes(nodes, query_bundle)
assert all(node.score == original_node.score + 1 for node, original_node in zip(processed_nodes, nodes))

View file

@ -14,7 +14,7 @@ class Record(BaseModel):
class TestObjectSortPostprocessor:
@pytest.fixture
def nodes_with_scores(self):
def mock_nodes_with_scores(self):
nodes = [
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
@ -23,38 +23,47 @@ class TestObjectSortPostprocessor:
return nodes
@pytest.fixture
def query_bundle(self, mocker):
def mock_query_bundle(self, mocker):
return mocker.MagicMock(spec=QueryBundle)
def test_sort_descending(self, nodes_with_scores, query_bundle):
def test_sort_descending(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert [node.score for node in sorted_nodes] == [20, 10, 5]
def test_sort_ascending(self, nodes_with_scores, query_bundle):
def test_sort_ascending(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert [node.score for node in sorted_nodes] == [5, 10, 20]
def test_top_n_limit(self, nodes_with_scores, query_bundle):
def test_top_n_limit(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert len(sorted_nodes) == 2
assert [node.score for node in sorted_nodes] == [20, 10]
def test_invalid_json_metadata(self, query_bundle):
def test_invalid_json_metadata(self, mock_query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes, query_bundle)
postprocessor._postprocess_nodes(nodes, mock_query_bundle)
def test_missing_query_bundle(self, nodes_with_scores):
def test_missing_query_bundle(self, mock_nodes_with_scores):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
postprocessor._postprocess_nodes(mock_nodes_with_scores, query_bundle=None)
def test_field_not_found_in_object(self):
def test_field_not_found_in_object(self, mock_query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes)
postprocessor._postprocess_nodes(nodes, query_bundle=mock_query_bundle)
def test_not_nodes(self, mock_query_bundle):
nodes = []
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
result = postprocessor._postprocess_nodes(nodes, mock_query_bundle)
assert result == []
def test_class_name(self):
assert ObjectSortPostprocessor.class_name() == "ObjectSortPostprocessor"

View file

@ -0,0 +1,21 @@
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
class SubModifiableRAGRetriever(ModifiableRAGRetriever):
...
class SubPersistableRAGRetriever(PersistableRAGRetriever):
...
class TestModifiableRAGRetriever:
def test_subclasshook(self):
result = SubModifiableRAGRetriever.__subclasshook__(SubModifiableRAGRetriever)
assert result is NotImplemented
class TestPersistableRAGRetriever:
def test_subclasshook(self):
result = SubPersistableRAGRetriever.__subclasshook__(SubPersistableRAGRetriever)
assert result is NotImplemented

View file

@ -8,30 +8,30 @@ from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
class TestDynamicBM25Retriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc1.get_content.return_value = "Document content 1"
self.doc2 = mocker.MagicMock(spec=Node)
self.doc2.get_content.return_value = "Document content 2"
self.mock_nodes = [self.doc1, self.doc2]
# 模拟index
index = mocker.MagicMock(spec=VectorStoreIndex)
index.storage_context.persist.return_value = "ok"
# 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
# 初始化DynamicBM25Retriever对象并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
def test_add_docs_updates_nodes_and_corpus(self):
# Execute
# Exec
self.retriever.add_nodes(self.mock_nodes)
# Assertions
# Assert
assert len(self.retriever._nodes) == len(self.mock_nodes)
assert len(self.retriever._corpus) == len(self.mock_nodes)
self.retriever._tokenizer.assert_called()
self.mock_bm25okapi.assert_called()
def test_persist(self):
self.retriever.persist("")

View file

@ -0,0 +1,20 @@
import pytest
from llama_index.core.schema import Node
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
class TestChromaRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
self.mock_index = mocker.MagicMock()
self.retriever = ChromaRetriever(self.mock_index)
def test_add_nodes(self):
self.retriever.add_nodes(self.mock_nodes)
self.mock_index.insert_nodes.assert_called()

View file

@ -0,0 +1,20 @@
import pytest
from llama_index.core.schema import Node
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
class TestElasticsearchRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
self.mock_index = mocker.MagicMock()
self.retriever = ElasticsearchRetriever(self.mock_index)
def test_add_nodes(self):
self.retriever.add_nodes(self.mock_nodes)
self.mock_index.insert_nodes.assert_called()

View file

@ -7,16 +7,19 @@ from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
class TestFAISSRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
# 模拟FAISSRetriever的_index属性
self.mock_index = mocker.MagicMock()
self.retriever = FAISSRetriever(self.mock_index)
def test_add_docs_calls_insert_for_each_document(self, mocker):
def test_add_docs_calls_insert_for_each_document(self):
self.retriever.add_nodes(self.mock_nodes)
assert self.mock_index.insert_nodes.assert_called
self.mock_index.insert_nodes.assert_called()
def test_persist(self):
self.retriever.persist("")
self.mock_index.storage_context.persist.assert_called()

View file

@ -1,5 +1,3 @@
from unittest.mock import AsyncMock
import pytest
from llama_index.core.schema import NodeWithScore, TextNode
@ -7,18 +5,30 @@ from metagpt.rag.retrievers import SimpleHybridRetriever
class TestSimpleHybridRetriever:
@pytest.fixture
def mock_retriever(self, mocker):
return mocker.MagicMock()
@pytest.fixture
def mock_hybrid_retriever(self, mock_retriever) -> SimpleHybridRetriever:
return SimpleHybridRetriever(mock_retriever)
@pytest.fixture
def mock_node(self):
return NodeWithScore(node=TextNode(id_="2"), score=0.95)
@pytest.mark.asyncio
async def test_aretrieve(self):
async def test_aretrieve(self, mocker):
question = "test query"
# Create mock retrievers
mock_retriever1 = AsyncMock()
mock_retriever1 = mocker.AsyncMock()
mock_retriever1.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="1"), score=1.0),
NodeWithScore(node=TextNode(id_="2"), score=0.95),
]
mock_retriever2 = AsyncMock()
mock_retriever2 = mocker.AsyncMock()
mock_retriever2.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="2"), score=0.95),
NodeWithScore(node=TextNode(id_="3"), score=0.8),
@ -37,3 +47,11 @@ class TestSimpleHybridRetriever:
# Check if the scores are correct (assuming you want the highest score)
node_scores = {node.node.node_id: node.score for node in results}
assert node_scores["2"] == 0.95
def test_add_nodes(self, mock_hybrid_retriever: SimpleHybridRetriever, mock_node):
mock_hybrid_retriever.add_nodes([mock_node])
mock_hybrid_retriever.retrievers[0].add_nodes.assert_called_once()
def test_persist(self, mock_hybrid_retriever: SimpleHybridRetriever):
mock_hybrid_retriever.persist("")
mock_hybrid_retriever.retrievers[0].persist.assert_called_once()