just change some func's position

This commit is contained in:
seehi 2024-03-16 09:11:54 +08:00
parent ec2e8bdca3
commit d27026ad81
3 changed files with 21 additions and 21 deletions

View file

@ -29,9 +29,6 @@ class RAGIndexFactory(ConfigBasedFactory):
"""Key is PersistType."""
return super().get_instance(config, **kwargs)
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
@ -59,5 +56,8 @@ class RAGIndexFactory(ConfigBasedFactory):
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
get_index = RAGIndexFactory().get_index

View file

@ -24,12 +24,12 @@ class RankerFactory(ConfigBasedFactory):
return super().get_instances(configs, **kwargs)
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)
def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
config.llm = self._extract_llm(config, **kwargs)
return LLMRerank(**config.model_dump())
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)
get_rankers = RankerFactory().get_rankers

View file

@ -50,21 +50,6 @@ class RetrieverFactory(ConfigBasedFactory):
def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
storage_context=storage_context,
embed_model=old_index._embed_model,
)
return new_index
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
@ -82,5 +67,20 @@ class RetrieverFactory(ConfigBasedFactory):
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return ChromaRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
storage_context=storage_context,
embed_model=old_index._embed_model,
)
return new_index
get_retriever = RetrieverFactory().get_retriever