mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
fix redundant embedding
This commit is contained in:
parent
e976ece310
commit
b8b1a666fe
3 changed files with 43 additions and 3 deletions
|
|
@ -220,7 +220,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
|
||||
retriever = get_retriever(
|
||||
configs=retriever_configs, nodes=nodes, embed_model=embed_model
|
||||
) # Default VectorStoreIndex(nodes, embed_model).as_retriever
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
return cls(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
|
|
@ -28,6 +30,22 @@ from metagpt.rag.schema import (
|
|||
)
|
||||
|
||||
|
||||
def get_or_build_index(build_index_func):
|
||||
"""Find index using `_extract_index` method.
|
||||
|
||||
If no index is found, using build_index_func.
|
||||
"""
|
||||
|
||||
@wraps(build_index_func)
|
||||
def wrapper(self, config, **kwargs):
|
||||
index = self._extract_index(config, **kwargs)
|
||||
if index is not None:
|
||||
return index
|
||||
return build_index_func(self, config, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
|
|
@ -59,12 +77,13 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
config.index = self._extract_index(config, **kwargs) or self._build_faiss_index(config, **kwargs)
|
||||
config.index = self._build_faiss_index(config, **kwargs)
|
||||
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
nodes = self._extract_nodes(config, **kwargs)
|
||||
index = self._extract_index(config, **kwargs)
|
||||
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
|
||||
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
|
|
@ -95,11 +114,13 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
|
||||
return index
|
||||
|
||||
@get_or_build_index
|
||||
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
|
|
@ -107,6 +128,7 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
|
|
|
|||
|
|
@ -119,3 +119,19 @@ class TestRetrieverFactory:
|
|||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_get_or_build_when_get(self, mocker):
|
||||
want = "existing_index"
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
||||
def test_get_or_build_when_build(self, mocker):
|
||||
want = "call_build_es_index"
|
||||
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue