from objs

This commit is contained in:
seehi 2024-03-07 17:56:35 +08:00
parent 0f2f460ddc
commit baa30d3ced
2 changed files with 73 additions and 22 deletions

View file

@ -38,14 +38,16 @@ from metagpt.rag.schema import (
BaseIndexConfig,
BaseRankerConfig,
BaseRetrieverConfig,
BM25RetrieverConfig,
ObjectNode,
)
from metagpt.utils.common import import_class
class SimpleEngine(RetrieverQueryEngine):
"""
SimpleEngine is a lightweight and easy-to-use search engine that integrates
"""SimpleEngine is designed to be simple and straightforward.
It is a lightweight and easy-to-use search engine that integrates
document reading, embedding, indexing, retrieving, and ranking functionalities
into a single, straightforward workflow. It is designed to quickly set up a
search engine from a collection of documents.
@ -78,7 +80,9 @@ class SimpleEngine(RetrieverQueryEngine):
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""This engine is designed to be simple and straightforward
"""From docs.
Must provide either `input_dir` or `input_files`.
Args:
input_dir: Path to the directory.
@ -89,6 +93,9 @@ class SimpleEngine(RetrieverQueryEngine):
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
index = VectorStoreIndex.from_documents(
documents=documents,
@ -97,6 +104,39 @@ class SimpleEngine(RetrieverQueryEngine):
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_objs(
cls,
objs: Optional[list[RAGObject]] = None,
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""From objs.
Args:
objs: List of RAGObject.
transformations: Parse documents to nodes. Default [SentenceSplitter].
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
llm: Must supported by llama index. Default OpenAI.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
# check
if not retriever_configs or any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("Must provide retriever_configs, and BM25RetrieverConfig is not supported.")
objs = objs or []
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=embed_model or get_rag_embedding(),
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_index(
cls,
@ -110,25 +150,6 @@ class SimpleEngine(RetrieverQueryEngine):
index = get_index(index_config, embed_model=embed_model or get_rag_embedding())
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def _from_index(
cls,
index: BaseIndex,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index)
rankers = get_rankers(configs=ranker_configs, llm=llm)
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
async def asearch(self, content: str, **kwargs) -> str:
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
@ -156,6 +177,25 @@ class SimpleEngine(RetrieverQueryEngine):
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
self._save_nodes(nodes)
@classmethod
def _from_index(
cls,
index: BaseIndex,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
def _ensure_retriever_modifiable(self):
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")