diff --git a/.gitattributes b/.gitattributes index 867a5ad7b..865da2ca2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -15,6 +15,7 @@ *.jpeg binary *.mp3 binary *.zip binary +*.bin binary # Preserve original line endings for specific document files diff --git a/.gitignore b/.gitignore index 468b631ae..1e5ee4374 100644 --- a/.gitignore +++ b/.gitignore @@ -174,6 +174,7 @@ tmp.png .dependencies.json tests/metagpt/utils/file_repo_git tests/data/rsp_cache.json +tests/data/rsp_cache_new.json *.tmp *.png htmlcov @@ -183,3 +184,4 @@ htmlcov.* *-structure.csv *-structure.json *.dot +.python-version diff --git a/examples/data/rag/travel.txt b/examples/data/rag/travel.txt index 1c738c54a..f72ad5c59 100644 --- a/examples/data/rag/travel.txt +++ b/examples/data/rag/travel.txt @@ -1 +1 @@ -Bojan likes traveling. \ No newline at end of file +Bob likes traveling. \ No newline at end of file diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 6e8e5a2cc..64a83e77c 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -15,7 +15,7 @@ DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" QUESTION = "What are key qualities to be a good writer?" TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt" -TRAVEL_QUESTION = "What does Bojan like?" +TRAVEL_QUESTION = "What does Bob like?" LLM_TIP = "If you not sure, just answer I don't know" @@ -61,10 +61,10 @@ class RAGExample: [After add docs] Retrieve Result: - 0. Bojan like..., 10.0 + 0. Bob like..., 10.0 Query Result: - Bojan likes traveling. + Bob likes traveling. """ self._print_title("RAG Add Docs") diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 5f81f6309..895b7bd1e 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -26,11 +26,16 @@ from llama_index.core.schema import ( TransformComponent, ) -from metagpt.rag.factories import get_rag_llm, get_rankers, get_retriever +from metagpt.rag.factories import ( + get_index, + get_rag_embedding, + get_rag_llm, + get_rankers, + get_retriever, +) from metagpt.rag.interface import RAGObject from metagpt.rag.retrievers.base import ModifiableRAGRetriever -from metagpt.rag.schema import BaseRankerConfig, BaseRetrieverConfig -from metagpt.utils.embedding import get_embedding +from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig class SimpleEngine(RetrieverQueryEngine): @@ -83,8 +88,31 @@ class SimpleEngine(RetrieverQueryEngine): index = VectorStoreIndex.from_documents( documents=documents, transformations=transformations or [SentenceSplitter()], - embed_model=embed_model or get_embedding(), + embed_model=embed_model or get_rag_embedding(), ) + return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + @classmethod + def from_index( + cls, + index_config: BaseIndexConfig, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, + ): + """Load from previously maintained""" + index = get_index(index_config, embed_model=embed_model or get_rag_embedding()) + return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + @classmethod + def _from_index( + cls, + index: BaseIndex, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, + ): llm = llm or get_rag_llm() retriever = get_retriever(configs=retriever_configs, index=index) rankers = get_rankers(configs=ranker_configs, llm=llm) diff --git a/metagpt/rag/factories/__init__.py b/metagpt/rag/factories/__init__.py index 74290fd69..df2d38502 100644 --- a/metagpt/rag/factories/__init__.py +++ b/metagpt/rag/factories/__init__.py @@ -2,5 +2,7 @@ from metagpt.rag.factories.retriever import get_retriever from metagpt.rag.factories.ranker import get_rankers from metagpt.rag.factories.llm import get_rag_llm +from metagpt.rag.factories.embedding import get_rag_embedding +from metagpt.rag.factories.index import get_index -__all__ = ["get_retriever", "get_rankers", "get_rag_llm"] +__all__ = ["get_retriever", "get_rankers", "get_rag_llm", "get_rag_embedding", "get_index"] diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py new file mode 100644 index 000000000..67c2f3d06 --- /dev/null +++ b/metagpt/rag/factories/embedding.py @@ -0,0 +1,39 @@ +"""RAG LLM Factory. + +The LLM of LlamaIndex and the LLM of MG are not the same. +""" +from llama_index.core.embeddings import BaseEmbedding +from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding +from llama_index.embeddings.openai import OpenAIEmbedding + +from metagpt.config2 import config +from metagpt.configs.llm_config import LLMType +from metagpt.rag.factories.base import GenericFactory + + +class RAGEmbeddingFactory(GenericFactory): + """Create LlamaIndex LLM with MG config.""" + + def __init__(self): + creators = { + LLMType.OPENAI: self._create_openai, + LLMType.AZURE: self._create_azure, + } + super().__init__(creators) + + def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding: + """Key is LLMType, default use config.llm.api_type.""" + return super().get_instance(key or config.llm.api_type) + + def _create_openai(self): + return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url) + + def _create_azure(self): + return AzureOpenAIEmbedding( + azure_endpoint=config.llm.base_url, + api_key=config.llm.api_key, + api_version=config.llm.api_version, + ) + + +get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py new file mode 100644 index 000000000..d1008081c --- /dev/null +++ b/metagpt/rag/factories/index.py @@ -0,0 +1,51 @@ +"""RAG Index Factory.""" +import chromadb +from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.indices.base import BaseIndex +from llama_index.vector_stores.faiss import FaissVectorStore + +from metagpt.rag.factories.base import ConfigFactory +from metagpt.rag.schema import BaseIndexConfig, ChromaIndexConfig, FAISSIndexConfig +from metagpt.rag.vector_stores.chroma import ChromaVectorStore + + +class RAGIndexFactory(ConfigFactory): + def __init__(self): + creators = { + FAISSIndexConfig: self._create_faiss, + ChromaIndexConfig: self._create_chroma, + } + super().__init__(creators) + + def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex: + """Key is PersistType.""" + return super().get_instance(config, **kwargs) + + def extract_embed_model(self, config, **kwargs) -> BaseEmbedding: + return self._val_from_config_or_kwargs("embed_model", config, **kwargs) + + def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex: + embed_model = self.extract_embed_model(config, **kwargs) + + vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path)) + storage_context = StorageContext.from_defaults( + vector_store=vector_store, persist_dir=config.persist_path, embed_mode=embed_model + ) + index = load_index_from_storage(storage_context=storage_context) + return index + + def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: + embed_model = self.extract_embed_model(config, **kwargs) + + db2 = chromadb.PersistentClient(str(config.persist_path)) + chroma_collection = db2.get_or_create_collection(config.collection_name) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) + index = VectorStoreIndex.from_vector_store( + vector_store, + embed_model=embed_model, + ) + return index + + +get_index = RAGIndexFactory().get_index diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 70d66dd37..c5d12079e 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -44,7 +44,7 @@ class RAGLLMFactory(GenericFactory): azure_endpoint=config.llm.base_url, api_key=config.llm.api_key, api_version=config.llm.api_version, - model=config.llm.model, + deployment_name=config.llm.model, max_tokens=config.llm.max_token, temperature=config.llm.temperature, ) diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index f74e30834..0867c7945 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -20,14 +20,10 @@ class RankerFactory(ConfigFactory): def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]: """Creates and returns a retriever instance based on the provided configurations.""" if not configs: - return self._create_default(**kwargs) + return [] return super().get_instances(configs, **kwargs) - def _create_default(self, **kwargs) -> list[LLMRerank]: - config = LLMRankerConfig(llm=self._extract_llm(**kwargs)) - return [LLMRerank(**config.model_dump())] - def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: return self._val_from_config_or_kwargs("llm", config, **kwargs) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 44678fc92..d9ec6b12d 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -1,19 +1,25 @@ """RAG Retriever Factory.""" +import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex +from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.rag.factories.base import ConfigFactory from metagpt.rag.retrievers.base import RAGRetriever from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever +from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( BaseRetrieverConfig, BM25RetrieverConfig, + ChromaRetrieverConfig, FAISSRetrieverConfig, + IndexRetrieverConfig, ) +from metagpt.rag.vector_stores.chroma import ChromaVectorStore class RetrieverFactory(ConfigFactory): @@ -23,6 +29,7 @@ class RetrieverFactory(ConfigFactory): creators = { FAISSRetrieverConfig: self._create_faiss_retriever, BM25RetrieverConfig: self._create_bm25_retriever, + ChromaRetrieverConfig: self._create_chroma_retriever, } super().__init__(creators) @@ -44,8 +51,9 @@ class RetrieverFactory(ConfigFactory): def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: return self._val_from_config_or_kwargs("index", config, **kwargs) - def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: - vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + def _build_index_from_vector_store( + self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs + ) -> VectorStoreIndex: storage_context = StorageContext.from_defaults(vector_store=vector_store) old_index = self._extract_index(config, **kwargs) new_index = VectorStoreIndex( @@ -53,12 +61,23 @@ class RetrieverFactory(ConfigFactory): storage_context=storage_context, embed_model=old_index._embed_model, ) - config.index = new_index + return new_index + + def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: config.index = self._extract_index(config, **kwargs) return DynamicBM25Retriever.from_defaults(**config.model_dump()) + def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: + db = chromadb.PersistentClient(path=str(config.persist_path)) + chroma_collection = db.get_or_create_collection(config.collection_name) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) + config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + return ChromaRetriever(**config.model_dump()) + get_retriever = RetrieverFactory().get_retriever diff --git a/metagpt/rag/retrievers/chroma_retriever.py b/metagpt/rag/retrievers/chroma_retriever.py new file mode 100644 index 000000000..035969421 --- /dev/null +++ b/metagpt/rag/retrievers/chroma_retriever.py @@ -0,0 +1,11 @@ +"""Chroma retriever.""" +from llama_index.core.retrievers import VectorIndexRetriever +from llama_index.core.schema import BaseNode + + +class ChromaRetriever(VectorIndexRetriever): + """FAISS retriever.""" + + def add_nodes(self, nodes: list[BaseNode], **kwargs): + """Support add nodes""" + self._index.insert_nodes(nodes, **kwargs) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index c74846cb6..35e16e286 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,7 +1,9 @@ """RAG schemas.""" -from typing import Any +from pathlib import Path +from typing import Any, Union +from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from pydantic import BaseModel, ConfigDict, Field @@ -9,7 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field class BaseRetrieverConfig(BaseModel): """Common config for retrievers. - If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory. + If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -19,7 +21,7 @@ class BaseRetrieverConfig(BaseModel): class IndexRetrieverConfig(BaseRetrieverConfig): """Config for Index-basd retrievers.""" - index: BaseIndex = Field(default=None, description="Index for retriver") + index: BaseIndex = Field(default=None, description="Index for retriver.") class FAISSRetrieverConfig(IndexRetrieverConfig): @@ -32,10 +34,17 @@ class BM25RetrieverConfig(IndexRetrieverConfig): """Config for BM25-based retrievers.""" +class ChromaRetrieverConfig(IndexRetrieverConfig): + """Config for Chroma-based retrievers.""" + + persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") + collection_name: str = Field(default="metagpt", description="The name of the collection.") + + class BaseRankerConfig(BaseModel): """Common config for rankers. - If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory. + If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -48,5 +57,30 @@ class LLMRankerConfig(BaseRankerConfig): llm: Any = Field( default=None, - description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1", + description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.", ) + + +class BaseIndexConfig(BaseModel): + """Common config for index. + + If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index. + """ + + persist_path: Union[str, Path] = Field(description="The directory of saved data.") + + +class VectorIndexConfig(BaseIndexConfig): + """Config for vector-based index.""" + + embed_model: BaseEmbedding = Field(default=None, description="Embed model.") + + +class FAISSIndexConfig(VectorIndexConfig): + """Config for faiss-based index.""" + + +class ChromaIndexConfig(VectorIndexConfig): + """Config for chroma-based index.""" + + collection_name: str = Field(default="metagpt", description="The name of the collection.") diff --git a/metagpt/rag/vector_stores/__init__.py b/metagpt/rag/vector_stores/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/vector_stores/chroma/__init__.py b/metagpt/rag/vector_stores/chroma/__init__.py new file mode 100644 index 000000000..87ba4d8a7 --- /dev/null +++ b/metagpt/rag/vector_stores/chroma/__init__.py @@ -0,0 +1,3 @@ +from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore + +__all__ = ["ChromaVectorStore"] diff --git a/metagpt/rag/vector_stores/chroma/base.py b/metagpt/rag/vector_stores/chroma/base.py new file mode 100644 index 000000000..94728f23b --- /dev/null +++ b/metagpt/rag/vector_stores/chroma/base.py @@ -0,0 +1,290 @@ +"""Chroma vector store. + +Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py. +The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7. +""" +import logging +import math +from typing import Any, Dict, Generator, List, Optional, cast + +import chromadb +from chromadb.api.models.Collection import Collection +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.schema import BaseNode, MetadataMode, TextNode +from llama_index.core.utils import truncate_text +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryResult, +) +from llama_index.core.vector_stores.utils import ( + legacy_metadata_dict_to_node, + metadata_dict_to_node, + node_to_metadata_dict, +) + +logger = logging.getLogger(__name__) + + +def _transform_chroma_filter_condition(condition: str) -> str: + """Translate standard metadata filter op to Chroma specific spec.""" + if condition == "and": + return "$and" + elif condition == "or": + return "$or" + else: + raise ValueError(f"Filter condition {condition} not supported") + + +def _transform_chroma_filter_operator(operator: str) -> str: + """Translate standard metadata filter operator to Chroma specific spec.""" + if operator == "!=": + return "$ne" + elif operator == "==": + return "$eq" + elif operator == ">": + return "$gt" + elif operator == "<": + return "$lt" + elif operator == ">=": + return "$gte" + elif operator == "<=": + return "$lte" + else: + raise ValueError(f"Filter operator {operator} not supported") + + +def _to_chroma_filter( + standard_filters: MetadataFilters, +) -> dict: + """Translate standard metadata filters to Chroma specific spec.""" + filters = {} + filters_list = [] + condition = standard_filters.condition or "and" + condition = _transform_chroma_filter_condition(condition) + if standard_filters.filters: + for filter in standard_filters.filters: + if filter.operator: + filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}}) + else: + filters_list.append({filter.key: filter.value}) + if len(filters_list) == 1: + # If there is only one filter, return it directly + return filters_list[0] + elif len(filters_list) > 1: + filters[condition] = filters_list + return filters + + +import_err_msg = "`chromadb` package not found, please run `pip install chromadb`" +MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB + + +def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]: + """Yield successive max_chunk_size-sized chunks from lst. + Args: + lst (List[BaseNode]): list of nodes with embeddings + max_chunk_size (int): max chunk size + Yields: + Generator[List[BaseNode], None, None]: list of nodes with embeddings + """ + for i in range(0, len(lst), max_chunk_size): + yield lst[i : i + max_chunk_size] + + +class ChromaVectorStore(BasePydanticVectorStore): + """Chroma vector store. + In this vector store, embeddings are stored within a ChromaDB collection. + During query time, the index uses ChromaDB to query for the top + k most similar nodes. + Args: + chroma_collection (chromadb.api.models.Collection.Collection): + ChromaDB collection instance + """ + + stores_text: bool = True + flat_metadata: bool = True + collection_name: Optional[str] + host: Optional[str] + port: Optional[str] + ssl: bool + headers: Optional[Dict[str, str]] + persist_dir: Optional[str] + collection_kwargs: Dict[str, Any] = Field(default_factory=dict) + _collection: Any = PrivateAttr() + + def __init__( + self, + chroma_collection: Optional[Any] = None, + collection_name: Optional[str] = None, + host: Optional[str] = None, + port: Optional[str] = None, + ssl: bool = False, + headers: Optional[Dict[str, str]] = None, + persist_dir: Optional[str] = None, + collection_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> None: + """Init params.""" + collection_kwargs = collection_kwargs or {} + if chroma_collection is None: + client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers) + self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs) + else: + self._collection = cast(Collection, chroma_collection) + super().__init__( + host=host, + port=port, + ssl=ssl, + headers=headers, + collection_name=collection_name, + persist_dir=persist_dir, + collection_kwargs=collection_kwargs or {}, + ) + + @classmethod + def from_collection(cls, collection: Any) -> "ChromaVectorStore": + try: + from chromadb import Collection + except ImportError: + raise ImportError(import_err_msg) + if not isinstance(collection, Collection): + raise Exception("argument is not chromadb collection instance") + return cls(chroma_collection=collection) + + @classmethod + def from_params( + cls, + collection_name: str, + host: Optional[str] = None, + port: Optional[str] = None, + ssl: bool = False, + headers: Optional[Dict[str, str]] = None, + persist_dir: Optional[str] = None, + collection_kwargs: dict = {}, + **kwargs: Any, + ) -> "ChromaVectorStore": + if persist_dir: + client = chromadb.PersistentClient(path=persist_dir) + collection = client.get_or_create_collection(name=collection_name, **collection_kwargs) + elif host and port: + client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers) + collection = client.get_or_create_collection(name=collection_name, **collection_kwargs) + else: + raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified") + return cls( + chroma_collection=collection, + host=host, + port=port, + ssl=ssl, + headers=headers, + persist_dir=persist_dir, + collection_kwargs=collection_kwargs, + **kwargs, + ) + + @classmethod + def class_name(cls) -> str: + return "ChromaVectorStore" + + def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: + """Add nodes to index. + Args: + nodes: List[BaseNode]: list of nodes with embeddings + """ + if not self._collection: + raise ValueError("Collection not initialized") + max_chunk_size = MAX_CHUNK_SIZE + node_chunks = chunk_list(nodes, max_chunk_size) + all_ids = [] + for node_chunk in node_chunks: + embeddings = [] + metadatas = [] + ids = [] + documents = [] + for node in node_chunk: + embeddings.append(node.get_embedding()) + metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata) + for key in metadata_dict: + if metadata_dict[key] is None: + metadata_dict[key] = "" + metadatas.append(metadata_dict) + ids.append(node.node_id) + documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) + self._collection.add( + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + documents=documents, + ) + all_ids.extend(ids) + return all_ids + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """ + Delete nodes using with ref_doc_id. + Args: + ref_doc_id (str): The doc_id of the document to delete. + """ + self._collection.delete(where={"document_id": ref_doc_id}) + + @property + def client(self) -> Any: + """Return client.""" + return self._collection + + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + """Query index for top k most similar nodes. + Args: + query_embedding (List[float]): query embedding + similarity_top_k (int): top k most similar nodes + """ + if query.filters is not None: + if "where" in kwargs: + raise ValueError( + "Cannot specify metadata filters via both query and kwargs. " + "Use kwargs only for chroma specific items that are " + "not supported via the generic query interface." + ) + where = _to_chroma_filter(query.filters) + else: + where = kwargs.pop("where", {}) + results = self._collection.query( + query_embeddings=query.query_embedding, + n_results=query.similarity_top_k, + where=where, + **kwargs, + ) + logger.debug(f"> Top {len(results['documents'])} nodes:") + nodes = [] + similarities = [] + ids = [] + for node_id, text, metadata, distance in zip( + results["ids"][0], + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + ): + try: + node = metadata_dict_to_node(metadata) + node.set_content(text) + except Exception: + # NOTE: deprecated legacy logic for backward compatibility + metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata) + node = TextNode( + text=text, + id_=node_id, + metadata=metadata, + start_char_idx=node_info.get("start", None), + end_char_idx=node_info.get("end", None), + relationships=relationships, + ) + nodes.append(node) + similarity_score = math.exp(-distance) + similarities.append(similarity_score) + logger.debug( + f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}" + ) + ids.append(node_id) + return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) diff --git a/requirements.txt b/requirements.txt index 366fd7545..cc3cf03ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ aiohttp==3.8.6 #azure_storage==0.37.0 channels==4.0.0 -# chromadb # Django==4.1.5 # docx==0.2.4 #faiss==1.5.3 @@ -12,6 +11,8 @@ typer==0.9.0 # google_api_python_client==2.93.0 # Used by search_engine.py lancedb==0.4.0 llama-index-core==0.10.12 +llama-index-embeddings-azure-openai==0.1.6 +llama-index-embeddings-huggingface==0.1.3 llama-index-embeddings-openai==0.1.5 llama-index-llms-azure-openai==0.1.4 llama-index-llms-gemini==0.1.4 @@ -20,6 +21,7 @@ llama-index-llms-openai==0.1.5 llama-index-readers-file==0.1.4 llama-index-retrievers-bm25==0.1.3 llama-index-vector-stores-faiss==0.1.1 +chromadb==0.4.23 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 @@ -35,7 +37,7 @@ python_docx==0.8.11 PyYAML==6.0.1 # sentence_transformers==2.2.2 setuptools==65.6.3 -tenacity==8.2.2 +tenacity==8.2.3 tiktoken==0.5.2 tqdm==4.66.2 #unstructured[local-inference] diff --git a/setup.py b/setup.py index 0439d6cd4..230fd19c7 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ extras_require["test"] = [ "connexion[uvicorn]~=3.0.5", "azure-cognitiveservices-speech~=1.31.0", "aioboto3~=11.3.0", - "chromadb==0.4.14", + "chromadb==0.4.23", "gradio==3.0.0", "grpcio-status==1.48.2", "mock==5.1.0", diff --git a/tests/metagpt/rag/factories/test_ranker.py b/tests/metagpt/rag/factories/test_ranker.py index d4b4167a6..563cffa73 100644 --- a/tests/metagpt/rag/factories/test_ranker.py +++ b/tests/metagpt/rag/factories/test_ranker.py @@ -18,9 +18,7 @@ class TestRankerFactory: def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker): mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm) default_rankers = ranker_factory.get_rankers() - assert len(default_rankers) == 1 - assert isinstance(default_rankers[0], LLMRerank) - ranker_factory._extract_llm.assert_called_once() + assert len(default_rankers) == 0 def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm): mock_config = LLMRankerConfig(llm=mock_llm)