From 835d6987b9cee338e753c73bb0e5b9ea7f71c3dc Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 10 Apr 2024 19:34:17 +0800 Subject: [PATCH] rag chroma db add metadata --- metagpt/rag/factories/index.py | 2 +- metagpt/rag/factories/retriever.py | 2 +- metagpt/rag/schema.py | 9 ++++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index a56471359..f897af3ad 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -48,7 +48,7 @@ class RAGIndexFactory(ConfigBasedFactory): def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: db = chromadb.PersistentClient(str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 65729002e..68f2c2313 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -69,7 +69,7 @@ class RetrieverFactory(ConfigBasedFactory): def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: db = chromadb.PersistentClient(path=str(config.persist_path)) - chroma_collection = db.get_or_create_collection(config.collection_name) + chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 183f6e0c7..581815321 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,8 +1,9 @@ """RAG schemas.""" from pathlib import Path -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union +from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode @@ -45,6 +46,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.") collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class ElasticsearchStoreConfig(BaseModel): @@ -130,6 +134,9 @@ class ChromaIndexConfig(VectorIndexConfig): """Config for chroma-based index.""" collection_name: str = Field(default="metagpt", description="The name of the collection.") + metadata: Optional[CollectionMetadata] = Field( + default=None, description="Optional metadata to associate with the collection" + ) class BM25IndexConfig(BaseIndexConfig):