upgrade llama-index-vector-stores-chroma and rag test coverage 100%

This commit is contained in:
seehi 2024-03-27 19:53:50 +08:00
parent 6450a09d3b
commit 6cb5492f02
21 changed files with 600 additions and 382 deletions

View file

@ -8,6 +8,7 @@ import re
import time
from typing import Any, Iterable
from llama_index.vector_stores.chroma import ChromaVectorStore
from pydantic import ConfigDict, Field
from metagpt.config2 import config as CONFIG
@ -15,7 +16,6 @@ from metagpt.environment.base_env import Environment
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
from metagpt.logs import logger
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file

View file

@ -5,6 +5,7 @@ from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_s
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
@ -17,7 +18,6 @@ from metagpt.rag.schema import (
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RAGIndexFactory(ConfigBasedFactory):

View file

@ -6,6 +6,7 @@ import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
@ -25,7 +26,6 @@ from metagpt.rag.schema import (
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RetrieverFactory(ConfigBasedFactory):

View file

@ -1,3 +0,0 @@
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
__all__ = ["ChromaVectorStore"]

View file

@ -1,290 +0,0 @@
"""Chroma vector store.
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
"""
import math
from typing import Any, Dict, Generator, List, Optional, cast
import chromadb
from chromadb.api.models.Collection import Collection
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
from llama_index.core.utils import truncate_text
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.core.vector_stores.utils import (
legacy_metadata_dict_to_node,
metadata_dict_to_node,
node_to_metadata_dict,
)
from metagpt.logs import logger
def _transform_chroma_filter_condition(condition: str) -> str:
"""Translate standard metadata filter op to Chroma specific spec."""
if condition == "and":
return "$and"
elif condition == "or":
return "$or"
else:
raise ValueError(f"Filter condition {condition} not supported")
def _transform_chroma_filter_operator(operator: str) -> str:
"""Translate standard metadata filter operator to Chroma specific spec."""
if operator == "!=":
return "$ne"
elif operator == "==":
return "$eq"
elif operator == ">":
return "$gt"
elif operator == "<":
return "$lt"
elif operator == ">=":
return "$gte"
elif operator == "<=":
return "$lte"
else:
raise ValueError(f"Filter operator {operator} not supported")
def _to_chroma_filter(
standard_filters: MetadataFilters,
) -> dict:
"""Translate standard metadata filters to Chroma specific spec."""
filters = {}
filters_list = []
condition = standard_filters.condition or "and"
condition = _transform_chroma_filter_condition(condition)
if standard_filters.filters:
for filter in standard_filters.filters:
if filter.operator:
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
else:
filters_list.append({filter.key: filter.value})
if len(filters_list) == 1:
# If there is only one filter, return it directly
return filters_list[0]
elif len(filters_list) > 1:
filters[condition] = filters_list
return filters
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size
Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]
class ChromaVectorStore(BasePydanticVectorStore):
"""Chroma vector store.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""
stores_text: bool = True
flat_metadata: bool = True
collection_name: Optional[str]
host: Optional[str]
port: Optional[str]
ssl: bool
headers: Optional[Dict[str, str]]
persist_dir: Optional[str]
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
_collection: Any = PrivateAttr()
def __init__(
self,
chroma_collection: Optional[Any] = None,
collection_name: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
persist_dir: Optional[str] = None,
collection_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Init params."""
collection_kwargs = collection_kwargs or {}
if chroma_collection is None:
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
else:
self._collection = cast(Collection, chroma_collection)
super().__init__(
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_name=collection_name,
persist_dir=persist_dir,
collection_kwargs=collection_kwargs or {},
)
@classmethod
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
try:
from chromadb import Collection
except ImportError:
raise ImportError(import_err_msg)
if not isinstance(collection, Collection):
raise Exception("argument is not chromadb collection instance")
return cls(chroma_collection=collection)
@classmethod
def from_params(
cls,
collection_name: str,
host: Optional[str] = None,
port: Optional[str] = None,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
persist_dir: Optional[str] = None,
collection_kwargs: dict = {},
**kwargs: Any,
) -> "ChromaVectorStore":
if persist_dir:
client = chromadb.PersistentClient(path=persist_dir)
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
elif host and port:
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
else:
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
return cls(
chroma_collection=collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
persist_dir=persist_dir,
collection_kwargs=collection_kwargs,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "ChromaVectorStore"
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
"""
if not self._collection:
raise ValueError("Collection not initialized")
max_chunk_size = MAX_CHUNK_SIZE
node_chunks = chunk_list(nodes, max_chunk_size)
all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
for key in metadata_dict:
if metadata_dict[key] is None:
metadata_dict[key] = ""
metadatas.append(metadata_dict)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)
return all_ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
self._collection.delete(where={"document_id": ref_doc_id})
@property
def client(self) -> Any:
"""Return client."""
return self._collection
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
"""
if query.filters is not None:
if "where" in kwargs:
raise ValueError(
"Cannot specify metadata filters via both query and kwargs. "
"Use kwargs only for chroma specific items that are "
"not supported via the generic query interface."
)
where = _to_chroma_filter(query.filters)
else:
where = kwargs.pop("where", {})
results = self._collection.query(
query_embeddings=query.query_embedding,
n_results=query.similarity_top_k,
where=where,
**kwargs,
)
logger.debug(f"> Top {len(results['documents'])} nodes:")
nodes = []
similarities = []
ids = []
for node_id, text, metadata, distance in zip(
results["ids"][0],
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
):
try:
node = metadata_dict_to_node(metadata)
node.set_content(text)
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
node = TextNode(
text=text,
id_=node_id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)
nodes.append(node)
similarity_score = math.exp(-distance)
similarities.append(similarity_score)
logger.debug(
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
)
ids.append(node_id)
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)

View file

@ -37,8 +37,8 @@ extras_require = {
"llama-index-retrievers-bm25==0.1.3",
"llama-index-vector-stores-faiss==0.1.1",
"llama-index-vector-stores-elasticsearch==0.1.6",
"llama-index-vector-stores-chroma==0.1.6",
"llama-index-postprocessor-colbert-rerank==0.1.1",
"chromadb==0.4.23",
],
}

View file

@ -1,12 +1,26 @@
import json
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import Document, TextNode
from llama_index.core.embeddings import MockEmbedding
from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
class TestSimpleEngine:
@pytest.fixture
def mock_llm(self):
return MockLLM()
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
@pytest.fixture
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@ -54,7 +68,7 @@ class TestSimpleEngine:
retriever_configs = [mocker.MagicMock()]
ranker_configs = [mocker.MagicMock()]
# Execute
# Exec
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
@ -65,7 +79,7 @@ class TestSimpleEngine:
ranker_configs=ranker_configs,
)
# Assertions
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
@ -75,6 +89,68 @@ class TestSimpleEngine:
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
def test_from_docs_without_file(self):
with pytest.raises(ValueError):
SimpleEngine.from_docs()
def test_from_objs(self, mock_llm, mock_embedding):
# Mock
class MockRAGObject:
def rag_key(self):
return "key"
def model_dump_json(self):
return "{}"
objs = [MockRAGObject()]
# Setup
retriever_configs = []
ranker_configs = []
# Exec
engine = SimpleEngine.from_objs(
objs=objs,
llm=mock_llm,
embed_model=mock_embedding,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is not None
def test_from_objs_with_bm25_config(self):
# Setup
retriever_configs = [BM25RetrieverConfig()]
# Exec
with pytest.raises(ValueError):
SimpleEngine.from_objs(
objs=[],
llm=MockLLM(),
retriever_configs=retriever_configs,
ranker_configs=[],
)
def test_from_index(self, mocker, mock_llm, mock_embedding):
# Mock
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
mock_get_index.return_value = mock_index
# Exec
engine = SimpleEngine.from_index(
index_config=mock_index,
embed_model=mock_embedding,
llm=mock_llm,
)
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is mock_index
@pytest.mark.asyncio
async def test_asearch(self, mocker):
# Mock
@ -86,10 +162,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mocker.MagicMock())
engine.aquery = mock_aquery
# Execute
# Exec
result = await engine.asearch(test_query)
# Assertions
# Assert
mock_aquery.assert_called_once_with(test_query)
assert result == expected_result
@ -106,10 +182,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mocker.MagicMock())
test_query = "test query"
# Execute
# Exec
result = await engine.aretrieve(test_query)
# Assertions
# Assert
mock_query_bundle.assert_called_once_with(test_query)
mock_super_aretrieve.assert_called_once_with("query_bundle")
assert result[0].text == "node_with_score"
@ -134,10 +210,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
input_files = ["test_file1", "test_file2"]
# Execute
# Exec
engine.add_docs(input_files=input_files)
# Assertions
# Assert
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
@ -156,11 +232,79 @@ class TestSimpleEngine:
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
# Execute
# Exec
engine.add_objs(objs=objs)
# Assertions
# Assert
assert mock_retriever.add_nodes.call_count == 1
for node in mock_retriever.add_nodes.call_args[0][0]:
assert isinstance(node, TextNode)
assert "is_obj" in node.metadata
def test_persist_successfully(self, mocker):
# Mock
mock_retriever = mocker.MagicMock(spec=PersistableRAGRetriever)
mock_retriever.persist.return_value = mocker.MagicMock()
# Setup
engine = SimpleEngine(retriever=mock_retriever)
# Exec
engine.persist(persist_dir="")
def test_ensure_retriever_of_type(self, mocker):
# Mock
class MyRetriever:
def add_nodes(self):
...
mock_retriever = mocker.MagicMock(spec=SimpleHybridRetriever)
mock_retriever.retrievers = [MyRetriever()]
# Setup
engine = SimpleEngine(retriever=mock_retriever)
# Assert
engine._ensure_retriever_of_type(ModifiableRAGRetriever)
with pytest.raises(TypeError):
engine._ensure_retriever_of_type(PersistableRAGRetriever)
with pytest.raises(TypeError):
other_engine = SimpleEngine(retriever=mocker.MagicMock(spec=ModifiableRAGRetriever))
other_engine._ensure_retriever_of_type(PersistableRAGRetriever)
def test_with_obj_metadata(self, mocker):
# Mock
node = NodeWithScore(
node=ObjectNode(
text="example",
metadata={
"is_obj": True,
"obj_cls_name": "ExampleObject",
"obj_mod_name": "__main__",
"obj_json": json.dumps({"key": "test_key", "value": "test_value"}),
},
)
)
class ExampleObject:
def __init__(self, key, value):
self.key = key
self.value = value
def __eq__(self, other):
return self.key == other.key and self.value == other.value
mock_import_class = mocker.patch("metagpt.rag.engines.simple.import_class")
mock_import_class.return_value = ExampleObject
# Setup
SimpleEngine._try_reconstruct_obj([node])
# Exec
expected_obj = ExampleObject(key="test_key", value="test_value")
# Assert
assert "obj" in node.node.metadata
assert node.node.metadata["obj"] == expected_obj

View file

@ -0,0 +1,43 @@
import pytest
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
class TestRAGEmbeddingFactory:
@pytest.fixture(autouse=True)
def mock_embedding_factory(self):
self.embedding_factory = RAGEmbeddingFactory()
@pytest.fixture
def mock_openai_embedding(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
@pytest.fixture
def mock_azure_embedding(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
def test_get_rag_embedding_openai(self, mock_openai_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
# Assert
mock_openai_embedding.assert_called_once()
def test_get_rag_embedding_azure(self, mock_azure_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
# Assert
mock_azure_embedding.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
# Mock
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
mock_config.llm.api_type = LLMType.OPENAI
# Exec
self.embedding_factory.get_rag_embedding()
# Assert
mock_openai_embedding.assert_called_once()

View file

@ -0,0 +1,89 @@
import pytest
from llama_index.core.embeddings import MockEmbedding
from metagpt.rag.factories.index import RAGIndexFactory
from metagpt.rag.schema import (
BM25IndexConfig,
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
FAISSIndexConfig,
)
class TestRAGIndexFactory:
@pytest.fixture(autouse=True)
def setup(self):
self.index_factory = RAGIndexFactory()
@pytest.fixture
def faiss_config(self):
return FAISSIndexConfig(persist_path="")
@pytest.fixture
def chroma_config(self):
return ChromaIndexConfig(persist_path="", collection_name="")
@pytest.fixture
def bm25_config(self):
return BM25IndexConfig(persist_path="")
@pytest.fixture
def es_config(self, mocker):
return ElasticsearchIndexConfig(store_config=ElasticsearchStoreConfig())
@pytest.fixture
def mock_storage_context(self, mocker):
return mocker.patch("metagpt.rag.factories.index.StorageContext.from_defaults")
@pytest.fixture
def mock_load_index_from_storage(self, mocker):
return mocker.patch("metagpt.rag.factories.index.load_index_from_storage")
@pytest.fixture
def mock_from_vector_store(self, mocker):
return mocker.patch("metagpt.rag.factories.index.VectorStoreIndex.from_vector_store")
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
def test_create_faiss_index(
self, mocker, faiss_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
):
# Mock
mock_faiss_store = mocker.patch("metagpt.rag.factories.index.FaissVectorStore.from_persist_dir")
# Exec
self.index_factory.get_index(faiss_config, embed_model=mock_embedding)
# Assert
mock_faiss_store.assert_called_once()
def test_create_bm25_index(
self, mocker, bm25_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
):
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
# Mock
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
mock_chroma_db.get_or_create_collection.return_value = mocker.MagicMock()
mock_chroma_store = mocker.patch("metagpt.rag.factories.index.ChromaVectorStore")
# Exec
self.index_factory.get_index(chroma_config, embed_model=mock_embedding)
# Assert
mock_chroma_store.assert_called_once()
def test_create_es_index(self, mocker, es_config, mock_from_vector_store, mock_embedding):
# Mock
mock_es_store = mocker.patch("metagpt.rag.factories.index.ElasticsearchStore")
# Exec
self.index_factory.get_index(es_config, embed_model=mock_embedding)
# Assert
mock_es_store.assert_called_once()

View file

@ -0,0 +1,71 @@
from typing import Optional, Union
import pytest
from llama_index.core.llms import LLMMetadata
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.provider.base_llm import BaseLLM
from metagpt.rag.factories.llm import RAGLLM, get_rag_llm
class MockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
...
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""_achat_completion implemented by inherited class"""
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return "ok"
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return "ok"
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
"""_achat_completion_stream implemented by inherited class"""
async def aask(
self,
msg: Union[str, list[dict[str, str]]],
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
stream=True,
) -> str:
return "ok"
class TestRAGLLM:
@pytest.fixture
def mock_model_infer(self):
return MockLLM(config=LLMConfig())
@pytest.fixture
def rag_llm(self, mock_model_infer):
return RAGLLM(model_infer=mock_model_infer)
def test_metadata(self, rag_llm):
metadata = rag_llm.metadata
assert isinstance(metadata, LLMMetadata)
assert metadata.context_window == rag_llm.context_window
assert metadata.num_output == rag_llm.num_output
assert metadata.model_name == rag_llm.model_name
@pytest.mark.asyncio
async def test_acomplete(self, rag_llm, mock_model_infer):
response = await rag_llm.acomplete("question")
assert response.text == "ok"
def test_complete(self, rag_llm, mock_model_infer):
response = rag_llm.complete("question")
assert response.text == "ok"
def test_stream_complete(self, rag_llm, mock_model_infer):
rag_llm.stream_complete("question")
def test_get_rag_llm():
result = get_rag_llm(MockLLM(config=LLMConfig()))
assert isinstance(result, RAGLLM)

View file

@ -1,41 +1,57 @@
import pytest
from llama_index.core.llms import LLM
from llama_index.core.llms import MockLLM
from llama_index.core.postprocessor import LLMRerank
from metagpt.rag.factories.ranker import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
from metagpt.rag.schema import ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig
class TestRankerFactory:
@pytest.fixture
def ranker_factory(self) -> RankerFactory:
return RankerFactory()
@pytest.fixture(autouse=True)
def ranker_factory(self):
self.ranker_factory: RankerFactory = RankerFactory()
@pytest.fixture
def mock_llm(self, mocker):
return mocker.MagicMock(spec=LLM)
def mock_llm(self):
return MockLLM()
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = ranker_factory.get_rankers()
def test_get_rankers_with_no_configs(self, mock_llm, mocker):
mocker.patch.object(self.ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = self.ranker_factory.get_rankers()
assert len(default_rankers) == 0
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
def test_get_rankers_with_configs(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
rankers = ranker_factory.get_rankers(configs=[mock_config])
rankers = self.ranker_factory.get_rankers(configs=[mock_config])
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
def test_extract_llm_from_config(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = ranker_factory._create_llm_ranker(mock_config)
extracted_llm = self.ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_extract_llm_from_kwargs(self, mock_llm):
extracted_llm = self.ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm
def test_create_llm_ranker(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = self.ranker_factory._create_llm_ranker(mock_config)
assert isinstance(ranker, LLMRerank)
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
extracted_llm = ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_create_colbert_ranker(self, mocker, mock_llm):
mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert")
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm
mock_config = ColbertRerankConfig(llm=mock_llm)
ranker = self.ranker_factory._create_colbert_ranker(mock_config)
assert ranker == "colbert"
def test_create_object_ranker(self, mocker, mock_llm):
mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")
mock_config = ObjectRankerConfig(field_name="fake", llm=mock_llm)
ranker = self.ranker_factory._create_object_ranker(mock_config)
assert ranker == "object"

View file

@ -1,18 +1,28 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
)
class TestRetrieverFactory:
@pytest.fixture
@pytest.fixture(autouse=True)
def retriever_factory(self):
return RetrieverFactory()
self.retriever_factory: RetrieverFactory = RetrieverFactory()
@pytest.fixture
def mock_faiss_index(self, mocker):
@ -25,55 +35,79 @@ class TestRetrieverFactory:
mock.docstore.docs.values.return_value = []
return mock
def test_get_retriever_with_faiss_config(
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
):
@pytest.fixture
def mock_chroma_vector_store(self, mocker):
return mocker.MagicMock(spec=ChromaVectorStore)
@pytest.fixture
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
):
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
assert isinstance(retriever, SimpleHybridRetriever)
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, ElasticsearchRetriever)
def test_create_default_retriever(self, mocker, mock_vector_store_index):
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mock_vector_store_index.as_retriever = mocker.MagicMock()
retriever = retriever_factory.get_retriever()
retriever = self.retriever_factory.get_retriever()
mock_vector_store_index.as_retriever.assert_called_once()
assert retriever is mock_vector_store_index.as_retriever.return_value
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
def test_extract_index_from_config(self, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
extracted_index = retriever_factory._extract_index(config=mock_config)
extracted_index = self.retriever_factory._extract_index(config=mock_config)
assert extracted_index == mock_vector_store_index
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
def test_extract_index_from_kwargs(self, mock_vector_store_index):
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index

View file

@ -0,0 +1,23 @@
import pytest
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from metagpt.rag.rankers.base import RAGRanker
class SimpleRAGRanker(RAGRanker):
def _postprocess_nodes(self, nodes, query_bundle=None):
return [NodeWithScore(node=node.node, score=node.score + 1) for node in nodes]
class TestSimpleRAGRanker:
@pytest.fixture
def ranker(self):
return SimpleRAGRanker()
def test_postprocess_nodes_increases_scores(self, ranker):
nodes = [NodeWithScore(node=TextNode(text="a"), score=10), NodeWithScore(node=TextNode(text="b"), score=20)]
query_bundle = QueryBundle(query_str="test query")
processed_nodes = ranker._postprocess_nodes(nodes, query_bundle)
assert all(node.score == original_node.score + 1 for node, original_node in zip(processed_nodes, nodes))

View file

@ -14,7 +14,7 @@ class Record(BaseModel):
class TestObjectSortPostprocessor:
@pytest.fixture
def nodes_with_scores(self):
def mock_nodes_with_scores(self):
nodes = [
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
@ -23,38 +23,47 @@ class TestObjectSortPostprocessor:
return nodes
@pytest.fixture
def query_bundle(self, mocker):
def mock_query_bundle(self, mocker):
return mocker.MagicMock(spec=QueryBundle)
def test_sort_descending(self, nodes_with_scores, query_bundle):
def test_sort_descending(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert [node.score for node in sorted_nodes] == [20, 10, 5]
def test_sort_ascending(self, nodes_with_scores, query_bundle):
def test_sort_ascending(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert [node.score for node in sorted_nodes] == [5, 10, 20]
def test_top_n_limit(self, nodes_with_scores, query_bundle):
def test_top_n_limit(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert len(sorted_nodes) == 2
assert [node.score for node in sorted_nodes] == [20, 10]
def test_invalid_json_metadata(self, query_bundle):
def test_invalid_json_metadata(self, mock_query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes, query_bundle)
postprocessor._postprocess_nodes(nodes, mock_query_bundle)
def test_missing_query_bundle(self, nodes_with_scores):
def test_missing_query_bundle(self, mock_nodes_with_scores):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
postprocessor._postprocess_nodes(mock_nodes_with_scores, query_bundle=None)
def test_field_not_found_in_object(self):
def test_field_not_found_in_object(self, mock_query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes)
postprocessor._postprocess_nodes(nodes, query_bundle=mock_query_bundle)
def test_not_nodes(self, mock_query_bundle):
nodes = []
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
result = postprocessor._postprocess_nodes(nodes, mock_query_bundle)
assert result == []
def test_class_name(self):
assert ObjectSortPostprocessor.class_name() == "ObjectSortPostprocessor"

View file

@ -0,0 +1,21 @@
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
class SubModifiableRAGRetriever(ModifiableRAGRetriever):
...
class SubPersistableRAGRetriever(PersistableRAGRetriever):
...
class TestModifiableRAGRetriever:
def test_subclasshook(self):
result = SubModifiableRAGRetriever.__subclasshook__(SubModifiableRAGRetriever)
assert result is NotImplemented
class TestPersistableRAGRetriever:
def test_subclasshook(self):
result = SubPersistableRAGRetriever.__subclasshook__(SubPersistableRAGRetriever)
assert result is NotImplemented

View file

@ -8,30 +8,30 @@ from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
class TestDynamicBM25Retriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc1.get_content.return_value = "Document content 1"
self.doc2 = mocker.MagicMock(spec=Node)
self.doc2.get_content.return_value = "Document content 2"
self.mock_nodes = [self.doc1, self.doc2]
# 模拟index
index = mocker.MagicMock(spec=VectorStoreIndex)
index.storage_context.persist.return_value = "ok"
# 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
# 初始化DynamicBM25Retriever对象并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
def test_add_docs_updates_nodes_and_corpus(self):
# Execute
# Exec
self.retriever.add_nodes(self.mock_nodes)
# Assertions
# Assert
assert len(self.retriever._nodes) == len(self.mock_nodes)
assert len(self.retriever._corpus) == len(self.mock_nodes)
self.retriever._tokenizer.assert_called()
self.mock_bm25okapi.assert_called()
def test_persist(self):
self.retriever.persist("")

View file

@ -0,0 +1,20 @@
import pytest
from llama_index.core.schema import Node
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
class TestChromaRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
self.mock_index = mocker.MagicMock()
self.retriever = ChromaRetriever(self.mock_index)
def test_add_nodes(self):
self.retriever.add_nodes(self.mock_nodes)
self.mock_index.insert_nodes.assert_called()

View file

@ -0,0 +1,20 @@
import pytest
from llama_index.core.schema import Node
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
class TestElasticsearchRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
self.mock_index = mocker.MagicMock()
self.retriever = ElasticsearchRetriever(self.mock_index)
def test_add_nodes(self):
self.retriever.add_nodes(self.mock_nodes)
self.mock_index.insert_nodes.assert_called()

View file

@ -7,16 +7,19 @@ from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
class TestFAISSRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
# 模拟FAISSRetriever的_index属性
self.mock_index = mocker.MagicMock()
self.retriever = FAISSRetriever(self.mock_index)
def test_add_docs_calls_insert_for_each_document(self, mocker):
def test_add_docs_calls_insert_for_each_document(self):
self.retriever.add_nodes(self.mock_nodes)
assert self.mock_index.insert_nodes.assert_called
self.mock_index.insert_nodes.assert_called()
def test_persist(self):
self.retriever.persist("")
self.mock_index.storage_context.persist.assert_called()

View file

@ -1,5 +1,3 @@
from unittest.mock import AsyncMock
import pytest
from llama_index.core.schema import NodeWithScore, TextNode
@ -7,18 +5,30 @@ from metagpt.rag.retrievers import SimpleHybridRetriever
class TestSimpleHybridRetriever:
@pytest.fixture
def mock_retriever(self, mocker):
return mocker.MagicMock()
@pytest.fixture
def mock_hybrid_retriever(self, mock_retriever) -> SimpleHybridRetriever:
return SimpleHybridRetriever(mock_retriever)
@pytest.fixture
def mock_node(self):
return NodeWithScore(node=TextNode(id_="2"), score=0.95)
@pytest.mark.asyncio
async def test_aretrieve(self):
async def test_aretrieve(self, mocker):
question = "test query"
# Create mock retrievers
mock_retriever1 = AsyncMock()
mock_retriever1 = mocker.AsyncMock()
mock_retriever1.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="1"), score=1.0),
NodeWithScore(node=TextNode(id_="2"), score=0.95),
]
mock_retriever2 = AsyncMock()
mock_retriever2 = mocker.AsyncMock()
mock_retriever2.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="2"), score=0.95),
NodeWithScore(node=TextNode(id_="3"), score=0.8),
@ -37,3 +47,11 @@ class TestSimpleHybridRetriever:
# Check if the scores are correct (assuming you want the highest score)
node_scores = {node.node.node_id: node.score for node in results}
assert node_scores["2"] == 0.95
def test_add_nodes(self, mock_hybrid_retriever: SimpleHybridRetriever, mock_node):
mock_hybrid_retriever.add_nodes([mock_node])
mock_hybrid_retriever.retrievers[0].add_nodes.assert_called_once()
def test_persist(self, mock_hybrid_retriever: SimpleHybridRetriever):
mock_hybrid_retriever.persist("")
mock_hybrid_retriever.retrievers[0].persist.assert_called_once()