mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 02:23:52 +02:00
upgrade llama-index-vector-stores-chroma and rag test coverage 100%
This commit is contained in:
parent
6450a09d3b
commit
6cb5492f02
21 changed files with 600 additions and 382 deletions
|
|
@ -8,6 +8,7 @@ import re
|
|||
import time
|
||||
from typing import Any, Iterable
|
||||
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import config as CONFIG
|
||||
|
|
@ -15,7 +16,6 @@ from metagpt.environment.base_env import Environment
|
|||
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_s
|
|||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
|
|
@ -17,7 +18,6 @@ from metagpt.rag.schema import (
|
|||
ElasticsearchKeywordIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RAGIndexFactory(ConfigBasedFactory):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import chromadb
|
|||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
|
|
@ -25,7 +26,6 @@ from metagpt.rag.schema import (
|
|||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
|
||||
|
||||
__all__ = ["ChromaVectorStore"]
|
||||
|
|
@ -1,290 +0,0 @@
|
|||
"""Chroma vector store.
|
||||
|
||||
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
|
||||
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Generator, List, Optional, cast
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from llama_index.core.bridge.pydantic import Field, PrivateAttr
|
||||
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
|
||||
from llama_index.core.utils import truncate_text
|
||||
from llama_index.core.vector_stores.types import (
|
||||
BasePydanticVectorStore,
|
||||
MetadataFilters,
|
||||
VectorStoreQuery,
|
||||
VectorStoreQueryResult,
|
||||
)
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
legacy_metadata_dict_to_node,
|
||||
metadata_dict_to_node,
|
||||
node_to_metadata_dict,
|
||||
)
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def _transform_chroma_filter_condition(condition: str) -> str:
|
||||
"""Translate standard metadata filter op to Chroma specific spec."""
|
||||
if condition == "and":
|
||||
return "$and"
|
||||
elif condition == "or":
|
||||
return "$or"
|
||||
else:
|
||||
raise ValueError(f"Filter condition {condition} not supported")
|
||||
|
||||
|
||||
def _transform_chroma_filter_operator(operator: str) -> str:
|
||||
"""Translate standard metadata filter operator to Chroma specific spec."""
|
||||
if operator == "!=":
|
||||
return "$ne"
|
||||
elif operator == "==":
|
||||
return "$eq"
|
||||
elif operator == ">":
|
||||
return "$gt"
|
||||
elif operator == "<":
|
||||
return "$lt"
|
||||
elif operator == ">=":
|
||||
return "$gte"
|
||||
elif operator == "<=":
|
||||
return "$lte"
|
||||
else:
|
||||
raise ValueError(f"Filter operator {operator} not supported")
|
||||
|
||||
|
||||
def _to_chroma_filter(
|
||||
standard_filters: MetadataFilters,
|
||||
) -> dict:
|
||||
"""Translate standard metadata filters to Chroma specific spec."""
|
||||
filters = {}
|
||||
filters_list = []
|
||||
condition = standard_filters.condition or "and"
|
||||
condition = _transform_chroma_filter_condition(condition)
|
||||
if standard_filters.filters:
|
||||
for filter in standard_filters.filters:
|
||||
if filter.operator:
|
||||
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
|
||||
else:
|
||||
filters_list.append({filter.key: filter.value})
|
||||
if len(filters_list) == 1:
|
||||
# If there is only one filter, return it directly
|
||||
return filters_list[0]
|
||||
elif len(filters_list) > 1:
|
||||
filters[condition] = filters_list
|
||||
return filters
|
||||
|
||||
|
||||
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
|
||||
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
|
||||
|
||||
|
||||
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
|
||||
"""Yield successive max_chunk_size-sized chunks from lst.
|
||||
Args:
|
||||
lst (List[BaseNode]): list of nodes with embeddings
|
||||
max_chunk_size (int): max chunk size
|
||||
Yields:
|
||||
Generator[List[BaseNode], None, None]: list of nodes with embeddings
|
||||
"""
|
||||
for i in range(0, len(lst), max_chunk_size):
|
||||
yield lst[i : i + max_chunk_size]
|
||||
|
||||
|
||||
class ChromaVectorStore(BasePydanticVectorStore):
|
||||
"""Chroma vector store.
|
||||
In this vector store, embeddings are stored within a ChromaDB collection.
|
||||
During query time, the index uses ChromaDB to query for the top
|
||||
k most similar nodes.
|
||||
Args:
|
||||
chroma_collection (chromadb.api.models.Collection.Collection):
|
||||
ChromaDB collection instance
|
||||
"""
|
||||
|
||||
stores_text: bool = True
|
||||
flat_metadata: bool = True
|
||||
collection_name: Optional[str]
|
||||
host: Optional[str]
|
||||
port: Optional[str]
|
||||
ssl: bool
|
||||
headers: Optional[Dict[str, str]]
|
||||
persist_dir: Optional[str]
|
||||
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
_collection: Any = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_collection: Optional[Any] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
collection_kwargs = collection_kwargs or {}
|
||||
if chroma_collection is None:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
self._collection = cast(Collection, chroma_collection)
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
collection_name=collection_name,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
|
||||
try:
|
||||
from chromadb import Collection
|
||||
except ImportError:
|
||||
raise ImportError(import_err_msg)
|
||||
if not isinstance(collection, Collection):
|
||||
raise Exception("argument is not chromadb collection instance")
|
||||
return cls(chroma_collection=collection)
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
collection_name: str,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: dict = {},
|
||||
**kwargs: Any,
|
||||
) -> "ChromaVectorStore":
|
||||
if persist_dir:
|
||||
client = chromadb.PersistentClient(path=persist_dir)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
elif host and port:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
|
||||
return cls(
|
||||
chroma_collection=collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ChromaVectorStore"
|
||||
|
||||
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
|
||||
"""Add nodes to index.
|
||||
Args:
|
||||
nodes: List[BaseNode]: list of nodes with embeddings
|
||||
"""
|
||||
if not self._collection:
|
||||
raise ValueError("Collection not initialized")
|
||||
max_chunk_size = MAX_CHUNK_SIZE
|
||||
node_chunks = chunk_list(nodes, max_chunk_size)
|
||||
all_ids = []
|
||||
for node_chunk in node_chunks:
|
||||
embeddings = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
documents = []
|
||||
for node in node_chunk:
|
||||
embeddings.append(node.get_embedding())
|
||||
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
|
||||
for key in metadata_dict:
|
||||
if metadata_dict[key] is None:
|
||||
metadata_dict[key] = ""
|
||||
metadatas.append(metadata_dict)
|
||||
ids.append(node.node_id)
|
||||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
||||
self._collection.add(
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
all_ids.extend(ids)
|
||||
return all_ids
|
||||
|
||||
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""
|
||||
Delete nodes using with ref_doc_id.
|
||||
Args:
|
||||
ref_doc_id (str): The doc_id of the document to delete.
|
||||
"""
|
||||
self._collection.delete(where={"document_id": ref_doc_id})
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
"""Return client."""
|
||||
return self._collection
|
||||
|
||||
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes.
|
||||
Args:
|
||||
query_embedding (List[float]): query embedding
|
||||
similarity_top_k (int): top k most similar nodes
|
||||
"""
|
||||
if query.filters is not None:
|
||||
if "where" in kwargs:
|
||||
raise ValueError(
|
||||
"Cannot specify metadata filters via both query and kwargs. "
|
||||
"Use kwargs only for chroma specific items that are "
|
||||
"not supported via the generic query interface."
|
||||
)
|
||||
where = _to_chroma_filter(query.filters)
|
||||
else:
|
||||
where = kwargs.pop("where", {})
|
||||
results = self._collection.query(
|
||||
query_embeddings=query.query_embedding,
|
||||
n_results=query.similarity_top_k,
|
||||
where=where,
|
||||
**kwargs,
|
||||
)
|
||||
logger.debug(f"> Top {len(results['documents'])} nodes:")
|
||||
nodes = []
|
||||
similarities = []
|
||||
ids = []
|
||||
for node_id, text, metadata, distance in zip(
|
||||
results["ids"][0],
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
):
|
||||
try:
|
||||
node = metadata_dict_to_node(metadata)
|
||||
node.set_content(text)
|
||||
except Exception:
|
||||
# NOTE: deprecated legacy logic for backward compatibility
|
||||
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
|
||||
node = TextNode(
|
||||
text=text,
|
||||
id_=node_id,
|
||||
metadata=metadata,
|
||||
start_char_idx=node_info.get("start", None),
|
||||
end_char_idx=node_info.get("end", None),
|
||||
relationships=relationships,
|
||||
)
|
||||
nodes.append(node)
|
||||
similarity_score = math.exp(-distance)
|
||||
similarities.append(similarity_score)
|
||||
logger.debug(
|
||||
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
|
||||
)
|
||||
ids.append(node_id)
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
||||
Loading…
Add table
Add a link
Reference in a new issue