fix redundant embedding

This commit is contained in:
seehi 2024-04-23 21:48:50 +08:00
parent e212e7b7b9
commit 263d79980e
6 changed files with 137 additions and 73 deletions

View file

@ -4,7 +4,7 @@ import json
import os
from typing import Any, Optional, Union
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core import SimpleDirectoryReader
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
@ -63,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine):
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
transformations: Optional[list[TransformComponent]] = None,
) -> None:
super().__init__(
retriever=retriever,
@ -71,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine):
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
self._transformations = transformations or self._default_transformations()
@classmethod
def from_docs(
@ -103,12 +103,17 @@ class SimpleEngine(RetrieverQueryEngine):
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
cls._fix_document_metadata(documents)
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations = transformations or cls._default_transformations()
nodes = run_transformations(documents, transformations=transformations)
return cls._from_nodes(
nodes=nodes,
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_objs(
@ -137,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
return cls._from_nodes(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_index(
@ -183,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine):
documents = SimpleDirectoryReader(input_files=input_files).load_data()
self._fix_document_metadata(documents)
nodes = run_transformations(documents, transformations=self.index._transformations)
nodes = run_transformations(documents, transformations=self._transformations)
self._save_nodes(nodes)
def add_objs(self, objs: list[RAGObject]):
@ -199,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine):
self._persist(str(persist_dir), **kwargs)
@classmethod
def _from_nodes(
cls,
nodes: list[BaseNode],
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
transformations=transformations,
)
@classmethod
def _from_index(
cls,
@ -208,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine):
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
@ -215,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine):
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
def _ensure_retriever_modifiable(self):
@ -266,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine):
return MockEmbedding(embed_dim=1)
return embed_model or get_rag_embedding()
@staticmethod
def _default_transformations():
return [SentenceSplitter()]

View file

@ -36,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""
def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
"""Get instance by the type of key.
Call func by the type of key, and the key will be passed to func.
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
Raise Exception if key not found.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)
self._raise_for_key(key)
def _raise_for_key(self, key: Any):
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
Return None if not found.
"""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
@ -57,6 +64,4 @@ class ConfigBasedFactory(GenericFactory):
if key in kwargs:
return kwargs[key]
raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)
return None

View file

@ -1,10 +1,11 @@
"""RAG Retriever Factory."""
import copy
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
@ -24,7 +25,6 @@ from metagpt.rag.schema import (
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
@ -54,48 +54,75 @@ class RetrieverFactory(ConfigBasedFactory):
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._extract_index(config, **kwargs) or self._build_faiss_index(config, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = self._extract_nodes(config, **kwargs)
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_chroma_index(config, **kwargs)
return ChromaRetriever(**config.model_dump())
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_es_index(config, **kwargs)
return ElasticsearchRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
return self._val_from_config_or_kwargs("nodes", config, **kwargs)
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
def _build_default_index(self, **kwargs) -> VectorStoreIndex:
index = VectorStoreIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
)
return index
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
return self._build_index_from_vector_store(config, vector_store, **kwargs)
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
return self._build_index_from_vector_store(config, vector_store, **kwargs)
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
return self._build_index_from_vector_store(config, vector_store, **kwargs)
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
index = VectorStoreIndex(
nodes=self._extract_nodes(config, **kwargs),
storage_context=storage_context,
embed_model=old_index._embed_model,
embed_model=self._extract_embed_model(config, **kwargs),
)
return new_index
return index
get_retriever = RetrieverFactory().get_retriever

View file

@ -25,10 +25,6 @@ class TestSimpleEngine:
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
@pytest.fixture
def mock_get_retriever(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
@ -45,7 +41,6 @@ class TestSimpleEngine:
self,
mocker,
mock_simple_directory_reader,
mock_vector_store_index,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
@ -81,11 +76,8 @@ class TestSimpleEngine:
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
configs=retriever_configs, index=mock_vector_store_index.return_value
)
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
@ -119,7 +111,7 @@ class TestSimpleEngine:
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is not None
assert engine._transformations is not None
def test_from_objs_with_bm25_config(self):
# Setup
@ -137,6 +129,7 @@ class TestSimpleEngine:
def test_from_index(self, mocker, mock_llm, mock_embedding):
# Mock
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.as_retriever.return_value = "retriever"
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
mock_get_index.return_value = mock_index
@ -149,7 +142,7 @@ class TestSimpleEngine:
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is mock_index
assert engine._retriever == "retriever"
@pytest.mark.asyncio
async def test_asearch(self, mocker):
@ -200,14 +193,11 @@ class TestSimpleEngine:
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index._transformations = mocker.MagicMock()
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
mock_run_transformations.return_value = ["node1", "node2"]
# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
engine = SimpleEngine(retriever=mock_retriever)
input_files = ["test_file1", "test_file2"]
# Exec
@ -230,7 +220,7 @@ class TestSimpleEngine:
return ""
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
engine = SimpleEngine(retriever=mock_retriever)
# Exec
engine.add_objs(objs=objs)

View file

@ -97,6 +97,5 @@ class TestConfigBasedFactory:
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert val is None

View file

@ -1,6 +1,8 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.embeddings import MockEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
@ -43,6 +45,14 @@ class TestRetrieverFactory:
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
@pytest.fixture
def mock_nodes(self, mocker):
return [TextNode(text="msg")]
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
@ -52,42 +62,40 @@ class TestRetrieverFactory:
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
retriever = self.retriever_factory.get_retriever(
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
)
assert isinstance(retriever, SimpleHybridRetriever)
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, ElasticsearchRetriever)