rag add docs

This commit is contained in:
seehi 2024-02-06 13:56:49 +08:00
parent 3ae422193d
commit 254088b026
15 changed files with 209 additions and 111 deletions

View file

@ -1,3 +1,6 @@
from metagpt.rag.engines.simple import SimpleEngine
"""Engines init"""
__all__ = ["SimpleEngine"]
from metagpt.rag.engines.simple import SimpleEngine

View file

@ -1,6 +1,7 @@
"""Simple Engine."""
from llama_index import ServiceContext, SimpleDirectoryReader
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
from llama_index.embeddings.base import BaseEmbedding
from llama_index.llms.llm import LLM
from llama_index.query_engine import RetrieverQueryEngine
@ -9,26 +10,23 @@ from llama_index.schema import NodeWithScore, QueryBundle, QueryType
from metagpt.rag.llm import get_default_llm
from metagpt.rag.rankers import get_rankers
from metagpt.rag.retrievers import get_retriever
from metagpt.rag.schema import RankerConfig, RetrieverConfig
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
from metagpt.utils.embedding import get_embedding
class SimpleEngine(RetrieverQueryEngine):
"""
SimpleEngine is a search engine that uses a vector index for retrieving documents.
"""
@classmethod
def from_docs(
cls,
input_dir: str = None,
input_files: list = None,
input_files: list[str] = None,
llm: LLM = None,
embed_model: BaseEmbedding = None,
chunk_size: int = None,
chunk_overlap: int = None,
retriever_configs: list[RetrieverConfig] = None,
ranker_configs: list[RankerConfig] = None,
retriever_configs: list[RetrieverConfigType] = None,
ranker_configs: list[RankerConfigType] = None,
) -> "SimpleEngine":
"""This engine is designed to be simple and straightforward
@ -44,8 +42,8 @@ class SimpleEngine(RetrieverQueryEngine):
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
nodes = service_context.node_parser.get_nodes_from_documents(documents)
retriever = get_retriever(nodes, configs=retriever_configs, service_context=service_context)
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
retriever = get_retriever(index, configs=retriever_configs)
rankers = get_rankers(configs=ranker_configs, service_context=service_context)
return SimpleEngine(retriever=retriever, node_postprocessors=rankers)
@ -58,3 +56,8 @@ class SimpleEngine(RetrieverQueryEngine):
"""Allow query to be str"""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
return await super().aretrieve(query_bundle)
def add_docs(self, input_files: list[str]):
documents = SimpleDirectoryReader(input_files=input_files).load_data()
retriever: RAGRetriever = self.retriever
retriever.add_docs(documents)

View file

@ -4,4 +4,4 @@ from metagpt.config2 import config
def get_default_llm() -> OpenAI:
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key)
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key, model=config.llm.model)

View file

@ -1,34 +1,6 @@
"""init"""
from metagpt.rag.schema import RankerConfig, LLMRankerConfig
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
"""Rankers init"""
from metagpt.rag.rankers.factory import get_rankers
def get_rankers(
configs: list[RankerConfig] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
if not configs:
return [_default_ranker(service_context)]
return [_get_ranker(config, service_context) for config in configs]
def _default_ranker(service_context: ServiceContext = None):
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _get_ranker(config: RankerConfig, service_context: ServiceContext = None):
ranker_factory = {
LLMRankerConfig: _create_llm_ranker,
}
create_func = ranker_factory.get(type(config))
if create_func:
return create_func(config, service_context)
raise ValueError(f"Unknown ranker config: {config}")
def _create_llm_ranker(config, service_context=None):
return LLMRerank(top_n=config.top_n, service_context=service_context)
__all__ = ["get_rankers"]

View file

@ -0,0 +1,36 @@
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
from metagpt.rag.schema import LLMRankerConfig, RankerConfigType
class RankerFactory:
def __init__(self):
self.ranker_creators = {
LLMRankerConfig: self._create_llm_ranker,
}
def get_rankers(
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
if not configs:
return [self._default_ranker(service_context)]
return [self._get_ranker(config, service_context) for config in configs]
def _default_ranker(self, service_context: ServiceContext = None):
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None):
create_func = self.ranker_creators.get(type(config))
if create_func:
return create_func(config, service_context)
raise ValueError(f"Unknown ranker config: {config}")
def _create_llm_ranker(self, config, service_context=None):
return LLMRerank(top_n=config.top_n, service_context=service_context)
get_rankers = RankerFactory().get_rankers

View file

@ -1,55 +1,6 @@
"""Retrievers init"""
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.factory import get_retriever
__all__ = ["SimpleHybridRetriever", "get_retriever"]
from llama_index import (
ServiceContext,
StorageContext,
VectorStoreIndex,
)
from llama_index.retrievers import BaseRetriever, BM25Retriever, VectorIndexRetriever
from llama_index.schema import BaseNode
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever
from metagpt.rag.schema import RetrieverConfig, FAISSRetrieverConfig, BM25RetrieverConfig
import faiss
def get_retriever(
nodes: list[BaseNode], configs: list[RetrieverConfig] = None, service_context: ServiceContext = None
) -> BaseRetriever:
if not configs:
return _default_retriever(nodes, service_context)
retrivers = [_get_retriever(nodes, config, service_context) for config in configs]
return SimpleHybridRetriever(*retrivers, service_context=service_context) if len(retrivers) > 1 else retrivers[0]
def _default_retriever(nodes: list[BaseNode], service_context: ServiceContext = None) -> BaseRetriever:
return VectorStoreIndex(nodes=nodes, service_context=service_context).as_retriever()
def _get_retriever(
nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext = None
) -> BaseRetriever:
retriever_factory = {
FAISSRetrieverConfig: _create_faiss_retriever,
BM25RetrieverConfig: _create_bm25_retriever,
}
create_func = retriever_factory.get(type(config))
if create_func:
return create_func(nodes, config, service_context)
raise ValueError(f"Unknown retriever config: {config}")
def _create_faiss_retriever(nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext):
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(nodes=nodes, storage_context=storage_context, service_context=service_context)
return VectorIndexRetriever(index=vector_index, similarity_top_k=config.similarity_top_k)
def _create_bm25_retriever(nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext = None):
return BM25Retriever.from_defaults(**config.model_dump(), nodes=nodes)

View file

@ -3,6 +3,7 @@
from abc import abstractmethod
from llama_index import Document
from llama_index.retrievers import BaseRetriever
from llama_index.schema import NodeWithScore, QueryType
@ -14,5 +15,9 @@ class RAGRetriever(BaseRetriever):
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""retrieve nodes"""
@abstractmethod
def add_docs(self, documents: list[Document]) -> None:
"""add docs"""
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
"""retrieve nodes"""

View file

@ -0,0 +1,14 @@
from llama_index import Document
from llama_index.retrievers import BM25Retriever
class DynamicBM25Retriever(BM25Retriever):
def add_docs(self, documents: list[Document]):
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("Please install rank_bm25: pip install rank-bm25")
self._nodes.extend(documents)
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)

View file

@ -0,0 +1,60 @@
import faiss
from llama_index import StorageContext, VectorStoreIndex
from llama_index.indices.base import BaseIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
RetrieverConfigType,
)
class RetrieverFactory:
def __init__(self):
self.retriever_creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
}
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
if not configs:
return self._default_retriever(index)
retrievers = [self._get_retriever(index, config) for config in configs]
return (
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
if len(retrievers) > 1
else retrievers[0]
)
def _default_retriever(self, index: BaseIndex) -> RAGRetriever:
return index.as_retriever()
def _get_retriever(self, index: BaseIndex, config: RetrieverConfigType) -> RAGRetriever:
create_func = self.retriever_creators.get(type(config))
if create_func:
return create_func(index, config)
raise ValueError(f"Unknown retriever config: {config}")
def _create_faiss_retriever(self, index: BaseIndex, config: FAISSRetrieverConfig):
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(
nodes=list(index.docstore.docs.values()),
storage_context=storage_context,
service_context=index.service_context,
)
return FAISSRetriever(vector_index, **config.model_dump())
def _create_bm25_retriever(self, index: BaseIndex, config: BM25RetrieverConfig):
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
get_retriever = RetrieverFactory().get_retriever

View file

@ -0,0 +1,8 @@
from llama_index import Document
from llama_index.retrievers import VectorIndexRetriever
class FAISSRetriever(VectorIndexRetriever):
def add_docs(self, documents: list[Document]):
for document in documents:
self._index.insert(document)

View file

@ -1,5 +1,5 @@
"""Hybrid retriever."""
from llama_index import ServiceContext
from llama_index import Document, ServiceContext
from llama_index.schema import QueryType
from metagpt.rag.retrievers.base import RAGRetriever
@ -36,3 +36,7 @@ class SimpleHybridRetriever(RAGRetriever):
result.append(n)
node_ids.add(n.node.node_id)
return result
def add_docs(self, documents: list[Document]):
for r in self.retrievers:
r.add_docs(documents)

View file

@ -1,5 +1,7 @@
"""Retriever schemas"""
from typing import Union
from pydantic import BaseModel
@ -21,3 +23,7 @@ class RankerConfig(BaseModel):
class LLMRankerConfig(RankerConfig):
...
RetrieverConfigType = Union[FAISSRetrieverConfig, BM25RetrieverConfig]
RankerConfigType = LLMRankerConfig