mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-09 15:52:38 +02:00
from objs
This commit is contained in:
parent
0f2f460ddc
commit
baa30d3ced
2 changed files with 73 additions and 22 deletions
|
|
@ -38,14 +38,16 @@ from metagpt.rag.schema import (
|
|||
BaseIndexConfig,
|
||||
BaseRankerConfig,
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ObjectNode,
|
||||
)
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
class SimpleEngine(RetrieverQueryEngine):
|
||||
"""
|
||||
SimpleEngine is a lightweight and easy-to-use search engine that integrates
|
||||
"""SimpleEngine is designed to be simple and straightforward.
|
||||
|
||||
It is a lightweight and easy-to-use search engine that integrates
|
||||
document reading, embedding, indexing, retrieving, and ranking functionalities
|
||||
into a single, straightforward workflow. It is designed to quickly set up a
|
||||
search engine from a collection of documents.
|
||||
|
|
@ -78,7 +80,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""This engine is designed to be simple and straightforward
|
||||
"""From docs.
|
||||
|
||||
Must provide either `input_dir` or `input_files`.
|
||||
|
||||
Args:
|
||||
input_dir: Path to the directory.
|
||||
|
|
@ -89,6 +93,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
if not input_dir and not input_files:
|
||||
raise ValueError("Must provide either `input_dir` or `input_files`.")
|
||||
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
|
|
@ -97,6 +104,39 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_objs(
|
||||
cls,
|
||||
objs: Optional[list[RAGObject]] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""From objs.
|
||||
|
||||
Args:
|
||||
objs: List of RAGObject.
|
||||
transformations: Parse documents to nodes. Default [SentenceSplitter].
|
||||
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
|
||||
llm: Must supported by llama index. Default OpenAI.
|
||||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
# check
|
||||
if not retriever_configs or any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("Must provide retriever_configs, and BM25RetrieverConfig is not supported.")
|
||||
|
||||
objs = objs or []
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=embed_model or get_rag_embedding(),
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_index(
|
||||
cls,
|
||||
|
|
@ -110,25 +150,6 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
index = get_index(index_config, embed_model=embed_model or get_rag_embedding())
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
index: BaseIndex,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
llm = llm or get_rag_llm()
|
||||
retriever = get_retriever(configs=retriever_configs, index=index)
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm)
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
async def asearch(self, content: str, **kwargs) -> str:
|
||||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
|
|
@ -156,6 +177,25 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
self._save_nodes(nodes)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
index: BaseIndex,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
llm = llm or get_rag_llm()
|
||||
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
if not isinstance(self.retriever, ModifiableRAGRetriever):
|
||||
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue