mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-07 06:42:38 +02:00
RAGObject interface add model_dump method; modify by pylint
This commit is contained in:
parent
ada8e8e37c
commit
aca3d1a0cb
9 changed files with 60 additions and 30 deletions
|
|
@ -87,9 +87,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
|
||||
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
async def aretrieve(self, query_bundle: QueryType) -> list[NodeWithScore]:
|
||||
"""Allow query to be str"""
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
query_bundle = QueryBundle(query_bundle) if isinstance(query_bundle, str) else query_bundle
|
||||
return await super().aretrieve(query_bundle)
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
|
|
@ -104,7 +104,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in objs]
|
||||
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs]
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
|
|
|
|||
|
|
@ -28,16 +28,11 @@ class BaseFactory:
|
|||
"""
|
||||
|
||||
def __init__(self, creators: dict[Any, Callable]):
|
||||
"""
|
||||
Creators is a dictionary mapping configuration types to creator functions.
|
||||
The first arg of Creator function should be config.
|
||||
"""
|
||||
"""Creators is a dictionary mapping configuration types to creator functions."""
|
||||
self.creators = creators
|
||||
|
||||
def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]:
|
||||
if not configs:
|
||||
return [self._default_instance(**kwargs)]
|
||||
|
||||
"""Get instances by configs"""
|
||||
return [self._get_instance(config, **kwargs) for config in configs]
|
||||
|
||||
def _get_instance(self, config: Any, **kwargs) -> Any:
|
||||
|
|
@ -47,13 +42,11 @@ class BaseFactory:
|
|||
|
||||
raise ValueError(f"Unknown config: {config}")
|
||||
|
||||
def _default_instance(self, **kwargs) -> Any:
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
|
||||
|
||||
class RetrieverFactory(BaseFactory):
|
||||
"""Modify creators for dynamically instance implementation"""
|
||||
|
||||
def __init__(self):
|
||||
# Dynamically add configuration and corresponding instance implementation.
|
||||
creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
|
|
@ -61,7 +54,12 @@ class RetrieverFactory(BaseFactory):
|
|||
super().__init__(creators)
|
||||
|
||||
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
"""Creates and returns a retriever instance based on the provided configurations.
|
||||
If multiple retrievers, using SimpleHybridRetriever
|
||||
"""
|
||||
if not configs:
|
||||
return self._default_instance(index)
|
||||
|
||||
retrievers = super().get_instances(configs, index=index)
|
||||
|
||||
return (
|
||||
|
|
@ -73,7 +71,7 @@ class RetrieverFactory(BaseFactory):
|
|||
def _default_instance(self, index: BaseIndex) -> RAGRetriever:
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex, **kwargs) -> FAISSRetriever:
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
vector_index = VectorStoreIndex(
|
||||
|
|
@ -83,13 +81,14 @@ class RetrieverFactory(BaseFactory):
|
|||
)
|
||||
return FAISSRetriever(**config.model_dump(), index=vector_index)
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex, **kwargs) -> DynamicBM25Retriever:
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex) -> DynamicBM25Retriever:
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
|
||||
|
||||
|
||||
class RankerFactory(BaseFactory):
|
||||
"""Modify creators for dynamically instance implementation"""
|
||||
|
||||
def __init__(self):
|
||||
# Dynamically add configuration and corresponding instance implementation.
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
|
|
@ -98,12 +97,16 @@ class RankerFactory(BaseFactory):
|
|||
def get_rankers(
|
||||
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
|
||||
) -> list[BaseNodePostprocessor]:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return [self._default_instance(service_context)]
|
||||
|
||||
return super().get_instances(configs, service_context=service_context)
|
||||
|
||||
def _default_instance(self, service_context: ServiceContext = None) -> LLMRerank:
|
||||
def _default_instance(self, service_context: ServiceContext) -> LLMRerank:
|
||||
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None, **kwargs) -> LLMRerank:
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, service_context: ServiceContext = None) -> LLMRerank:
|
||||
return LLMRerank(**config.model_dump(), service_context=service_context)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,14 @@
|
|||
from typing import Protocol
|
||||
"""RAG Interface."""
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class RAGObject(Protocol):
|
||||
"""Support rag add object"""
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""for rag search"""
|
||||
"""For rag search."""
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
"""For rag persist.
|
||||
Pydantic Model don't need to implement this, as there is a built-in function named model_dump
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
"""RAG LLM
|
||||
The LLM of LlamaIndex and the LLM of MG are not the same.
|
||||
"""
|
||||
from llama_index.llms import OpenAI
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_default_llm() -> OpenAI:
|
||||
"""OpenAI"""
|
||||
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key, model=config.llm.model)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
"""BM25 retriever."""
|
||||
from llama_index.retrievers import BM25Retriever
|
||||
from llama_index.schema import BaseNode
|
||||
|
||||
|
||||
class DynamicBM25Retriever(BM25Retriever):
|
||||
"""BM25 retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
"""Support add nodes"""
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
"""FAISS retriever."""
|
||||
from llama_index.retrievers import VectorIndexRetriever
|
||||
from llama_index.schema import BaseNode
|
||||
|
||||
|
||||
class FAISSRetriever(VectorIndexRetriever):
|
||||
"""FAISS retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
"""Support add nodes"""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
|
|
|||
|
|
@ -38,5 +38,6 @@ class SimpleHybridRetriever(RAGRetriever):
|
|||
return result
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode]):
|
||||
"""Support add nodes"""
|
||||
for r in self.retrievers:
|
||||
r.add_nodes(nodes)
|
||||
|
|
|
|||
|
|
@ -1,28 +1,34 @@
|
|||
"""Retriever schemas"""
|
||||
"""RAG schemas"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RetrieverConfig(BaseModel):
|
||||
similarity_top_k: int = 5
|
||||
"""Common config for retrievers."""
|
||||
|
||||
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
|
||||
|
||||
|
||||
class FAISSRetrieverConfig(RetrieverConfig):
|
||||
dimensions: int = 1536
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
|
||||
class BM25RetrieverConfig(RetrieverConfig):
|
||||
...
|
||||
"""Config for BM25-based retrievers."""
|
||||
|
||||
|
||||
class RankerConfig(BaseModel):
|
||||
"""Common config for rankers."""
|
||||
|
||||
top_n: int = 5
|
||||
|
||||
|
||||
class LLMRankerConfig(RankerConfig):
|
||||
...
|
||||
"""Config for LLM-based rankers."""
|
||||
|
||||
|
||||
# If add new config, it is necessary to add the corresponding instance implementation in rag.factory
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue