fix redundant embedding

This commit is contained in:
seehi 2024-04-24 15:56:15 +08:00
parent e976ece310
commit b8b1a666fe
3 changed files with 43 additions and 3 deletions

View file

@ -220,7 +220,9 @@ class SimpleEngine(RetrieverQueryEngine):
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
retriever = get_retriever(
configs=retriever_configs, nodes=nodes, embed_model=embed_model
) # Default VectorStoreIndex(nodes, embed_model).as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
return cls(

View file

@ -1,6 +1,8 @@
"""RAG Retriever Factory."""
from functools import wraps
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
@ -28,6 +30,22 @@ from metagpt.rag.schema import (
)
def get_or_build_index(build_index_func):
"""Find index using `_extract_index` method.
If no index is found, using build_index_func.
"""
@wraps(build_index_func)
def wrapper(self, config, **kwargs):
index = self._extract_index(config, **kwargs)
if index is not None:
return index
return build_index_func(self, config, **kwargs)
return wrapper
class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
@ -59,12 +77,13 @@ class RetrieverFactory(ConfigBasedFactory):
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
config.index = self._extract_index(config, **kwargs) or self._build_faiss_index(config, **kwargs)
config.index = self._build_faiss_index(config, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
nodes = self._extract_nodes(config, **kwargs)
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
@ -95,11 +114,13 @@ class RetrieverFactory(ConfigBasedFactory):
return index
@get_or_build_index
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
@ -107,6 +128,7 @@ class RetrieverFactory(ConfigBasedFactory):
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

View file

@ -119,3 +119,19 @@ class TestRetrieverFactory:
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index
def test_get_or_build_when_get(self, mocker):
want = "existing_index"
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
got = self.retriever_factory._build_es_index(None)
assert got == want
def test_get_or_build_when_build(self, mocker):
want = "call_build_es_index"
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
got = self.retriever_factory._build_es_index(None)
assert got == want