From aca3d1a0cb945b152ab563125570bb659598d6d0 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 20 Feb 2024 16:59:18 +0800 Subject: [PATCH] RAGObject interface add model_dump method; modify by pylint --- examples/rag_pipeline.py | 4 +-- metagpt/rag/engines/simple.py | 6 ++-- metagpt/rag/factory.py | 37 ++++++++++++---------- metagpt/rag/interface.py | 12 +++++-- metagpt/rag/llm.py | 4 +++ metagpt/rag/retrievers/bm25_retriever.py | 4 +++ metagpt/rag/retrievers/faiss_retriever.py | 4 +++ metagpt/rag/retrievers/hybrid_retriever.py | 1 + metagpt/rag/schema.py | 18 +++++++---- 9 files changed, 60 insertions(+), 30 deletions(-) diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 3aae9aa70..675fe62f1 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -92,7 +92,7 @@ class RAGExample: tool: str = "Red Bull Energy Drink" def rag_key(self) -> str: - return "100m Sprint" + return self.goal foo = Player(name="foo") question = f"{foo.rag_key()}" @@ -106,7 +106,7 @@ class RAGExample: print("[Object Detail]") player: Player = nodes[0].metadata["obj"] - print(f"{player.model_dump()}") + print(player) @staticmethod def _print_title(title): diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index e036f6aa9..d48fc8a1a 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -87,9 +87,9 @@ class SimpleEngine(RetrieverQueryEngine): """Inplement tools.SearchInterface""" return await self.aquery(content) - async def aretrieve(self, query: QueryType) -> list[NodeWithScore]: + async def aretrieve(self, query_bundle: QueryType) -> list[NodeWithScore]: """Allow query to be str""" - query_bundle = QueryBundle(query) if isinstance(query, str) else query + query_bundle = QueryBundle(query_bundle) if isinstance(query_bundle, str) else query_bundle return await super().aretrieve(query_bundle) def add_docs(self, input_files: list[str]): @@ -104,7 +104,7 @@ class SimpleEngine(RetrieverQueryEngine): """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" self._ensure_retriever_modifiable() - nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in objs] + nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs] self.retriever.add_nodes(nodes) def _ensure_retriever_modifiable(self): diff --git a/metagpt/rag/factory.py b/metagpt/rag/factory.py index 475acc476..04543f57e 100644 --- a/metagpt/rag/factory.py +++ b/metagpt/rag/factory.py @@ -28,16 +28,11 @@ class BaseFactory: """ def __init__(self, creators: dict[Any, Callable]): - """ - Creators is a dictionary mapping configuration types to creator functions. - The first arg of Creator function should be config. - """ + """Creators is a dictionary mapping configuration types to creator functions.""" self.creators = creators def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]: - if not configs: - return [self._default_instance(**kwargs)] - + """Get instances by configs""" return [self._get_instance(config, **kwargs) for config in configs] def _get_instance(self, config: Any, **kwargs) -> Any: @@ -47,13 +42,11 @@ class BaseFactory: raise ValueError(f"Unknown config: {config}") - def _default_instance(self, **kwargs) -> Any: - raise NotImplementedError("This method should be implemented by subclasses.") - class RetrieverFactory(BaseFactory): + """Modify creators for dynamically instance implementation""" + def __init__(self): - # Dynamically add configuration and corresponding instance implementation. creators = { FAISSRetrieverConfig: self._create_faiss_retriever, BM25RetrieverConfig: self._create_bm25_retriever, @@ -61,7 +54,12 @@ class RetrieverFactory(BaseFactory): super().__init__(creators) def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever: - """Creates and returns a retriever instance based on the provided configurations.""" + """Creates and returns a retriever instance based on the provided configurations. + If multiple retrievers, using SimpleHybridRetriever + """ + if not configs: + return self._default_instance(index) + retrievers = super().get_instances(configs, index=index) return ( @@ -73,7 +71,7 @@ class RetrieverFactory(BaseFactory): def _default_instance(self, index: BaseIndex) -> RAGRetriever: return index.as_retriever() - def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex, **kwargs) -> FAISSRetriever: + def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex) -> FAISSRetriever: vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) storage_context = StorageContext.from_defaults(vector_store=vector_store) vector_index = VectorStoreIndex( @@ -83,13 +81,14 @@ class RetrieverFactory(BaseFactory): ) return FAISSRetriever(**config.model_dump(), index=vector_index) - def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex, **kwargs) -> DynamicBM25Retriever: + def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex) -> DynamicBM25Retriever: return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index) class RankerFactory(BaseFactory): + """Modify creators for dynamically instance implementation""" + def __init__(self): - # Dynamically add configuration and corresponding instance implementation. creators = { LLMRankerConfig: self._create_llm_ranker, } @@ -98,12 +97,16 @@ class RankerFactory(BaseFactory): def get_rankers( self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None ) -> list[BaseNodePostprocessor]: + """Creates and returns a retriever instance based on the provided configurations.""" + if not configs: + return [self._default_instance(service_context)] + return super().get_instances(configs, service_context=service_context) - def _default_instance(self, service_context: ServiceContext = None) -> LLMRerank: + def _default_instance(self, service_context: ServiceContext) -> LLMRerank: return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context) - def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None, **kwargs) -> LLMRerank: + def _create_llm_ranker(self, config: LLMRankerConfig, service_context: ServiceContext = None) -> LLMRerank: return LLMRerank(**config.model_dump(), service_context=service_context) diff --git a/metagpt/rag/interface.py b/metagpt/rag/interface.py index 7ed2c6b58..97faf9f01 100644 --- a/metagpt/rag/interface.py +++ b/metagpt/rag/interface.py @@ -1,6 +1,14 @@ -from typing import Protocol +"""RAG Interface.""" +from typing import Any, Protocol class RAGObject(Protocol): + """Support rag add object""" + def rag_key(self) -> str: - """for rag search""" + """For rag search.""" + + def model_dump(self) -> dict[str, Any]: + """For rag persist. + Pydantic Model don't need to implement this, as there is a built-in function named model_dump + """ diff --git a/metagpt/rag/llm.py b/metagpt/rag/llm.py index 405b29991..83b3a849d 100644 --- a/metagpt/rag/llm.py +++ b/metagpt/rag/llm.py @@ -1,7 +1,11 @@ +"""RAG LLM +The LLM of LlamaIndex and the LLM of MG are not the same. +""" from llama_index.llms import OpenAI from metagpt.config2 import config def get_default_llm() -> OpenAI: + """OpenAI""" return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key, model=config.llm.model) diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index c7257e00f..dc8d59802 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -1,9 +1,13 @@ +"""BM25 retriever.""" from llama_index.retrievers import BM25Retriever from llama_index.schema import BaseNode class DynamicBM25Retriever(BM25Retriever): + """BM25 retriever.""" + def add_nodes(self, nodes: list[BaseNode], **kwargs): + """Support add nodes""" try: from rank_bm25 import BM25Okapi except ImportError: diff --git a/metagpt/rag/retrievers/faiss_retriever.py b/metagpt/rag/retrievers/faiss_retriever.py index aa91aaaff..a898d0292 100644 --- a/metagpt/rag/retrievers/faiss_retriever.py +++ b/metagpt/rag/retrievers/faiss_retriever.py @@ -1,7 +1,11 @@ +"""FAISS retriever.""" from llama_index.retrievers import VectorIndexRetriever from llama_index.schema import BaseNode class FAISSRetriever(VectorIndexRetriever): + """FAISS retriever.""" + def add_nodes(self, nodes: list[BaseNode], **kwargs): + """Support add nodes""" self._index.insert_nodes(nodes, **kwargs) diff --git a/metagpt/rag/retrievers/hybrid_retriever.py b/metagpt/rag/retrievers/hybrid_retriever.py index 04889b702..d514194c9 100644 --- a/metagpt/rag/retrievers/hybrid_retriever.py +++ b/metagpt/rag/retrievers/hybrid_retriever.py @@ -38,5 +38,6 @@ class SimpleHybridRetriever(RAGRetriever): return result def add_nodes(self, nodes: list[BaseNode]): + """Support add nodes""" for r in self.retrievers: r.add_nodes(nodes) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index d1cbf31bf..1e3d945f2 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,28 +1,34 @@ -"""Retriever schemas""" +"""RAG schemas""" from typing import Union -from pydantic import BaseModel +from pydantic import BaseModel, Field class RetrieverConfig(BaseModel): - similarity_top_k: int = 5 + """Common config for retrievers.""" + + similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.") class FAISSRetrieverConfig(RetrieverConfig): - dimensions: int = 1536 + """Config for FAISS-based retrievers.""" + + dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.") class BM25RetrieverConfig(RetrieverConfig): - ... + """Config for BM25-based retrievers.""" class RankerConfig(BaseModel): + """Common config for rankers.""" + top_n: int = 5 class LLMRankerConfig(RankerConfig): - ... + """Config for LLM-based rankers.""" # If add new config, it is necessary to add the corresponding instance implementation in rag.factory