add index factory and chromadb

This commit is contained in:
seehi 2024-02-28 13:42:41 +08:00 committed by betterwang
parent 8f3d56d18c
commit 90ca74147d
19 changed files with 505 additions and 29 deletions

1
.gitattributes vendored
View file

@ -15,6 +15,7 @@
*.jpeg binary
*.mp3 binary
*.zip binary
*.bin binary
# Preserve original line endings for specific document files

4
.gitignore vendored
View file

@ -174,6 +174,7 @@ tmp.png
.dependencies.json
tests/metagpt/utils/file_repo_git
tests/data/rsp_cache.json
tests/data/rsp_cache_new.json
*.tmp
*.png
htmlcov
@ -184,4 +185,5 @@ cov.xml
*.faiss
*-structure.csv
*-structure.json
metagpt/tools/schemas
*.dot
.python-version

View file

@ -1 +1 @@
Bojan likes traveling.
Bob likes traveling.

View file

@ -15,7 +15,7 @@ DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
QUESTION = "What are key qualities to be a good writer?"
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
TRAVEL_QUESTION = "What does Bojan like?"
TRAVEL_QUESTION = "What does Bob like?"
LLM_TIP = "If you not sure, just answer I don't know"
@ -61,10 +61,10 @@ class RAGExample:
[After add docs]
Retrieve Result:
0. Bojan like..., 10.0
0. Bob like..., 10.0
Query Result:
Bojan likes traveling.
Bob likes traveling.
"""
self._print_title("RAG Add Docs")

View file

@ -26,11 +26,16 @@ from llama_index.core.schema import (
TransformComponent,
)
from metagpt.rag.factories import get_rag_llm, get_rankers, get_retriever
from metagpt.rag.factories import (
get_index,
get_rag_embedding,
get_rag_llm,
get_rankers,
get_retriever,
)
from metagpt.rag.interface import RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import BaseRankerConfig, BaseRetrieverConfig
from metagpt.utils.embedding import get_embedding
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
class SimpleEngine(RetrieverQueryEngine):
@ -83,8 +88,31 @@ class SimpleEngine(RetrieverQueryEngine):
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=embed_model or get_embedding(),
embed_model=embed_model or get_rag_embedding(),
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_index(
cls,
index_config: BaseIndexConfig,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
):
"""Load from previously maintained"""
index = get_index(index_config, embed_model=embed_model or get_rag_embedding())
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def _from_index(
cls,
index: BaseIndex,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
):
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index)
rankers = get_rankers(configs=ranker_configs, llm=llm)

View file

@ -2,5 +2,7 @@
from metagpt.rag.factories.retriever import get_retriever
from metagpt.rag.factories.ranker import get_rankers
from metagpt.rag.factories.llm import get_rag_llm
from metagpt.rag.factories.embedding import get_rag_embedding
from metagpt.rag.factories.index import get_index
__all__ = ["get_retriever", "get_rankers", "get_rag_llm"]
__all__ = ["get_retriever", "get_rankers", "get_rag_llm", "get_rag_embedding", "get_index"]

View file

@ -0,0 +1,39 @@
"""RAG LLM Factory.
The LLM of LlamaIndex and the LLM of MG are not the same.
"""
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.base import GenericFactory
class RAGEmbeddingFactory(GenericFactory):
"""Create LlamaIndex LLM with MG config."""
def __init__(self):
creators = {
LLMType.OPENAI: self._create_openai,
LLMType.AZURE: self._create_azure,
}
super().__init__(creators)
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
"""Key is LLMType, default use config.llm.api_type."""
return super().get_instance(key or config.llm.api_type)
def _create_openai(self):
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
def _create_azure(self):
return AzureOpenAIEmbedding(
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
)
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding

View file

@ -0,0 +1,51 @@
"""RAG Index Factory."""
import chromadb
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigFactory
from metagpt.rag.schema import BaseIndexConfig, ChromaIndexConfig, FAISSIndexConfig
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RAGIndexFactory(ConfigFactory):
def __init__(self):
creators = {
FAISSIndexConfig: self._create_faiss,
ChromaIndexConfig: self._create_chroma,
}
super().__init__(creators)
def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex:
"""Key is PersistType."""
return super().get_instance(config, **kwargs)
def extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self.extract_embed_model(config, **kwargs)
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir=config.persist_path, embed_mode=embed_model
)
index = load_index_from_storage(storage_context=storage_context)
return index
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self.extract_embed_model(config, **kwargs)
db2 = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db2.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=embed_model,
)
return index
get_index = RAGIndexFactory().get_index

View file

@ -44,7 +44,7 @@ class RAGLLMFactory(GenericFactory):
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
model=config.llm.model,
deployment_name=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)

View file

@ -20,14 +20,10 @@ class RankerFactory(ConfigFactory):
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return self._create_default(**kwargs)
return []
return super().get_instances(configs, **kwargs)
def _create_default(self, **kwargs) -> list[LLMRerank]:
config = LLMRankerConfig(llm=self._extract_llm(**kwargs))
return [LLMRerank(**config.model_dump())]
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)

View file

@ -1,19 +1,25 @@
"""RAG Retriever Factory."""
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RetrieverFactory(ConfigFactory):
@ -23,6 +29,7 @@ class RetrieverFactory(ConfigFactory):
creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
ChromaRetrieverConfig: self._create_chroma_retriever,
}
super().__init__(creators)
@ -44,8 +51,9 @@ class RetrieverFactory(ConfigFactory):
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
@ -53,12 +61,23 @@ class RetrieverFactory(ConfigFactory):
storage_context=storage_context,
embed_model=old_index._embed_model,
)
config.index = new_index
return new_index
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = self._extract_index(config, **kwargs)
return DynamicBM25Retriever.from_defaults(**config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return ChromaRetriever(**config.model_dump())
get_retriever = RetrieverFactory().get_retriever

View file

@ -0,0 +1,11 @@
"""Chroma retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class ChromaRetriever(VectorIndexRetriever):
"""FAISS retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs):
"""Support add nodes"""
self._index.insert_nodes(nodes, **kwargs)

View file

@ -1,7 +1,9 @@
"""RAG schemas."""
from typing import Any
from pathlib import Path
from typing import Any, Union
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from pydantic import BaseModel, ConfigDict, Field
@ -9,7 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field
class BaseRetrieverConfig(BaseModel):
"""Common config for retrievers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -19,7 +21,7 @@ class BaseRetrieverConfig(BaseModel):
class IndexRetrieverConfig(BaseRetrieverConfig):
"""Config for Index-basd retrievers."""
index: BaseIndex = Field(default=None, description="Index for retriver")
index: BaseIndex = Field(default=None, description="Index for retriver.")
class FAISSRetrieverConfig(IndexRetrieverConfig):
@ -32,10 +34,17 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class BaseRankerConfig(BaseModel):
"""Common config for rankers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -48,5 +57,30 @@ class LLMRankerConfig(BaseRankerConfig):
llm: Any = Field(
default=None,
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1",
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
)
class BaseIndexConfig(BaseModel):
"""Common config for index.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
"""
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
class VectorIndexConfig(BaseIndexConfig):
"""Config for vector-based index."""
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
class FAISSIndexConfig(VectorIndexConfig):
"""Config for faiss-based index."""
class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")

View file

View file

@ -0,0 +1,3 @@
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
__all__ = ["ChromaVectorStore"]

View file

@ -0,0 +1,290 @@
"""Chroma vector store.
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
"""
import logging
import math
from typing import Any, Dict, Generator, List, Optional, cast
import chromadb
from chromadb.api.models.Collection import Collection
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
from llama_index.core.utils import truncate_text
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.core.vector_stores.utils import (
legacy_metadata_dict_to_node,
metadata_dict_to_node,
node_to_metadata_dict,
)
logger = logging.getLogger(__name__)
def _transform_chroma_filter_condition(condition: str) -> str:
"""Translate standard metadata filter op to Chroma specific spec."""
if condition == "and":
return "$and"
elif condition == "or":
return "$or"
else:
raise ValueError(f"Filter condition {condition} not supported")
def _transform_chroma_filter_operator(operator: str) -> str:
"""Translate standard metadata filter operator to Chroma specific spec."""
if operator == "!=":
return "$ne"
elif operator == "==":
return "$eq"
elif operator == ">":
return "$gt"
elif operator == "<":
return "$lt"
elif operator == ">=":
return "$gte"
elif operator == "<=":
return "$lte"
else:
raise ValueError(f"Filter operator {operator} not supported")
def _to_chroma_filter(
standard_filters: MetadataFilters,
) -> dict:
"""Translate standard metadata filters to Chroma specific spec."""
filters = {}
filters_list = []
condition = standard_filters.condition or "and"
condition = _transform_chroma_filter_condition(condition)
if standard_filters.filters:
for filter in standard_filters.filters:
if filter.operator:
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
else:
filters_list.append({filter.key: filter.value})
if len(filters_list) == 1:
# If there is only one filter, return it directly
return filters_list[0]
elif len(filters_list) > 1:
filters[condition] = filters_list
return filters
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size
Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]
class ChromaVectorStore(BasePydanticVectorStore):
"""Chroma vector store.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""
stores_text: bool = True
flat_metadata: bool = True
collection_name: Optional[str]
host: Optional[str]
port: Optional[str]
ssl: bool
headers: Optional[Dict[str, str]]
persist_dir: Optional[str]
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
_collection: Any = PrivateAttr()
def __init__(
self,
chroma_collection: Optional[Any] = None,
collection_name: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
persist_dir: Optional[str] = None,
collection_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Init params."""
collection_kwargs = collection_kwargs or {}
if chroma_collection is None:
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
else:
self._collection = cast(Collection, chroma_collection)
super().__init__(
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_name=collection_name,
persist_dir=persist_dir,
collection_kwargs=collection_kwargs or {},
)
@classmethod
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
try:
from chromadb import Collection
except ImportError:
raise ImportError(import_err_msg)
if not isinstance(collection, Collection):
raise Exception("argument is not chromadb collection instance")
return cls(chroma_collection=collection)
@classmethod
def from_params(
cls,
collection_name: str,
host: Optional[str] = None,
port: Optional[str] = None,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
persist_dir: Optional[str] = None,
collection_kwargs: dict = {},
**kwargs: Any,
) -> "ChromaVectorStore":
if persist_dir:
client = chromadb.PersistentClient(path=persist_dir)
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
elif host and port:
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
else:
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
return cls(
chroma_collection=collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
persist_dir=persist_dir,
collection_kwargs=collection_kwargs,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "ChromaVectorStore"
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
"""
if not self._collection:
raise ValueError("Collection not initialized")
max_chunk_size = MAX_CHUNK_SIZE
node_chunks = chunk_list(nodes, max_chunk_size)
all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
for key in metadata_dict:
if metadata_dict[key] is None:
metadata_dict[key] = ""
metadatas.append(metadata_dict)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)
return all_ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
self._collection.delete(where={"document_id": ref_doc_id})
@property
def client(self) -> Any:
"""Return client."""
return self._collection
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
"""
if query.filters is not None:
if "where" in kwargs:
raise ValueError(
"Cannot specify metadata filters via both query and kwargs. "
"Use kwargs only for chroma specific items that are "
"not supported via the generic query interface."
)
where = _to_chroma_filter(query.filters)
else:
where = kwargs.pop("where", {})
results = self._collection.query(
query_embeddings=query.query_embedding,
n_results=query.similarity_top_k,
where=where,
**kwargs,
)
logger.debug(f"> Top {len(results['documents'])} nodes:")
nodes = []
similarities = []
ids = []
for node_id, text, metadata, distance in zip(
results["ids"][0],
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
):
try:
node = metadata_dict_to_node(metadata)
node.set_content(text)
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
node = TextNode(
text=text,
id_=node_id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)
nodes.append(node)
similarity_score = math.exp(-distance)
similarities.append(similarity_score)
logger.debug(
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
)
ids.append(node_id)
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)

View file

@ -1,7 +1,6 @@
aiohttp==3.8.6
#azure_storage==0.37.0
channels==4.0.0
# chromadb
# Django==4.1.5
# docx==0.2.4
#faiss==1.5.3
@ -12,6 +11,8 @@ typer==0.9.0
# google_api_python_client==2.93.0 # Used by search_engine.py
lancedb==0.4.0
llama-index-core==0.10.12
llama-index-embeddings-azure-openai==0.1.6
llama-index-embeddings-huggingface==0.1.3
llama-index-embeddings-openai==0.1.5
llama-index-llms-azure-openai==0.1.4
llama-index-llms-gemini==0.1.4
@ -20,6 +21,7 @@ llama-index-llms-openai==0.1.5
llama-index-readers-file==0.1.4
llama-index-retrievers-bm25==0.1.3
llama-index-vector-stores-faiss==0.1.1
chromadb==0.4.23
loguru==0.6.0
meilisearch==0.21.0
numpy==1.24.3

View file

@ -42,7 +42,7 @@ extras_require["test"] = [
"connexion[uvicorn]~=3.0.5",
"azure-cognitiveservices-speech~=1.31.0",
"aioboto3~=11.3.0",
"chromadb==0.4.14",
"chromadb==0.4.23",
"gradio==3.0.0",
"grpcio-status==1.48.2",
"pylint==3.0.3",

View file

@ -18,9 +18,7 @@ class TestRankerFactory:
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = ranker_factory.get_rankers()
assert len(default_rankers) == 1
assert isinstance(default_rankers[0], LLMRerank)
ranker_factory._extract_llm.assert_called_once()
assert len(default_rankers) == 0
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)