mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-27 01:36:29 +02:00
upgrade llama-index-vector-stores-chroma and rag test coverage 100%
This commit is contained in:
parent
6450a09d3b
commit
6cb5492f02
21 changed files with 600 additions and 382 deletions
|
|
@ -1,12 +1,26 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import Document, TextNode
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.llms import MockLLM
|
||||
from llama_index.core.schema import Document, NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
|
||||
|
||||
|
||||
class TestSimpleEngine:
|
||||
@pytest.fixture
|
||||
def mock_llm(self):
|
||||
return MockLLM()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
|
@ -54,7 +68,7 @@ class TestSimpleEngine:
|
|||
retriever_configs = [mocker.MagicMock()]
|
||||
ranker_configs = [mocker.MagicMock()]
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
|
|
@ -65,7 +79,7 @@ class TestSimpleEngine:
|
|||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
|
|
@ -75,6 +89,68 @@ class TestSimpleEngine:
|
|||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
def test_from_docs_without_file(self):
|
||||
with pytest.raises(ValueError):
|
||||
SimpleEngine.from_docs()
|
||||
|
||||
def test_from_objs(self, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
class MockRAGObject:
|
||||
def rag_key(self):
|
||||
return "key"
|
||||
|
||||
def model_dump_json(self):
|
||||
return "{}"
|
||||
|
||||
objs = [MockRAGObject()]
|
||||
|
||||
# Setup
|
||||
retriever_configs = []
|
||||
ranker_configs = []
|
||||
|
||||
# Exec
|
||||
engine = SimpleEngine.from_objs(
|
||||
objs=objs,
|
||||
llm=mock_llm,
|
||||
embed_model=mock_embedding,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is not None
|
||||
|
||||
def test_from_objs_with_bm25_config(self):
|
||||
# Setup
|
||||
retriever_configs = [BM25RetrieverConfig()]
|
||||
|
||||
# Exec
|
||||
with pytest.raises(ValueError):
|
||||
SimpleEngine.from_objs(
|
||||
objs=[],
|
||||
llm=MockLLM(),
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=[],
|
||||
)
|
||||
|
||||
def test_from_index(self, mocker, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
|
||||
mock_get_index.return_value = mock_index
|
||||
|
||||
# Exec
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=mock_index,
|
||||
embed_model=mock_embedding,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is mock_index
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
# Mock
|
||||
|
|
@ -86,10 +162,10 @@ class TestSimpleEngine:
|
|||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
engine.aquery = mock_aquery
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
result = await engine.asearch(test_query)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
mock_aquery.assert_called_once_with(test_query)
|
||||
assert result == expected_result
|
||||
|
||||
|
|
@ -106,10 +182,10 @@ class TestSimpleEngine:
|
|||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
test_query = "test query"
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
result = await engine.aretrieve(test_query)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
mock_query_bundle.assert_called_once_with(test_query)
|
||||
mock_super_aretrieve.assert_called_once_with("query_bundle")
|
||||
assert result[0].text == "node_with_score"
|
||||
|
|
@ -134,10 +210,10 @@ class TestSimpleEngine:
|
|||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
engine.add_docs(input_files=input_files)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
|
||||
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
|
||||
|
||||
|
|
@ -156,11 +232,79 @@ class TestSimpleEngine:
|
|||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
engine.add_objs(objs=objs)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
assert mock_retriever.add_nodes.call_count == 1
|
||||
for node in mock_retriever.add_nodes.call_args[0][0]:
|
||||
assert isinstance(node, TextNode)
|
||||
assert "is_obj" in node.metadata
|
||||
|
||||
def test_persist_successfully(self, mocker):
|
||||
# Mock
|
||||
mock_retriever = mocker.MagicMock(spec=PersistableRAGRetriever)
|
||||
mock_retriever.persist.return_value = mocker.MagicMock()
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Exec
|
||||
engine.persist(persist_dir="")
|
||||
|
||||
def test_ensure_retriever_of_type(self, mocker):
|
||||
# Mock
|
||||
class MyRetriever:
|
||||
def add_nodes(self):
|
||||
...
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=SimpleHybridRetriever)
|
||||
mock_retriever.retrievers = [MyRetriever()]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Assert
|
||||
engine._ensure_retriever_of_type(ModifiableRAGRetriever)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
engine._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
other_engine = SimpleEngine(retriever=mocker.MagicMock(spec=ModifiableRAGRetriever))
|
||||
other_engine._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
def test_with_obj_metadata(self, mocker):
|
||||
# Mock
|
||||
node = NodeWithScore(
|
||||
node=ObjectNode(
|
||||
text="example",
|
||||
metadata={
|
||||
"is_obj": True,
|
||||
"obj_cls_name": "ExampleObject",
|
||||
"obj_mod_name": "__main__",
|
||||
"obj_json": json.dumps({"key": "test_key", "value": "test_value"}),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
class ExampleObject:
|
||||
def __init__(self, key, value):
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.key == other.key and self.value == other.value
|
||||
|
||||
mock_import_class = mocker.patch("metagpt.rag.engines.simple.import_class")
|
||||
mock_import_class.return_value = ExampleObject
|
||||
|
||||
# Setup
|
||||
SimpleEngine._try_reconstruct_obj([node])
|
||||
|
||||
# Exec
|
||||
expected_obj = ExampleObject(key="test_key", value="test_value")
|
||||
|
||||
# Assert
|
||||
assert "obj" in node.node.metadata
|
||||
assert node.node.metadata["obj"] == expected_obj
|
||||
|
|
|
|||
43
tests/metagpt/rag/factories/test_embedding.py
Normal file
43
tests/metagpt/rag/factories/test_embedding.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
||||
class TestRAGEmbeddingFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_embedding_factory(self):
|
||||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding()
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
89
tests/metagpt/rag/factories/test_index.py
Normal file
89
tests/metagpt/rag/factories/test_index.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
|
||||
from metagpt.rag.factories.index import RAGIndexFactory
|
||||
from metagpt.rag.schema import (
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestRAGIndexFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
self.index_factory = RAGIndexFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_config(self):
|
||||
return FAISSIndexConfig(persist_path="")
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_config(self):
|
||||
return ChromaIndexConfig(persist_path="", collection_name="")
|
||||
|
||||
@pytest.fixture
|
||||
def bm25_config(self):
|
||||
return BM25IndexConfig(persist_path="")
|
||||
|
||||
@pytest.fixture
|
||||
def es_config(self, mocker):
|
||||
return ElasticsearchIndexConfig(store_config=ElasticsearchStoreConfig())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_context(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.StorageContext.from_defaults")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_load_index_from_storage(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.load_index_from_storage")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_from_vector_store(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.VectorStoreIndex.from_vector_store")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
def test_create_faiss_index(
|
||||
self, mocker, faiss_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
|
||||
):
|
||||
# Mock
|
||||
mock_faiss_store = mocker.patch("metagpt.rag.factories.index.FaissVectorStore.from_persist_dir")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(faiss_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_faiss_store.assert_called_once()
|
||||
|
||||
def test_create_bm25_index(
|
||||
self, mocker, bm25_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
|
||||
):
|
||||
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
|
||||
|
||||
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
|
||||
# Mock
|
||||
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
|
||||
mock_chroma_db.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
|
||||
mock_chroma_store = mocker.patch("metagpt.rag.factories.index.ChromaVectorStore")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(chroma_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_chroma_store.assert_called_once()
|
||||
|
||||
def test_create_es_index(self, mocker, es_config, mock_from_vector_store, mock_embedding):
|
||||
# Mock
|
||||
mock_es_store = mocker.patch("metagpt.rag.factories.index.ElasticsearchStore")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(es_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_es_store.assert_called_once()
|
||||
71
tests/metagpt/rag/factories/test_llm.py
Normal file
71
tests/metagpt/rag/factories/test_llm.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import LLMMetadata
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.rag.factories.llm import RAGLLM, get_rag_llm
|
||||
|
||||
|
||||
class MockLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
...
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
"""_achat_completion implemented by inherited class"""
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return "ok"
|
||||
|
||||
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return "ok"
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
"""_achat_completion_stream implemented by inherited class"""
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: Union[str, list[dict[str, str]]],
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
stream=True,
|
||||
) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
class TestRAGLLM:
|
||||
@pytest.fixture
|
||||
def mock_model_infer(self):
|
||||
return MockLLM(config=LLMConfig())
|
||||
|
||||
@pytest.fixture
|
||||
def rag_llm(self, mock_model_infer):
|
||||
return RAGLLM(model_infer=mock_model_infer)
|
||||
|
||||
def test_metadata(self, rag_llm):
|
||||
metadata = rag_llm.metadata
|
||||
assert isinstance(metadata, LLMMetadata)
|
||||
assert metadata.context_window == rag_llm.context_window
|
||||
assert metadata.num_output == rag_llm.num_output
|
||||
assert metadata.model_name == rag_llm.model_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acomplete(self, rag_llm, mock_model_infer):
|
||||
response = await rag_llm.acomplete("question")
|
||||
assert response.text == "ok"
|
||||
|
||||
def test_complete(self, rag_llm, mock_model_infer):
|
||||
response = rag_llm.complete("question")
|
||||
assert response.text == "ok"
|
||||
|
||||
def test_stream_complete(self, rag_llm, mock_model_infer):
|
||||
rag_llm.stream_complete("question")
|
||||
|
||||
|
||||
def test_get_rag_llm():
|
||||
result = get_rag_llm(MockLLM(config=LLMConfig()))
|
||||
assert isinstance(result, RAGLLM)
|
||||
|
|
@ -1,41 +1,57 @@
|
|||
import pytest
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.llms import MockLLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.factories.ranker import RankerFactory
|
||||
from metagpt.rag.schema import LLMRankerConfig
|
||||
from metagpt.rag.schema import ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def ranker_factory(self) -> RankerFactory:
|
||||
return RankerFactory()
|
||||
@pytest.fixture(autouse=True)
|
||||
def ranker_factory(self):
|
||||
self.ranker_factory: RankerFactory = RankerFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self, mocker):
|
||||
return mocker.MagicMock(spec=LLM)
|
||||
def mock_llm(self):
|
||||
return MockLLM()
|
||||
|
||||
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
|
||||
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = ranker_factory.get_rankers()
|
||||
def test_get_rankers_with_no_configs(self, mock_llm, mocker):
|
||||
mocker.patch.object(self.ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = self.ranker_factory.get_rankers()
|
||||
assert len(default_rankers) == 0
|
||||
|
||||
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
def test_get_rankers_with_configs(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
rankers = ranker_factory.get_rankers(configs=[mock_config])
|
||||
rankers = self.ranker_factory.get_rankers(configs=[mock_config])
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
|
||||
def test_extract_llm_from_config(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
ranker = ranker_factory._create_llm_ranker(mock_config)
|
||||
extracted_llm = self.ranker_factory._extract_llm(config=mock_config)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_extract_llm_from_kwargs(self, mock_llm):
|
||||
extracted_llm = self.ranker_factory._extract_llm(llm=mock_llm)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_create_llm_ranker(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_llm_ranker(mock_config)
|
||||
assert isinstance(ranker, LLMRerank)
|
||||
|
||||
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
extracted_llm = ranker_factory._extract_llm(config=mock_config)
|
||||
assert extracted_llm == mock_llm
|
||||
def test_create_colbert_ranker(self, mocker, mock_llm):
|
||||
mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert")
|
||||
|
||||
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
|
||||
assert extracted_llm == mock_llm
|
||||
mock_config = ColbertRerankConfig(llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_colbert_ranker(mock_config)
|
||||
|
||||
assert ranker == "colbert"
|
||||
|
||||
def test_create_object_ranker(self, mocker, mock_llm):
|
||||
mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")
|
||||
|
||||
mock_config = ObjectRankerConfig(field_name="fake", llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_object_ranker(mock_config)
|
||||
|
||||
assert ranker == "object"
|
||||
|
|
|
|||
|
|
@ -1,18 +1,28 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
||||
from metagpt.rag.factories.retriever import RetrieverFactory
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
@pytest.fixture
|
||||
@pytest.fixture(autouse=True)
|
||||
def retriever_factory(self):
|
||||
return RetrieverFactory()
|
||||
self.retriever_factory: RetrieverFactory = RetrieverFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_index(self, mocker):
|
||||
|
|
@ -25,55 +35,79 @@ class TestRetrieverFactory:
|
|||
mock.docstore.docs.values.return_value = []
|
||||
return mock
|
||||
|
||||
def test_get_retriever_with_faiss_config(
|
||||
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
|
||||
):
|
||||
@pytest.fixture
|
||||
def mock_chroma_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ChromaVectorStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
|
||||
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(
|
||||
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
|
||||
):
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
|
||||
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
|
||||
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
|
||||
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
|
||||
def test_create_default_retriever(self, mocker, mock_vector_store_index):
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mock_vector_store_index.as_retriever = mocker.MagicMock()
|
||||
|
||||
retriever = retriever_factory.get_retriever()
|
||||
retriever = self.retriever_factory.get_retriever()
|
||||
|
||||
mock_vector_store_index.as_retriever.assert_called_once()
|
||||
assert retriever is mock_vector_store_index.as_retriever.return_value
|
||||
|
||||
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
def test_extract_index_from_config(self, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
|
||||
|
||||
extracted_index = retriever_factory._extract_index(config=mock_config)
|
||||
extracted_index = self.retriever_factory._extract_index(config=mock_config)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
def test_extract_index_from_kwargs(self, mock_vector_store_index):
|
||||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
|
|
|||
23
tests/metagpt/rag/rankers/test_base_ranker.py
Normal file
23
tests/metagpt/rag/rankers/test_base_ranker.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
|
||||
|
||||
from metagpt.rag.rankers.base import RAGRanker
|
||||
|
||||
|
||||
class SimpleRAGRanker(RAGRanker):
|
||||
def _postprocess_nodes(self, nodes, query_bundle=None):
|
||||
return [NodeWithScore(node=node.node, score=node.score + 1) for node in nodes]
|
||||
|
||||
|
||||
class TestSimpleRAGRanker:
|
||||
@pytest.fixture
|
||||
def ranker(self):
|
||||
return SimpleRAGRanker()
|
||||
|
||||
def test_postprocess_nodes_increases_scores(self, ranker):
|
||||
nodes = [NodeWithScore(node=TextNode(text="a"), score=10), NodeWithScore(node=TextNode(text="b"), score=20)]
|
||||
query_bundle = QueryBundle(query_str="test query")
|
||||
|
||||
processed_nodes = ranker._postprocess_nodes(nodes, query_bundle)
|
||||
|
||||
assert all(node.score == original_node.score + 1 for node, original_node in zip(processed_nodes, nodes))
|
||||
|
|
@ -14,7 +14,7 @@ class Record(BaseModel):
|
|||
|
||||
class TestObjectSortPostprocessor:
|
||||
@pytest.fixture
|
||||
def nodes_with_scores(self):
|
||||
def mock_nodes_with_scores(self):
|
||||
nodes = [
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
|
||||
|
|
@ -23,38 +23,47 @@ class TestObjectSortPostprocessor:
|
|||
return nodes
|
||||
|
||||
@pytest.fixture
|
||||
def query_bundle(self, mocker):
|
||||
def mock_query_bundle(self, mocker):
|
||||
return mocker.MagicMock(spec=QueryBundle)
|
||||
|
||||
def test_sort_descending(self, nodes_with_scores, query_bundle):
|
||||
def test_sort_descending(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [20, 10, 5]
|
||||
|
||||
def test_sort_ascending(self, nodes_with_scores, query_bundle):
|
||||
def test_sort_ascending(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [5, 10, 20]
|
||||
|
||||
def test_top_n_limit(self, nodes_with_scores, query_bundle):
|
||||
def test_top_n_limit(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert len(sorted_nodes) == 2
|
||||
assert [node.score for node in sorted_nodes] == [20, 10]
|
||||
|
||||
def test_invalid_json_metadata(self, query_bundle):
|
||||
def test_invalid_json_metadata(self, mock_query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle)
|
||||
postprocessor._postprocess_nodes(nodes, mock_query_bundle)
|
||||
|
||||
def test_missing_query_bundle(self, nodes_with_scores):
|
||||
def test_missing_query_bundle(self, mock_nodes_with_scores):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
|
||||
postprocessor._postprocess_nodes(mock_nodes_with_scores, query_bundle=None)
|
||||
|
||||
def test_field_not_found_in_object(self):
|
||||
def test_field_not_found_in_object(self, mock_query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes)
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle=mock_query_bundle)
|
||||
|
||||
def test_not_nodes(self, mock_query_bundle):
|
||||
nodes = []
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
result = postprocessor._postprocess_nodes(nodes, mock_query_bundle)
|
||||
assert result == []
|
||||
|
||||
def test_class_name(self):
|
||||
assert ObjectSortPostprocessor.class_name() == "ObjectSortPostprocessor"
|
||||
|
|
|
|||
21
tests/metagpt/rag/retrievers/test_base_retriever.py
Normal file
21
tests/metagpt/rag/retrievers/test_base_retriever.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
|
||||
|
||||
class SubModifiableRAGRetriever(ModifiableRAGRetriever):
|
||||
...
|
||||
|
||||
|
||||
class SubPersistableRAGRetriever(PersistableRAGRetriever):
|
||||
...
|
||||
|
||||
|
||||
class TestModifiableRAGRetriever:
|
||||
def test_subclasshook(self):
|
||||
result = SubModifiableRAGRetriever.__subclasshook__(SubModifiableRAGRetriever)
|
||||
assert result is NotImplemented
|
||||
|
||||
|
||||
class TestPersistableRAGRetriever:
|
||||
def test_subclasshook(self):
|
||||
result = SubPersistableRAGRetriever.__subclasshook__(SubPersistableRAGRetriever)
|
||||
assert result is NotImplemented
|
||||
|
|
@ -8,30 +8,30 @@ from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
|||
class TestDynamicBM25Retriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc1.get_content.return_value = "Document content 1"
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.doc2.get_content.return_value = "Document content 2"
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟index
|
||||
index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
index.storage_context.persist.return_value = "ok"
|
||||
|
||||
# 模拟nodes和tokenizer参数
|
||||
mock_nodes = []
|
||||
mock_tokenizer = mocker.MagicMock()
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
|
||||
# 初始化DynamicBM25Retriever对象,并提供必需的参数
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
|
||||
|
||||
def test_add_docs_updates_nodes_and_corpus(self):
|
||||
# Execute
|
||||
# Exec
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
assert len(self.retriever._nodes) == len(self.mock_nodes)
|
||||
assert len(self.retriever._corpus) == len(self.mock_nodes)
|
||||
self.retriever._tokenizer.assert_called()
|
||||
self.mock_bm25okapi.assert_called()
|
||||
|
||||
def test_persist(self):
|
||||
self.retriever.persist("")
|
||||
|
|
|
|||
20
tests/metagpt/rag/retrievers/test_chroma_retriever.py
Normal file
20
tests/metagpt/rag/retrievers/test_chroma_retriever.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
|
||||
|
||||
class TestChromaRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = ChromaRetriever(self.mock_index)
|
||||
|
||||
def test_add_nodes(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
20
tests/metagpt/rag/retrievers/test_es_retriever.py
Normal file
20
tests/metagpt/rag/retrievers/test_es_retriever.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
|
||||
|
||||
class TestElasticsearchRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = ElasticsearchRetriever(self.mock_index)
|
||||
|
||||
def test_add_nodes(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
|
|
@ -7,16 +7,19 @@ from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
|||
class TestFAISSRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟FAISSRetriever的_index属性
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = FAISSRetriever(self.mock_index)
|
||||
|
||||
def test_add_docs_calls_insert_for_each_document(self, mocker):
|
||||
def test_add_docs_calls_insert_for_each_document(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
assert self.mock_index.insert_nodes.assert_called
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
|
||||
def test_persist(self):
|
||||
self.retriever.persist("")
|
||||
|
||||
self.mock_index.storage_context.persist.assert_called()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, TextNode
|
||||
|
||||
|
|
@ -7,18 +5,30 @@ from metagpt.rag.retrievers import SimpleHybridRetriever
|
|||
|
||||
|
||||
class TestSimpleHybridRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self, mocker):
|
||||
return mocker.MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hybrid_retriever(self, mock_retriever) -> SimpleHybridRetriever:
|
||||
return SimpleHybridRetriever(mock_retriever)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_node(self):
|
||||
return NodeWithScore(node=TextNode(id_="2"), score=0.95)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self):
|
||||
async def test_aretrieve(self, mocker):
|
||||
question = "test query"
|
||||
|
||||
# Create mock retrievers
|
||||
mock_retriever1 = AsyncMock()
|
||||
mock_retriever1 = mocker.AsyncMock()
|
||||
mock_retriever1.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="1"), score=1.0),
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
]
|
||||
|
||||
mock_retriever2 = AsyncMock()
|
||||
mock_retriever2 = mocker.AsyncMock()
|
||||
mock_retriever2.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
NodeWithScore(node=TextNode(id_="3"), score=0.8),
|
||||
|
|
@ -37,3 +47,11 @@ class TestSimpleHybridRetriever:
|
|||
# Check if the scores are correct (assuming you want the highest score)
|
||||
node_scores = {node.node.node_id: node.score for node in results}
|
||||
assert node_scores["2"] == 0.95
|
||||
|
||||
def test_add_nodes(self, mock_hybrid_retriever: SimpleHybridRetriever, mock_node):
|
||||
mock_hybrid_retriever.add_nodes([mock_node])
|
||||
mock_hybrid_retriever.retrievers[0].add_nodes.assert_called_once()
|
||||
|
||||
def test_persist(self, mock_hybrid_retriever: SimpleHybridRetriever):
|
||||
mock_hybrid_retriever.persist("")
|
||||
mock_hybrid_retriever.retrievers[0].persist.assert_called_once()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue