mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
upgrade llama-index-vector-stores-chroma and rag test coverage 100%
This commit is contained in:
parent
6450a09d3b
commit
6cb5492f02
21 changed files with 600 additions and 382 deletions
|
|
@ -8,6 +8,7 @@ import re
|
|||
import time
|
||||
from typing import Any, Iterable
|
||||
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import config as CONFIG
|
||||
|
|
@ -15,7 +16,6 @@ from metagpt.environment.base_env import Environment
|
|||
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_s
|
|||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
|
|
@ -17,7 +18,6 @@ from metagpt.rag.schema import (
|
|||
ElasticsearchKeywordIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RAGIndexFactory(ConfigBasedFactory):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import chromadb
|
|||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
|
|
@ -25,7 +26,6 @@ from metagpt.rag.schema import (
|
|||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
|
||||
|
||||
__all__ = ["ChromaVectorStore"]
|
||||
|
|
@ -1,290 +0,0 @@
|
|||
"""Chroma vector store.
|
||||
|
||||
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
|
||||
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Generator, List, Optional, cast
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from llama_index.core.bridge.pydantic import Field, PrivateAttr
|
||||
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
|
||||
from llama_index.core.utils import truncate_text
|
||||
from llama_index.core.vector_stores.types import (
|
||||
BasePydanticVectorStore,
|
||||
MetadataFilters,
|
||||
VectorStoreQuery,
|
||||
VectorStoreQueryResult,
|
||||
)
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
legacy_metadata_dict_to_node,
|
||||
metadata_dict_to_node,
|
||||
node_to_metadata_dict,
|
||||
)
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def _transform_chroma_filter_condition(condition: str) -> str:
|
||||
"""Translate standard metadata filter op to Chroma specific spec."""
|
||||
if condition == "and":
|
||||
return "$and"
|
||||
elif condition == "or":
|
||||
return "$or"
|
||||
else:
|
||||
raise ValueError(f"Filter condition {condition} not supported")
|
||||
|
||||
|
||||
def _transform_chroma_filter_operator(operator: str) -> str:
|
||||
"""Translate standard metadata filter operator to Chroma specific spec."""
|
||||
if operator == "!=":
|
||||
return "$ne"
|
||||
elif operator == "==":
|
||||
return "$eq"
|
||||
elif operator == ">":
|
||||
return "$gt"
|
||||
elif operator == "<":
|
||||
return "$lt"
|
||||
elif operator == ">=":
|
||||
return "$gte"
|
||||
elif operator == "<=":
|
||||
return "$lte"
|
||||
else:
|
||||
raise ValueError(f"Filter operator {operator} not supported")
|
||||
|
||||
|
||||
def _to_chroma_filter(
|
||||
standard_filters: MetadataFilters,
|
||||
) -> dict:
|
||||
"""Translate standard metadata filters to Chroma specific spec."""
|
||||
filters = {}
|
||||
filters_list = []
|
||||
condition = standard_filters.condition or "and"
|
||||
condition = _transform_chroma_filter_condition(condition)
|
||||
if standard_filters.filters:
|
||||
for filter in standard_filters.filters:
|
||||
if filter.operator:
|
||||
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
|
||||
else:
|
||||
filters_list.append({filter.key: filter.value})
|
||||
if len(filters_list) == 1:
|
||||
# If there is only one filter, return it directly
|
||||
return filters_list[0]
|
||||
elif len(filters_list) > 1:
|
||||
filters[condition] = filters_list
|
||||
return filters
|
||||
|
||||
|
||||
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
|
||||
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
|
||||
|
||||
|
||||
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
|
||||
"""Yield successive max_chunk_size-sized chunks from lst.
|
||||
Args:
|
||||
lst (List[BaseNode]): list of nodes with embeddings
|
||||
max_chunk_size (int): max chunk size
|
||||
Yields:
|
||||
Generator[List[BaseNode], None, None]: list of nodes with embeddings
|
||||
"""
|
||||
for i in range(0, len(lst), max_chunk_size):
|
||||
yield lst[i : i + max_chunk_size]
|
||||
|
||||
|
||||
class ChromaVectorStore(BasePydanticVectorStore):
|
||||
"""Chroma vector store.
|
||||
In this vector store, embeddings are stored within a ChromaDB collection.
|
||||
During query time, the index uses ChromaDB to query for the top
|
||||
k most similar nodes.
|
||||
Args:
|
||||
chroma_collection (chromadb.api.models.Collection.Collection):
|
||||
ChromaDB collection instance
|
||||
"""
|
||||
|
||||
stores_text: bool = True
|
||||
flat_metadata: bool = True
|
||||
collection_name: Optional[str]
|
||||
host: Optional[str]
|
||||
port: Optional[str]
|
||||
ssl: bool
|
||||
headers: Optional[Dict[str, str]]
|
||||
persist_dir: Optional[str]
|
||||
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
_collection: Any = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_collection: Optional[Any] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
collection_kwargs = collection_kwargs or {}
|
||||
if chroma_collection is None:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
self._collection = cast(Collection, chroma_collection)
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
collection_name=collection_name,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
|
||||
try:
|
||||
from chromadb import Collection
|
||||
except ImportError:
|
||||
raise ImportError(import_err_msg)
|
||||
if not isinstance(collection, Collection):
|
||||
raise Exception("argument is not chromadb collection instance")
|
||||
return cls(chroma_collection=collection)
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
collection_name: str,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: dict = {},
|
||||
**kwargs: Any,
|
||||
) -> "ChromaVectorStore":
|
||||
if persist_dir:
|
||||
client = chromadb.PersistentClient(path=persist_dir)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
elif host and port:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
|
||||
return cls(
|
||||
chroma_collection=collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ChromaVectorStore"
|
||||
|
||||
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
|
||||
"""Add nodes to index.
|
||||
Args:
|
||||
nodes: List[BaseNode]: list of nodes with embeddings
|
||||
"""
|
||||
if not self._collection:
|
||||
raise ValueError("Collection not initialized")
|
||||
max_chunk_size = MAX_CHUNK_SIZE
|
||||
node_chunks = chunk_list(nodes, max_chunk_size)
|
||||
all_ids = []
|
||||
for node_chunk in node_chunks:
|
||||
embeddings = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
documents = []
|
||||
for node in node_chunk:
|
||||
embeddings.append(node.get_embedding())
|
||||
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
|
||||
for key in metadata_dict:
|
||||
if metadata_dict[key] is None:
|
||||
metadata_dict[key] = ""
|
||||
metadatas.append(metadata_dict)
|
||||
ids.append(node.node_id)
|
||||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
||||
self._collection.add(
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
all_ids.extend(ids)
|
||||
return all_ids
|
||||
|
||||
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""
|
||||
Delete nodes using with ref_doc_id.
|
||||
Args:
|
||||
ref_doc_id (str): The doc_id of the document to delete.
|
||||
"""
|
||||
self._collection.delete(where={"document_id": ref_doc_id})
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
"""Return client."""
|
||||
return self._collection
|
||||
|
||||
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes.
|
||||
Args:
|
||||
query_embedding (List[float]): query embedding
|
||||
similarity_top_k (int): top k most similar nodes
|
||||
"""
|
||||
if query.filters is not None:
|
||||
if "where" in kwargs:
|
||||
raise ValueError(
|
||||
"Cannot specify metadata filters via both query and kwargs. "
|
||||
"Use kwargs only for chroma specific items that are "
|
||||
"not supported via the generic query interface."
|
||||
)
|
||||
where = _to_chroma_filter(query.filters)
|
||||
else:
|
||||
where = kwargs.pop("where", {})
|
||||
results = self._collection.query(
|
||||
query_embeddings=query.query_embedding,
|
||||
n_results=query.similarity_top_k,
|
||||
where=where,
|
||||
**kwargs,
|
||||
)
|
||||
logger.debug(f"> Top {len(results['documents'])} nodes:")
|
||||
nodes = []
|
||||
similarities = []
|
||||
ids = []
|
||||
for node_id, text, metadata, distance in zip(
|
||||
results["ids"][0],
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
):
|
||||
try:
|
||||
node = metadata_dict_to_node(metadata)
|
||||
node.set_content(text)
|
||||
except Exception:
|
||||
# NOTE: deprecated legacy logic for backward compatibility
|
||||
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
|
||||
node = TextNode(
|
||||
text=text,
|
||||
id_=node_id,
|
||||
metadata=metadata,
|
||||
start_char_idx=node_info.get("start", None),
|
||||
end_char_idx=node_info.get("end", None),
|
||||
relationships=relationships,
|
||||
)
|
||||
nodes.append(node)
|
||||
similarity_score = math.exp(-distance)
|
||||
similarities.append(similarity_score)
|
||||
logger.debug(
|
||||
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
|
||||
)
|
||||
ids.append(node_id)
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
||||
2
setup.py
2
setup.py
|
|
@ -37,8 +37,8 @@ extras_require = {
|
|||
"llama-index-retrievers-bm25==0.1.3",
|
||||
"llama-index-vector-stores-faiss==0.1.1",
|
||||
"llama-index-vector-stores-elasticsearch==0.1.6",
|
||||
"llama-index-vector-stores-chroma==0.1.6",
|
||||
"llama-index-postprocessor-colbert-rerank==0.1.1",
|
||||
"chromadb==0.4.23",
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,26 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import Document, TextNode
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.llms import MockLLM
|
||||
from llama_index.core.schema import Document, NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
|
||||
|
||||
|
||||
class TestSimpleEngine:
|
||||
@pytest.fixture
|
||||
def mock_llm(self):
|
||||
return MockLLM()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
|
@ -54,7 +68,7 @@ class TestSimpleEngine:
|
|||
retriever_configs = [mocker.MagicMock()]
|
||||
ranker_configs = [mocker.MagicMock()]
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
|
|
@ -65,7 +79,7 @@ class TestSimpleEngine:
|
|||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
|
|
@ -75,6 +89,68 @@ class TestSimpleEngine:
|
|||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
def test_from_docs_without_file(self):
|
||||
with pytest.raises(ValueError):
|
||||
SimpleEngine.from_docs()
|
||||
|
||||
def test_from_objs(self, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
class MockRAGObject:
|
||||
def rag_key(self):
|
||||
return "key"
|
||||
|
||||
def model_dump_json(self):
|
||||
return "{}"
|
||||
|
||||
objs = [MockRAGObject()]
|
||||
|
||||
# Setup
|
||||
retriever_configs = []
|
||||
ranker_configs = []
|
||||
|
||||
# Exec
|
||||
engine = SimpleEngine.from_objs(
|
||||
objs=objs,
|
||||
llm=mock_llm,
|
||||
embed_model=mock_embedding,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is not None
|
||||
|
||||
def test_from_objs_with_bm25_config(self):
|
||||
# Setup
|
||||
retriever_configs = [BM25RetrieverConfig()]
|
||||
|
||||
# Exec
|
||||
with pytest.raises(ValueError):
|
||||
SimpleEngine.from_objs(
|
||||
objs=[],
|
||||
llm=MockLLM(),
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=[],
|
||||
)
|
||||
|
||||
def test_from_index(self, mocker, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
|
||||
mock_get_index.return_value = mock_index
|
||||
|
||||
# Exec
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=mock_index,
|
||||
embed_model=mock_embedding,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is mock_index
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
# Mock
|
||||
|
|
@ -86,10 +162,10 @@ class TestSimpleEngine:
|
|||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
engine.aquery = mock_aquery
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
result = await engine.asearch(test_query)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
mock_aquery.assert_called_once_with(test_query)
|
||||
assert result == expected_result
|
||||
|
||||
|
|
@ -106,10 +182,10 @@ class TestSimpleEngine:
|
|||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
test_query = "test query"
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
result = await engine.aretrieve(test_query)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
mock_query_bundle.assert_called_once_with(test_query)
|
||||
mock_super_aretrieve.assert_called_once_with("query_bundle")
|
||||
assert result[0].text == "node_with_score"
|
||||
|
|
@ -134,10 +210,10 @@ class TestSimpleEngine:
|
|||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
engine.add_docs(input_files=input_files)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
|
||||
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
|
||||
|
||||
|
|
@ -156,11 +232,79 @@ class TestSimpleEngine:
|
|||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
|
||||
# Execute
|
||||
# Exec
|
||||
engine.add_objs(objs=objs)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
assert mock_retriever.add_nodes.call_count == 1
|
||||
for node in mock_retriever.add_nodes.call_args[0][0]:
|
||||
assert isinstance(node, TextNode)
|
||||
assert "is_obj" in node.metadata
|
||||
|
||||
def test_persist_successfully(self, mocker):
|
||||
# Mock
|
||||
mock_retriever = mocker.MagicMock(spec=PersistableRAGRetriever)
|
||||
mock_retriever.persist.return_value = mocker.MagicMock()
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Exec
|
||||
engine.persist(persist_dir="")
|
||||
|
||||
def test_ensure_retriever_of_type(self, mocker):
|
||||
# Mock
|
||||
class MyRetriever:
|
||||
def add_nodes(self):
|
||||
...
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=SimpleHybridRetriever)
|
||||
mock_retriever.retrievers = [MyRetriever()]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Assert
|
||||
engine._ensure_retriever_of_type(ModifiableRAGRetriever)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
engine._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
other_engine = SimpleEngine(retriever=mocker.MagicMock(spec=ModifiableRAGRetriever))
|
||||
other_engine._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
def test_with_obj_metadata(self, mocker):
|
||||
# Mock
|
||||
node = NodeWithScore(
|
||||
node=ObjectNode(
|
||||
text="example",
|
||||
metadata={
|
||||
"is_obj": True,
|
||||
"obj_cls_name": "ExampleObject",
|
||||
"obj_mod_name": "__main__",
|
||||
"obj_json": json.dumps({"key": "test_key", "value": "test_value"}),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
class ExampleObject:
|
||||
def __init__(self, key, value):
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.key == other.key and self.value == other.value
|
||||
|
||||
mock_import_class = mocker.patch("metagpt.rag.engines.simple.import_class")
|
||||
mock_import_class.return_value = ExampleObject
|
||||
|
||||
# Setup
|
||||
SimpleEngine._try_reconstruct_obj([node])
|
||||
|
||||
# Exec
|
||||
expected_obj = ExampleObject(key="test_key", value="test_value")
|
||||
|
||||
# Assert
|
||||
assert "obj" in node.node.metadata
|
||||
assert node.node.metadata["obj"] == expected_obj
|
||||
|
|
|
|||
43
tests/metagpt/rag/factories/test_embedding.py
Normal file
43
tests/metagpt/rag/factories/test_embedding.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
||||
class TestRAGEmbeddingFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_embedding_factory(self):
|
||||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding()
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
89
tests/metagpt/rag/factories/test_index.py
Normal file
89
tests/metagpt/rag/factories/test_index.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
|
||||
from metagpt.rag.factories.index import RAGIndexFactory
|
||||
from metagpt.rag.schema import (
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestRAGIndexFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
self.index_factory = RAGIndexFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_config(self):
|
||||
return FAISSIndexConfig(persist_path="")
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_config(self):
|
||||
return ChromaIndexConfig(persist_path="", collection_name="")
|
||||
|
||||
@pytest.fixture
|
||||
def bm25_config(self):
|
||||
return BM25IndexConfig(persist_path="")
|
||||
|
||||
@pytest.fixture
|
||||
def es_config(self, mocker):
|
||||
return ElasticsearchIndexConfig(store_config=ElasticsearchStoreConfig())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_context(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.StorageContext.from_defaults")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_load_index_from_storage(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.load_index_from_storage")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_from_vector_store(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.VectorStoreIndex.from_vector_store")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
def test_create_faiss_index(
|
||||
self, mocker, faiss_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
|
||||
):
|
||||
# Mock
|
||||
mock_faiss_store = mocker.patch("metagpt.rag.factories.index.FaissVectorStore.from_persist_dir")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(faiss_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_faiss_store.assert_called_once()
|
||||
|
||||
def test_create_bm25_index(
|
||||
self, mocker, bm25_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
|
||||
):
|
||||
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
|
||||
|
||||
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
|
||||
# Mock
|
||||
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
|
||||
mock_chroma_db.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
|
||||
mock_chroma_store = mocker.patch("metagpt.rag.factories.index.ChromaVectorStore")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(chroma_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_chroma_store.assert_called_once()
|
||||
|
||||
def test_create_es_index(self, mocker, es_config, mock_from_vector_store, mock_embedding):
|
||||
# Mock
|
||||
mock_es_store = mocker.patch("metagpt.rag.factories.index.ElasticsearchStore")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(es_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_es_store.assert_called_once()
|
||||
71
tests/metagpt/rag/factories/test_llm.py
Normal file
71
tests/metagpt/rag/factories/test_llm.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import LLMMetadata
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.rag.factories.llm import RAGLLM, get_rag_llm
|
||||
|
||||
|
||||
class MockLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
...
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
"""_achat_completion implemented by inherited class"""
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return "ok"
|
||||
|
||||
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return "ok"
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
"""_achat_completion_stream implemented by inherited class"""
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: Union[str, list[dict[str, str]]],
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
stream=True,
|
||||
) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
class TestRAGLLM:
|
||||
@pytest.fixture
|
||||
def mock_model_infer(self):
|
||||
return MockLLM(config=LLMConfig())
|
||||
|
||||
@pytest.fixture
|
||||
def rag_llm(self, mock_model_infer):
|
||||
return RAGLLM(model_infer=mock_model_infer)
|
||||
|
||||
def test_metadata(self, rag_llm):
|
||||
metadata = rag_llm.metadata
|
||||
assert isinstance(metadata, LLMMetadata)
|
||||
assert metadata.context_window == rag_llm.context_window
|
||||
assert metadata.num_output == rag_llm.num_output
|
||||
assert metadata.model_name == rag_llm.model_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acomplete(self, rag_llm, mock_model_infer):
|
||||
response = await rag_llm.acomplete("question")
|
||||
assert response.text == "ok"
|
||||
|
||||
def test_complete(self, rag_llm, mock_model_infer):
|
||||
response = rag_llm.complete("question")
|
||||
assert response.text == "ok"
|
||||
|
||||
def test_stream_complete(self, rag_llm, mock_model_infer):
|
||||
rag_llm.stream_complete("question")
|
||||
|
||||
|
||||
def test_get_rag_llm():
|
||||
result = get_rag_llm(MockLLM(config=LLMConfig()))
|
||||
assert isinstance(result, RAGLLM)
|
||||
|
|
@ -1,41 +1,57 @@
|
|||
import pytest
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.llms import MockLLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.factories.ranker import RankerFactory
|
||||
from metagpt.rag.schema import LLMRankerConfig
|
||||
from metagpt.rag.schema import ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def ranker_factory(self) -> RankerFactory:
|
||||
return RankerFactory()
|
||||
@pytest.fixture(autouse=True)
|
||||
def ranker_factory(self):
|
||||
self.ranker_factory: RankerFactory = RankerFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self, mocker):
|
||||
return mocker.MagicMock(spec=LLM)
|
||||
def mock_llm(self):
|
||||
return MockLLM()
|
||||
|
||||
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
|
||||
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = ranker_factory.get_rankers()
|
||||
def test_get_rankers_with_no_configs(self, mock_llm, mocker):
|
||||
mocker.patch.object(self.ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = self.ranker_factory.get_rankers()
|
||||
assert len(default_rankers) == 0
|
||||
|
||||
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
def test_get_rankers_with_configs(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
rankers = ranker_factory.get_rankers(configs=[mock_config])
|
||||
rankers = self.ranker_factory.get_rankers(configs=[mock_config])
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
|
||||
def test_extract_llm_from_config(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
ranker = ranker_factory._create_llm_ranker(mock_config)
|
||||
extracted_llm = self.ranker_factory._extract_llm(config=mock_config)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_extract_llm_from_kwargs(self, mock_llm):
|
||||
extracted_llm = self.ranker_factory._extract_llm(llm=mock_llm)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_create_llm_ranker(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_llm_ranker(mock_config)
|
||||
assert isinstance(ranker, LLMRerank)
|
||||
|
||||
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
extracted_llm = ranker_factory._extract_llm(config=mock_config)
|
||||
assert extracted_llm == mock_llm
|
||||
def test_create_colbert_ranker(self, mocker, mock_llm):
|
||||
mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert")
|
||||
|
||||
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
|
||||
assert extracted_llm == mock_llm
|
||||
mock_config = ColbertRerankConfig(llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_colbert_ranker(mock_config)
|
||||
|
||||
assert ranker == "colbert"
|
||||
|
||||
def test_create_object_ranker(self, mocker, mock_llm):
|
||||
mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")
|
||||
|
||||
mock_config = ObjectRankerConfig(field_name="fake", llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_object_ranker(mock_config)
|
||||
|
||||
assert ranker == "object"
|
||||
|
|
|
|||
|
|
@ -1,18 +1,28 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
||||
from metagpt.rag.factories.retriever import RetrieverFactory
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
@pytest.fixture
|
||||
@pytest.fixture(autouse=True)
|
||||
def retriever_factory(self):
|
||||
return RetrieverFactory()
|
||||
self.retriever_factory: RetrieverFactory = RetrieverFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_index(self, mocker):
|
||||
|
|
@ -25,55 +35,79 @@ class TestRetrieverFactory:
|
|||
mock.docstore.docs.values.return_value = []
|
||||
return mock
|
||||
|
||||
def test_get_retriever_with_faiss_config(
|
||||
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
|
||||
):
|
||||
@pytest.fixture
|
||||
def mock_chroma_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ChromaVectorStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
|
||||
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(
|
||||
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
|
||||
):
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
|
||||
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
|
||||
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
|
||||
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
|
||||
def test_create_default_retriever(self, mocker, mock_vector_store_index):
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mock_vector_store_index.as_retriever = mocker.MagicMock()
|
||||
|
||||
retriever = retriever_factory.get_retriever()
|
||||
retriever = self.retriever_factory.get_retriever()
|
||||
|
||||
mock_vector_store_index.as_retriever.assert_called_once()
|
||||
assert retriever is mock_vector_store_index.as_retriever.return_value
|
||||
|
||||
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
def test_extract_index_from_config(self, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
|
||||
|
||||
extracted_index = retriever_factory._extract_index(config=mock_config)
|
||||
extracted_index = self.retriever_factory._extract_index(config=mock_config)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
def test_extract_index_from_kwargs(self, mock_vector_store_index):
|
||||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
|
|
|||
23
tests/metagpt/rag/rankers/test_base_ranker.py
Normal file
23
tests/metagpt/rag/rankers/test_base_ranker.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
|
||||
|
||||
from metagpt.rag.rankers.base import RAGRanker
|
||||
|
||||
|
||||
class SimpleRAGRanker(RAGRanker):
|
||||
def _postprocess_nodes(self, nodes, query_bundle=None):
|
||||
return [NodeWithScore(node=node.node, score=node.score + 1) for node in nodes]
|
||||
|
||||
|
||||
class TestSimpleRAGRanker:
|
||||
@pytest.fixture
|
||||
def ranker(self):
|
||||
return SimpleRAGRanker()
|
||||
|
||||
def test_postprocess_nodes_increases_scores(self, ranker):
|
||||
nodes = [NodeWithScore(node=TextNode(text="a"), score=10), NodeWithScore(node=TextNode(text="b"), score=20)]
|
||||
query_bundle = QueryBundle(query_str="test query")
|
||||
|
||||
processed_nodes = ranker._postprocess_nodes(nodes, query_bundle)
|
||||
|
||||
assert all(node.score == original_node.score + 1 for node, original_node in zip(processed_nodes, nodes))
|
||||
|
|
@ -14,7 +14,7 @@ class Record(BaseModel):
|
|||
|
||||
class TestObjectSortPostprocessor:
|
||||
@pytest.fixture
|
||||
def nodes_with_scores(self):
|
||||
def mock_nodes_with_scores(self):
|
||||
nodes = [
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
|
||||
|
|
@ -23,38 +23,47 @@ class TestObjectSortPostprocessor:
|
|||
return nodes
|
||||
|
||||
@pytest.fixture
|
||||
def query_bundle(self, mocker):
|
||||
def mock_query_bundle(self, mocker):
|
||||
return mocker.MagicMock(spec=QueryBundle)
|
||||
|
||||
def test_sort_descending(self, nodes_with_scores, query_bundle):
|
||||
def test_sort_descending(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [20, 10, 5]
|
||||
|
||||
def test_sort_ascending(self, nodes_with_scores, query_bundle):
|
||||
def test_sort_ascending(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [5, 10, 20]
|
||||
|
||||
def test_top_n_limit(self, nodes_with_scores, query_bundle):
|
||||
def test_top_n_limit(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert len(sorted_nodes) == 2
|
||||
assert [node.score for node in sorted_nodes] == [20, 10]
|
||||
|
||||
def test_invalid_json_metadata(self, query_bundle):
|
||||
def test_invalid_json_metadata(self, mock_query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle)
|
||||
postprocessor._postprocess_nodes(nodes, mock_query_bundle)
|
||||
|
||||
def test_missing_query_bundle(self, nodes_with_scores):
|
||||
def test_missing_query_bundle(self, mock_nodes_with_scores):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
|
||||
postprocessor._postprocess_nodes(mock_nodes_with_scores, query_bundle=None)
|
||||
|
||||
def test_field_not_found_in_object(self):
|
||||
def test_field_not_found_in_object(self, mock_query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes)
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle=mock_query_bundle)
|
||||
|
||||
def test_not_nodes(self, mock_query_bundle):
|
||||
nodes = []
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
result = postprocessor._postprocess_nodes(nodes, mock_query_bundle)
|
||||
assert result == []
|
||||
|
||||
def test_class_name(self):
|
||||
assert ObjectSortPostprocessor.class_name() == "ObjectSortPostprocessor"
|
||||
|
|
|
|||
21
tests/metagpt/rag/retrievers/test_base_retriever.py
Normal file
21
tests/metagpt/rag/retrievers/test_base_retriever.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
|
||||
|
||||
class SubModifiableRAGRetriever(ModifiableRAGRetriever):
|
||||
...
|
||||
|
||||
|
||||
class SubPersistableRAGRetriever(PersistableRAGRetriever):
|
||||
...
|
||||
|
||||
|
||||
class TestModifiableRAGRetriever:
|
||||
def test_subclasshook(self):
|
||||
result = SubModifiableRAGRetriever.__subclasshook__(SubModifiableRAGRetriever)
|
||||
assert result is NotImplemented
|
||||
|
||||
|
||||
class TestPersistableRAGRetriever:
|
||||
def test_subclasshook(self):
|
||||
result = SubPersistableRAGRetriever.__subclasshook__(SubPersistableRAGRetriever)
|
||||
assert result is NotImplemented
|
||||
|
|
@ -8,30 +8,30 @@ from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
|||
class TestDynamicBM25Retriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc1.get_content.return_value = "Document content 1"
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.doc2.get_content.return_value = "Document content 2"
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟index
|
||||
index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
index.storage_context.persist.return_value = "ok"
|
||||
|
||||
# 模拟nodes和tokenizer参数
|
||||
mock_nodes = []
|
||||
mock_tokenizer = mocker.MagicMock()
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
|
||||
# 初始化DynamicBM25Retriever对象,并提供必需的参数
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
|
||||
|
||||
def test_add_docs_updates_nodes_and_corpus(self):
|
||||
# Execute
|
||||
# Exec
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
# Assertions
|
||||
# Assert
|
||||
assert len(self.retriever._nodes) == len(self.mock_nodes)
|
||||
assert len(self.retriever._corpus) == len(self.mock_nodes)
|
||||
self.retriever._tokenizer.assert_called()
|
||||
self.mock_bm25okapi.assert_called()
|
||||
|
||||
def test_persist(self):
|
||||
self.retriever.persist("")
|
||||
|
|
|
|||
20
tests/metagpt/rag/retrievers/test_chroma_retriever.py
Normal file
20
tests/metagpt/rag/retrievers/test_chroma_retriever.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
|
||||
|
||||
class TestChromaRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = ChromaRetriever(self.mock_index)
|
||||
|
||||
def test_add_nodes(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
20
tests/metagpt/rag/retrievers/test_es_retriever.py
Normal file
20
tests/metagpt/rag/retrievers/test_es_retriever.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
|
||||
|
||||
class TestElasticsearchRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = ElasticsearchRetriever(self.mock_index)
|
||||
|
||||
def test_add_nodes(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
|
|
@ -7,16 +7,19 @@ from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
|||
class TestFAISSRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟FAISSRetriever的_index属性
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = FAISSRetriever(self.mock_index)
|
||||
|
||||
def test_add_docs_calls_insert_for_each_document(self, mocker):
|
||||
def test_add_docs_calls_insert_for_each_document(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
assert self.mock_index.insert_nodes.assert_called
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
|
||||
def test_persist(self):
|
||||
self.retriever.persist("")
|
||||
|
||||
self.mock_index.storage_context.persist.assert_called()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, TextNode
|
||||
|
||||
|
|
@ -7,18 +5,30 @@ from metagpt.rag.retrievers import SimpleHybridRetriever
|
|||
|
||||
|
||||
class TestSimpleHybridRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self, mocker):
|
||||
return mocker.MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hybrid_retriever(self, mock_retriever) -> SimpleHybridRetriever:
|
||||
return SimpleHybridRetriever(mock_retriever)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_node(self):
|
||||
return NodeWithScore(node=TextNode(id_="2"), score=0.95)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self):
|
||||
async def test_aretrieve(self, mocker):
|
||||
question = "test query"
|
||||
|
||||
# Create mock retrievers
|
||||
mock_retriever1 = AsyncMock()
|
||||
mock_retriever1 = mocker.AsyncMock()
|
||||
mock_retriever1.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="1"), score=1.0),
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
]
|
||||
|
||||
mock_retriever2 = AsyncMock()
|
||||
mock_retriever2 = mocker.AsyncMock()
|
||||
mock_retriever2.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
NodeWithScore(node=TextNode(id_="3"), score=0.8),
|
||||
|
|
@ -37,3 +47,11 @@ class TestSimpleHybridRetriever:
|
|||
# Check if the scores are correct (assuming you want the highest score)
|
||||
node_scores = {node.node.node_id: node.score for node in results}
|
||||
assert node_scores["2"] == 0.95
|
||||
|
||||
def test_add_nodes(self, mock_hybrid_retriever: SimpleHybridRetriever, mock_node):
|
||||
mock_hybrid_retriever.add_nodes([mock_node])
|
||||
mock_hybrid_retriever.retrievers[0].add_nodes.assert_called_once()
|
||||
|
||||
def test_persist(self, mock_hybrid_retriever: SimpleHybridRetriever):
|
||||
mock_hybrid_retriever.persist("")
|
||||
mock_hybrid_retriever.retrievers[0].persist.assert_called_once()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue