Merge pull request #1224 from seehi/fix-rag-redundant-embedding

Fix the potential duplicate embeddings in the RAG module
This commit is contained in:
Alexander Wu 2024-04-24 20:01:43 +08:00 committed by GitHub
commit c779f6977e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 184 additions and 75 deletions

View file

@ -40,7 +40,10 @@ class Player(BaseModel):
class RAGExample:
"""Show how to use RAG."""
"""Show how to use RAG.
Default engine use LLM Reranker, if the answer from the LLM is incorrect, may encounter `IndexError: list index out of range`.
"""
def __init__(self, engine: SimpleEngine = None):
self._engine = engine
@ -59,6 +62,7 @@ class RAGExample:
def engine(self, value: SimpleEngine):
self._engine = value
@handle_exception
async def run_pipeline(self, question=QUESTION, print_title=True):
"""This example run rag pipeline, use faiss retriever and llm ranker, will print something like:
@ -79,6 +83,7 @@ class RAGExample:
answer = await self.engine.aquery(question)
self._print_query_result(answer)
@handle_exception
async def add_docs(self):
"""This example show how to add docs.
@ -148,6 +153,7 @@ class RAGExample:
except Exception as e:
logger.error(f"nodes is empty, llm don't answer correctly, exception: {e}")
@handle_exception
async def init_objects(self):
"""This example show how to from objs, will print something like:
@ -160,6 +166,7 @@ class RAGExample:
await self.add_objects(print_title=False)
self.engine = pre_engine
@handle_exception
async def init_and_query_chromadb(self):
"""This example show how to use chromadb. how to save and load index. will print something like:
@ -233,7 +240,7 @@ class RAGExample:
async def main():
"""RAG pipeline"""
"""RAG pipeline."""
e = RAGExample()
await e.run_pipeline()
await e.add_docs()

View file

@ -4,7 +4,7 @@ import json
import os
from typing import Any, Optional, Union
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core import SimpleDirectoryReader
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
@ -63,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine):
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
transformations: Optional[list[TransformComponent]] = None,
) -> None:
super().__init__(
retriever=retriever,
@ -71,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine):
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
self._transformations = transformations or self._default_transformations()
@classmethod
def from_docs(
@ -103,12 +103,17 @@ class SimpleEngine(RetrieverQueryEngine):
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
cls._fix_document_metadata(documents)
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations = transformations or cls._default_transformations()
nodes = run_transformations(documents, transformations=transformations)
return cls._from_nodes(
nodes=nodes,
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_objs(
@ -137,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
return cls._from_nodes(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_index(
@ -183,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine):
documents = SimpleDirectoryReader(input_files=input_files).load_data()
self._fix_document_metadata(documents)
nodes = run_transformations(documents, transformations=self.index._transformations)
nodes = run_transformations(documents, transformations=self._transformations)
self._save_nodes(nodes)
def add_objs(self, objs: list[RAGObject]):
@ -199,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine):
self._persist(str(persist_dir), **kwargs)
@classmethod
def _from_nodes(
cls,
nodes: list[BaseNode],
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
transformations=transformations,
)
@classmethod
def _from_index(
cls,
@ -208,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine):
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
@ -215,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine):
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
def _ensure_retriever_modifiable(self):
@ -266,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine):
return MockEmbedding(embed_dim=1)
return embed_model or get_rag_embedding()
@staticmethod
def _default_transformations():
return [SentenceSplitter()]

View file

@ -36,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""
def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
"""Get instance by the type of key.
Call func by the type of key, and the key will be passed to func.
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
Raise Exception if key not found.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)
self._raise_for_key(key)
def _raise_for_key(self, key: Any):
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
Return None if not found.
"""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
@ -57,6 +64,4 @@ class ConfigBasedFactory(GenericFactory):
if key in kwargs:
return kwargs[key]
raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)
return None

View file

@ -1,10 +1,13 @@
"""RAG Retriever Factory."""
import copy
from functools import wraps
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
@ -24,10 +27,25 @@ from metagpt.rag.schema import (
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
def get_or_build_index(build_index_func):
"""Decorator to get or build an index.
Get index using `_extract_index` method, if not found, using build_index_func.
"""
@wraps(build_index_func)
def wrapper(self, config, **kwargs):
index = self._extract_index(config, **kwargs)
if index is not None:
return index
return build_index_func(self, config, **kwargs)
return wrapper
class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory):
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_faiss_index(config, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_chroma_index(config, **kwargs)
return ChromaRetriever(**config.model_dump())
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_es_index(config, **kwargs)
return ElasticsearchRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
return self._val_from_config_or_kwargs("nodes", config, **kwargs)
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
def _build_default_index(self, **kwargs) -> VectorStoreIndex:
index = VectorStoreIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
)
return index
@get_or_build_index
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
return self._build_index_from_vector_store(config, vector_store, **kwargs)
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
index = VectorStoreIndex(
nodes=self._extract_nodes(config, **kwargs),
storage_context=storage_context,
embed_model=old_index._embed_model,
embed_model=self._extract_embed_model(config, **kwargs),
)
return new_index
return index
get_retriever = RetrieverFactory().get_retriever

View file

@ -25,10 +25,6 @@ class TestSimpleEngine:
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
@pytest.fixture
def mock_get_retriever(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
@ -45,7 +41,6 @@ class TestSimpleEngine:
self,
mocker,
mock_simple_directory_reader,
mock_vector_store_index,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
@ -81,11 +76,8 @@ class TestSimpleEngine:
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
configs=retriever_configs, index=mock_vector_store_index.return_value
)
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
@ -119,7 +111,7 @@ class TestSimpleEngine:
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is not None
assert engine._transformations is not None
def test_from_objs_with_bm25_config(self):
# Setup
@ -137,6 +129,7 @@ class TestSimpleEngine:
def test_from_index(self, mocker, mock_llm, mock_embedding):
# Mock
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.as_retriever.return_value = "retriever"
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
mock_get_index.return_value = mock_index
@ -149,7 +142,7 @@ class TestSimpleEngine:
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is mock_index
assert engine._retriever == "retriever"
@pytest.mark.asyncio
async def test_asearch(self, mocker):
@ -200,14 +193,11 @@ class TestSimpleEngine:
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index._transformations = mocker.MagicMock()
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
mock_run_transformations.return_value = ["node1", "node2"]
# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
engine = SimpleEngine(retriever=mock_retriever)
input_files = ["test_file1", "test_file2"]
# Exec
@ -230,7 +220,7 @@ class TestSimpleEngine:
return ""
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
engine = SimpleEngine(retriever=mock_retriever)
# Exec
engine.add_objs(objs=objs)

View file

@ -97,6 +97,5 @@ class TestConfigBasedFactory:
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert val is None

View file

@ -1,6 +1,8 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.embeddings import MockEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
@ -43,6 +45,14 @@ class TestRetrieverFactory:
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
@pytest.fixture
def mock_nodes(self, mocker):
return [TextNode(text="msg")]
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
@ -52,42 +62,40 @@ class TestRetrieverFactory:
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
retriever = self.retriever_factory.get_retriever(
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
)
assert isinstance(retriever, SimpleHybridRetriever)
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, ElasticsearchRetriever)
@ -111,3 +119,19 @@ class TestRetrieverFactory:
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index
def test_get_or_build_when_get(self, mocker):
want = "existing_index"
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
got = self.retriever_factory._build_es_index(None)
assert got == want
def test_get_or_build_when_build(self, mocker):
want = "call_build_es_index"
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
got = self.retriever_factory._build_es_index(None)
assert got == want