RAGObject interface add model_dump method; modify by pylint

This commit is contained in:
seehi 2024-02-20 16:59:18 +08:00
parent ada8e8e37c
commit aca3d1a0cb
9 changed files with 60 additions and 30 deletions

View file

@ -87,9 +87,9 @@ class SimpleEngine(RetrieverQueryEngine):
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
async def aretrieve(self, query_bundle: QueryType) -> list[NodeWithScore]:
"""Allow query to be str"""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
query_bundle = QueryBundle(query_bundle) if isinstance(query_bundle, str) else query_bundle
return await super().aretrieve(query_bundle)
def add_docs(self, input_files: list[str]):
@ -104,7 +104,7 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in objs]
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs]
self.retriever.add_nodes(nodes)
def _ensure_retriever_modifiable(self):

View file

@ -28,16 +28,11 @@ class BaseFactory:
"""
def __init__(self, creators: dict[Any, Callable]):
"""
Creators is a dictionary mapping configuration types to creator functions.
The first arg of Creator function should be config.
"""
"""Creators is a dictionary mapping configuration types to creator functions."""
self.creators = creators
def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]:
if not configs:
return [self._default_instance(**kwargs)]
"""Get instances by configs"""
return [self._get_instance(config, **kwargs) for config in configs]
def _get_instance(self, config: Any, **kwargs) -> Any:
@ -47,13 +42,11 @@ class BaseFactory:
raise ValueError(f"Unknown config: {config}")
def _default_instance(self, **kwargs) -> Any:
raise NotImplementedError("This method should be implemented by subclasses.")
class RetrieverFactory(BaseFactory):
"""Modify creators for dynamically instance implementation"""
def __init__(self):
# Dynamically add configuration and corresponding instance implementation.
creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
@ -61,7 +54,12 @@ class RetrieverFactory(BaseFactory):
super().__init__(creators)
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations."""
"""Creates and returns a retriever instance based on the provided configurations.
If multiple retrievers, using SimpleHybridRetriever
"""
if not configs:
return self._default_instance(index)
retrievers = super().get_instances(configs, index=index)
return (
@ -73,7 +71,7 @@ class RetrieverFactory(BaseFactory):
def _default_instance(self, index: BaseIndex) -> RAGRetriever:
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex, **kwargs) -> FAISSRetriever:
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(
@ -83,13 +81,14 @@ class RetrieverFactory(BaseFactory):
)
return FAISSRetriever(**config.model_dump(), index=vector_index)
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex, **kwargs) -> DynamicBM25Retriever:
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex) -> DynamicBM25Retriever:
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
class RankerFactory(BaseFactory):
"""Modify creators for dynamically instance implementation"""
def __init__(self):
# Dynamically add configuration and corresponding instance implementation.
creators = {
LLMRankerConfig: self._create_llm_ranker,
}
@ -98,12 +97,16 @@ class RankerFactory(BaseFactory):
def get_rankers(
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return [self._default_instance(service_context)]
return super().get_instances(configs, service_context=service_context)
def _default_instance(self, service_context: ServiceContext = None) -> LLMRerank:
def _default_instance(self, service_context: ServiceContext) -> LLMRerank:
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None, **kwargs) -> LLMRerank:
def _create_llm_ranker(self, config: LLMRankerConfig, service_context: ServiceContext = None) -> LLMRerank:
return LLMRerank(**config.model_dump(), service_context=service_context)

View file

@ -1,6 +1,14 @@
from typing import Protocol
"""RAG Interface."""
from typing import Any, Protocol
class RAGObject(Protocol):
"""Support rag add object"""
def rag_key(self) -> str:
"""for rag search"""
"""For rag search."""
def model_dump(self) -> dict[str, Any]:
"""For rag persist.
Pydantic Model don't need to implement this, as there is a built-in function named model_dump
"""

View file

@ -1,7 +1,11 @@
"""RAG LLM
The LLM of LlamaIndex and the LLM of MG are not the same.
"""
from llama_index.llms import OpenAI
from metagpt.config2 import config
def get_default_llm() -> OpenAI:
"""OpenAI"""
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key, model=config.llm.model)

View file

@ -1,9 +1,13 @@
"""BM25 retriever."""
from llama_index.retrievers import BM25Retriever
from llama_index.schema import BaseNode
class DynamicBM25Retriever(BM25Retriever):
"""BM25 retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs):
"""Support add nodes"""
try:
from rank_bm25 import BM25Okapi
except ImportError:

View file

@ -1,7 +1,11 @@
"""FAISS retriever."""
from llama_index.retrievers import VectorIndexRetriever
from llama_index.schema import BaseNode
class FAISSRetriever(VectorIndexRetriever):
"""FAISS retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs):
"""Support add nodes"""
self._index.insert_nodes(nodes, **kwargs)

View file

@ -38,5 +38,6 @@ class SimpleHybridRetriever(RAGRetriever):
return result
def add_nodes(self, nodes: list[BaseNode]):
"""Support add nodes"""
for r in self.retrievers:
r.add_nodes(nodes)

View file

@ -1,28 +1,34 @@
"""Retriever schemas"""
"""RAG schemas"""
from typing import Union
from pydantic import BaseModel
from pydantic import BaseModel, Field
class RetrieverConfig(BaseModel):
similarity_top_k: int = 5
"""Common config for retrievers."""
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
class FAISSRetrieverConfig(RetrieverConfig):
dimensions: int = 1536
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
class BM25RetrieverConfig(RetrieverConfig):
...
"""Config for BM25-based retrievers."""
class RankerConfig(BaseModel):
"""Common config for rankers."""
top_n: int = 5
class LLMRankerConfig(RankerConfig):
...
"""Config for LLM-based rankers."""
# If add new config, it is necessary to add the corresponding instance implementation in rag.factory