mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
add index factory and chromadb
This commit is contained in:
parent
8f3d56d18c
commit
90ca74147d
19 changed files with 505 additions and 29 deletions
1
.gitattributes
vendored
1
.gitattributes
vendored
|
|
@ -15,6 +15,7 @@
|
|||
*.jpeg binary
|
||||
*.mp3 binary
|
||||
*.zip binary
|
||||
*.bin binary
|
||||
|
||||
|
||||
# Preserve original line endings for specific document files
|
||||
|
|
|
|||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -174,6 +174,7 @@ tmp.png
|
|||
.dependencies.json
|
||||
tests/metagpt/utils/file_repo_git
|
||||
tests/data/rsp_cache.json
|
||||
tests/data/rsp_cache_new.json
|
||||
*.tmp
|
||||
*.png
|
||||
htmlcov
|
||||
|
|
@ -184,4 +185,5 @@ cov.xml
|
|||
*.faiss
|
||||
*-structure.csv
|
||||
*-structure.json
|
||||
metagpt/tools/schemas
|
||||
*.dot
|
||||
.python-version
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Bojan likes traveling.
|
||||
Bob likes traveling.
|
||||
|
|
@ -15,7 +15,7 @@ DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
|
|||
QUESTION = "What are key qualities to be a good writer?"
|
||||
|
||||
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
|
||||
TRAVEL_QUESTION = "What does Bojan like?"
|
||||
TRAVEL_QUESTION = "What does Bob like?"
|
||||
|
||||
LLM_TIP = "If you not sure, just answer I don't know"
|
||||
|
||||
|
|
@ -61,10 +61,10 @@ class RAGExample:
|
|||
|
||||
[After add docs]
|
||||
Retrieve Result:
|
||||
0. Bojan like..., 10.0
|
||||
0. Bob like..., 10.0
|
||||
|
||||
Query Result:
|
||||
Bojan likes traveling.
|
||||
Bob likes traveling.
|
||||
"""
|
||||
self._print_title("RAG Add Docs")
|
||||
|
||||
|
|
|
|||
|
|
@ -26,11 +26,16 @@ from llama_index.core.schema import (
|
|||
TransformComponent,
|
||||
)
|
||||
|
||||
from metagpt.rag.factories import get_rag_llm, get_rankers, get_retriever
|
||||
from metagpt.rag.factories import (
|
||||
get_index,
|
||||
get_rag_embedding,
|
||||
get_rag_llm,
|
||||
get_rankers,
|
||||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import BaseRankerConfig, BaseRetrieverConfig
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
|
||||
|
||||
|
||||
class SimpleEngine(RetrieverQueryEngine):
|
||||
|
|
@ -83,8 +88,31 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=embed_model or get_embedding(),
|
||||
embed_model=embed_model or get_rag_embedding(),
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_index(
|
||||
cls,
|
||||
index_config: BaseIndexConfig,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
):
|
||||
"""Load from previously maintained"""
|
||||
index = get_index(index_config, embed_model=embed_model or get_rag_embedding())
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
index: BaseIndex,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
):
|
||||
llm = llm or get_rag_llm()
|
||||
retriever = get_retriever(configs=retriever_configs, index=index)
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm)
|
||||
|
|
|
|||
|
|
@ -2,5 +2,7 @@
|
|||
from metagpt.rag.factories.retriever import get_retriever
|
||||
from metagpt.rag.factories.ranker import get_rankers
|
||||
from metagpt.rag.factories.llm import get_rag_llm
|
||||
from metagpt.rag.factories.embedding import get_rag_embedding
|
||||
from metagpt.rag.factories.index import get_index
|
||||
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_llm"]
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_llm", "get_rag_embedding", "get_index"]
|
||||
|
|
|
|||
39
metagpt/rag/factories/embedding.py
Normal file
39
metagpt/rag/factories/embedding.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""RAG LLM Factory.
|
||||
|
||||
The LLM of LlamaIndex and the LLM of MG are not the same.
|
||||
"""
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex LLM with MG config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAIEmbedding(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
)
|
||||
|
||||
|
||||
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding
|
||||
51
metagpt/rag/factories/index.py
Normal file
51
metagpt/rag/factories/index.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""RAG Index Factory."""
|
||||
import chromadb
|
||||
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigFactory
|
||||
from metagpt.rag.schema import BaseIndexConfig, ChromaIndexConfig, FAISSIndexConfig
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RAGIndexFactory(ConfigFactory):
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSIndexConfig: self._create_faiss,
|
||||
ChromaIndexConfig: self._create_chroma,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex:
|
||||
"""Key is PersistType."""
|
||||
return super().get_instance(config, **kwargs)
|
||||
|
||||
def extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
||||
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self.extract_embed_model(config, **kwargs)
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
|
||||
storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store, persist_dir=config.persist_path, embed_mode=embed_model
|
||||
)
|
||||
index = load_index_from_storage(storage_context=storage_context)
|
||||
return index
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self.extract_embed_model(config, **kwargs)
|
||||
|
||||
db2 = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db2.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
return index
|
||||
|
||||
|
||||
get_index = RAGIndexFactory().get_index
|
||||
|
|
@ -44,7 +44,7 @@ class RAGLLMFactory(GenericFactory):
|
|||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
model=config.llm.model,
|
||||
deployment_name=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,14 +20,10 @@ class RankerFactory(ConfigFactory):
|
|||
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return self._create_default(**kwargs)
|
||||
return []
|
||||
|
||||
return super().get_instances(configs, **kwargs)
|
||||
|
||||
def _create_default(self, **kwargs) -> list[LLMRerank]:
|
||||
config = LLMRankerConfig(llm=self._extract_llm(**kwargs))
|
||||
return [LLMRerank(**config.model_dump())]
|
||||
|
||||
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
|
||||
return self._val_from_config_or_kwargs("llm", config, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,25 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigFactory):
|
||||
|
|
@ -23,6 +29,7 @@ class RetrieverFactory(ConfigFactory):
|
|||
creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
ChromaRetrieverConfig: self._create_chroma_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -44,8 +51,9 @@ class RetrieverFactory(ConfigFactory):
|
|||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
def _build_index_from_vector_store(
|
||||
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
old_index = self._extract_index(config, **kwargs)
|
||||
new_index = VectorStoreIndex(
|
||||
|
|
@ -53,12 +61,23 @@ class RetrieverFactory(ConfigFactory):
|
|||
storage_context=storage_context,
|
||||
embed_model=old_index._embed_model,
|
||||
)
|
||||
config.index = new_index
|
||||
return new_index
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = self._extract_index(config, **kwargs)
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
|
|
|
|||
11
metagpt/rag/retrievers/chroma_retriever.py
Normal file
11
metagpt/rag/retrievers/chroma_retriever.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Chroma retriever."""
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class ChromaRetriever(VectorIndexRetriever):
|
||||
"""FAISS retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
"""Support add nodes"""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
|
@ -9,7 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
class BaseRetrieverConfig(BaseModel):
|
||||
"""Common config for retrievers.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory.
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
|
@ -19,7 +21,7 @@ class BaseRetrieverConfig(BaseModel):
|
|||
class IndexRetrieverConfig(BaseRetrieverConfig):
|
||||
"""Config for Index-basd retrievers."""
|
||||
|
||||
index: BaseIndex = Field(default=None, description="Index for retriver")
|
||||
index: BaseIndex = Field(default=None, description="Index for retriver.")
|
||||
|
||||
|
||||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
|
|
@ -32,10 +34,17 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
|
|||
"""Config for BM25-based retrievers."""
|
||||
|
||||
|
||||
class ChromaRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Chroma-based retrievers."""
|
||||
|
||||
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class BaseRankerConfig(BaseModel):
|
||||
"""Common config for rankers.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory.
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
|
@ -48,5 +57,30 @@ class LLMRankerConfig(BaseRankerConfig):
|
|||
|
||||
llm: Any = Field(
|
||||
default=None,
|
||||
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1",
|
||||
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
|
||||
)
|
||||
|
||||
|
||||
class BaseIndexConfig(BaseModel):
|
||||
"""Common config for index.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
|
||||
"""
|
||||
|
||||
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
|
||||
|
||||
|
||||
class VectorIndexConfig(BaseIndexConfig):
|
||||
"""Config for vector-based index."""
|
||||
|
||||
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
|
||||
|
||||
|
||||
class FAISSIndexConfig(VectorIndexConfig):
|
||||
"""Config for faiss-based index."""
|
||||
|
||||
|
||||
class ChromaIndexConfig(VectorIndexConfig):
|
||||
"""Config for chroma-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
|
|
|||
0
metagpt/rag/vector_stores/__init__.py
Normal file
0
metagpt/rag/vector_stores/__init__.py
Normal file
3
metagpt/rag/vector_stores/chroma/__init__.py
Normal file
3
metagpt/rag/vector_stores/chroma/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
|
||||
|
||||
__all__ = ["ChromaVectorStore"]
|
||||
290
metagpt/rag/vector_stores/chroma/base.py
Normal file
290
metagpt/rag/vector_stores/chroma/base.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
"""Chroma vector store.
|
||||
|
||||
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
|
||||
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Dict, Generator, List, Optional, cast
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from llama_index.core.bridge.pydantic import Field, PrivateAttr
|
||||
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
|
||||
from llama_index.core.utils import truncate_text
|
||||
from llama_index.core.vector_stores.types import (
|
||||
BasePydanticVectorStore,
|
||||
MetadataFilters,
|
||||
VectorStoreQuery,
|
||||
VectorStoreQueryResult,
|
||||
)
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
legacy_metadata_dict_to_node,
|
||||
metadata_dict_to_node,
|
||||
node_to_metadata_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _transform_chroma_filter_condition(condition: str) -> str:
|
||||
"""Translate standard metadata filter op to Chroma specific spec."""
|
||||
if condition == "and":
|
||||
return "$and"
|
||||
elif condition == "or":
|
||||
return "$or"
|
||||
else:
|
||||
raise ValueError(f"Filter condition {condition} not supported")
|
||||
|
||||
|
||||
def _transform_chroma_filter_operator(operator: str) -> str:
|
||||
"""Translate standard metadata filter operator to Chroma specific spec."""
|
||||
if operator == "!=":
|
||||
return "$ne"
|
||||
elif operator == "==":
|
||||
return "$eq"
|
||||
elif operator == ">":
|
||||
return "$gt"
|
||||
elif operator == "<":
|
||||
return "$lt"
|
||||
elif operator == ">=":
|
||||
return "$gte"
|
||||
elif operator == "<=":
|
||||
return "$lte"
|
||||
else:
|
||||
raise ValueError(f"Filter operator {operator} not supported")
|
||||
|
||||
|
||||
def _to_chroma_filter(
|
||||
standard_filters: MetadataFilters,
|
||||
) -> dict:
|
||||
"""Translate standard metadata filters to Chroma specific spec."""
|
||||
filters = {}
|
||||
filters_list = []
|
||||
condition = standard_filters.condition or "and"
|
||||
condition = _transform_chroma_filter_condition(condition)
|
||||
if standard_filters.filters:
|
||||
for filter in standard_filters.filters:
|
||||
if filter.operator:
|
||||
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
|
||||
else:
|
||||
filters_list.append({filter.key: filter.value})
|
||||
if len(filters_list) == 1:
|
||||
# If there is only one filter, return it directly
|
||||
return filters_list[0]
|
||||
elif len(filters_list) > 1:
|
||||
filters[condition] = filters_list
|
||||
return filters
|
||||
|
||||
|
||||
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
|
||||
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
|
||||
|
||||
|
||||
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
|
||||
"""Yield successive max_chunk_size-sized chunks from lst.
|
||||
Args:
|
||||
lst (List[BaseNode]): list of nodes with embeddings
|
||||
max_chunk_size (int): max chunk size
|
||||
Yields:
|
||||
Generator[List[BaseNode], None, None]: list of nodes with embeddings
|
||||
"""
|
||||
for i in range(0, len(lst), max_chunk_size):
|
||||
yield lst[i : i + max_chunk_size]
|
||||
|
||||
|
||||
class ChromaVectorStore(BasePydanticVectorStore):
|
||||
"""Chroma vector store.
|
||||
In this vector store, embeddings are stored within a ChromaDB collection.
|
||||
During query time, the index uses ChromaDB to query for the top
|
||||
k most similar nodes.
|
||||
Args:
|
||||
chroma_collection (chromadb.api.models.Collection.Collection):
|
||||
ChromaDB collection instance
|
||||
"""
|
||||
|
||||
stores_text: bool = True
|
||||
flat_metadata: bool = True
|
||||
collection_name: Optional[str]
|
||||
host: Optional[str]
|
||||
port: Optional[str]
|
||||
ssl: bool
|
||||
headers: Optional[Dict[str, str]]
|
||||
persist_dir: Optional[str]
|
||||
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
_collection: Any = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_collection: Optional[Any] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
collection_kwargs = collection_kwargs or {}
|
||||
if chroma_collection is None:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
self._collection = cast(Collection, chroma_collection)
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
collection_name=collection_name,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
|
||||
try:
|
||||
from chromadb import Collection
|
||||
except ImportError:
|
||||
raise ImportError(import_err_msg)
|
||||
if not isinstance(collection, Collection):
|
||||
raise Exception("argument is not chromadb collection instance")
|
||||
return cls(chroma_collection=collection)
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
collection_name: str,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: dict = {},
|
||||
**kwargs: Any,
|
||||
) -> "ChromaVectorStore":
|
||||
if persist_dir:
|
||||
client = chromadb.PersistentClient(path=persist_dir)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
elif host and port:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
|
||||
return cls(
|
||||
chroma_collection=collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ChromaVectorStore"
|
||||
|
||||
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
|
||||
"""Add nodes to index.
|
||||
Args:
|
||||
nodes: List[BaseNode]: list of nodes with embeddings
|
||||
"""
|
||||
if not self._collection:
|
||||
raise ValueError("Collection not initialized")
|
||||
max_chunk_size = MAX_CHUNK_SIZE
|
||||
node_chunks = chunk_list(nodes, max_chunk_size)
|
||||
all_ids = []
|
||||
for node_chunk in node_chunks:
|
||||
embeddings = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
documents = []
|
||||
for node in node_chunk:
|
||||
embeddings.append(node.get_embedding())
|
||||
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
|
||||
for key in metadata_dict:
|
||||
if metadata_dict[key] is None:
|
||||
metadata_dict[key] = ""
|
||||
metadatas.append(metadata_dict)
|
||||
ids.append(node.node_id)
|
||||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
||||
self._collection.add(
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
all_ids.extend(ids)
|
||||
return all_ids
|
||||
|
||||
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""
|
||||
Delete nodes using with ref_doc_id.
|
||||
Args:
|
||||
ref_doc_id (str): The doc_id of the document to delete.
|
||||
"""
|
||||
self._collection.delete(where={"document_id": ref_doc_id})
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
"""Return client."""
|
||||
return self._collection
|
||||
|
||||
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes.
|
||||
Args:
|
||||
query_embedding (List[float]): query embedding
|
||||
similarity_top_k (int): top k most similar nodes
|
||||
"""
|
||||
if query.filters is not None:
|
||||
if "where" in kwargs:
|
||||
raise ValueError(
|
||||
"Cannot specify metadata filters via both query and kwargs. "
|
||||
"Use kwargs only for chroma specific items that are "
|
||||
"not supported via the generic query interface."
|
||||
)
|
||||
where = _to_chroma_filter(query.filters)
|
||||
else:
|
||||
where = kwargs.pop("where", {})
|
||||
results = self._collection.query(
|
||||
query_embeddings=query.query_embedding,
|
||||
n_results=query.similarity_top_k,
|
||||
where=where,
|
||||
**kwargs,
|
||||
)
|
||||
logger.debug(f"> Top {len(results['documents'])} nodes:")
|
||||
nodes = []
|
||||
similarities = []
|
||||
ids = []
|
||||
for node_id, text, metadata, distance in zip(
|
||||
results["ids"][0],
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
):
|
||||
try:
|
||||
node = metadata_dict_to_node(metadata)
|
||||
node.set_content(text)
|
||||
except Exception:
|
||||
# NOTE: deprecated legacy logic for backward compatibility
|
||||
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
|
||||
node = TextNode(
|
||||
text=text,
|
||||
id_=node_id,
|
||||
metadata=metadata,
|
||||
start_char_idx=node_info.get("start", None),
|
||||
end_char_idx=node_info.get("end", None),
|
||||
relationships=relationships,
|
||||
)
|
||||
nodes.append(node)
|
||||
similarity_score = math.exp(-distance)
|
||||
similarities.append(similarity_score)
|
||||
logger.debug(
|
||||
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
|
||||
)
|
||||
ids.append(node_id)
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
aiohttp==3.8.6
|
||||
#azure_storage==0.37.0
|
||||
channels==4.0.0
|
||||
# chromadb
|
||||
# Django==4.1.5
|
||||
# docx==0.2.4
|
||||
#faiss==1.5.3
|
||||
|
|
@ -12,6 +11,8 @@ typer==0.9.0
|
|||
# google_api_python_client==2.93.0 # Used by search_engine.py
|
||||
lancedb==0.4.0
|
||||
llama-index-core==0.10.12
|
||||
llama-index-embeddings-azure-openai==0.1.6
|
||||
llama-index-embeddings-huggingface==0.1.3
|
||||
llama-index-embeddings-openai==0.1.5
|
||||
llama-index-llms-azure-openai==0.1.4
|
||||
llama-index-llms-gemini==0.1.4
|
||||
|
|
@ -20,6 +21,7 @@ llama-index-llms-openai==0.1.5
|
|||
llama-index-readers-file==0.1.4
|
||||
llama-index-retrievers-bm25==0.1.3
|
||||
llama-index-vector-stores-faiss==0.1.1
|
||||
chromadb==0.4.23
|
||||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy==1.24.3
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -42,7 +42,7 @@ extras_require["test"] = [
|
|||
"connexion[uvicorn]~=3.0.5",
|
||||
"azure-cognitiveservices-speech~=1.31.0",
|
||||
"aioboto3~=11.3.0",
|
||||
"chromadb==0.4.14",
|
||||
"chromadb==0.4.23",
|
||||
"gradio==3.0.0",
|
||||
"grpcio-status==1.48.2",
|
||||
"pylint==3.0.3",
|
||||
|
|
|
|||
|
|
@ -18,9 +18,7 @@ class TestRankerFactory:
|
|||
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
|
||||
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = ranker_factory.get_rankers()
|
||||
assert len(default_rankers) == 1
|
||||
assert isinstance(default_rankers[0], LLMRerank)
|
||||
ranker_factory._extract_llm.assert_called_once()
|
||||
assert len(default_rankers) == 0
|
||||
|
||||
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue