diff --git a/examples/data/rag_travel.txt b/examples/data/rag_travel.txt new file mode 100644 index 000000000..1c738c54a --- /dev/null +++ b/examples/data/rag_travel.txt @@ -0,0 +1 @@ +Bojan likes traveling. \ No newline at end of file diff --git a/examples/data/rag.txt b/examples/data/rag_writer.txt similarity index 100% rename from examples/data/rag.txt rename to examples/data/rag_writer.txt diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index c90b160f3..ba8287f4b 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -9,7 +9,7 @@ from metagpt.rag.schema import ( LLMRankerConfig, ) -DOC_PATH = EXAMPLE_PATH / "data/rag.txt" +DOC_PATH = EXAMPLE_PATH / "data/rag_writer.txt" QUESTION = "What are key qualities to be a good writer?" @@ -26,7 +26,16 @@ def print_result(result, state="Retrieve"): print(result) -async def rag_pipeline(): +def build_engine(input_files: list[str]): + engine = SimpleEngine.from_docs( + input_files=input_files, + retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) + return engine + + +async def rag_pipeline(engine: SimpleEngine, question=QUESTION): """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: Retrieve Result: @@ -37,22 +46,48 @@ async def rag_pipeline(): Query Result: Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer. """ - engine = SimpleEngine.from_docs( - input_files=[DOC_PATH], - retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], - ranker_configs=[LLMRankerConfig()], - ) - - nodes = await engine.aretrieve(QUESTION) + nodes = await engine.aretrieve(question) print_result(nodes, state="Retrieve") - answer = await engine.aquery(QUESTION) + answer = await engine.aquery(question) print_result(answer, state="Query") +async def rag_add_docs(engine: SimpleEngine): + """This example show how to add docs, before add docs llm anwser I don't know, after add docs llm give the correct answer, will print something like: + + [Before add docs] + -------------------------------------------------- + Retrieve Result: + -------------------------------------------------- + Query Result: + I don't know. + + [After add docs] + -------------------------------------------------- + Retrieve Result: + 0. Bojan like..., 10.0 + -------------------------------------------------- + Query Result: + Bojan likes traveling. + """ + travel_question = "What does Bojan like? If you not sure, just answer i don't know" + travel_filepath = EXAMPLE_PATH / "data/rag_travel.txt" + + print("[Before add docs]") + await rag_pipeline(engine, question=travel_question) + + print("\n[After add docs]") + engine.add_docs([travel_filepath]) + await rag_pipeline(engine, question=travel_question) + + async def main(): """RAG pipeline""" - await rag_pipeline() + engine = build_engine([DOC_PATH]) + await rag_pipeline(engine) + print("#" * 100) + await rag_add_docs(engine) if __name__ == "__main__": diff --git a/metagpt/rag/engines/__init__.py b/metagpt/rag/engines/__init__.py index 7b4e37e88..4e862b908 100644 --- a/metagpt/rag/engines/__init__.py +++ b/metagpt/rag/engines/__init__.py @@ -1,3 +1,6 @@ -from metagpt.rag.engines.simple import SimpleEngine +"""Engines init""" __all__ = ["SimpleEngine"] + + +from metagpt.rag.engines.simple import SimpleEngine diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 3f6f15aad..e136b4092 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -1,6 +1,7 @@ """Simple Engine.""" -from llama_index import ServiceContext, SimpleDirectoryReader + +from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex from llama_index.embeddings.base import BaseEmbedding from llama_index.llms.llm import LLM from llama_index.query_engine import RetrieverQueryEngine @@ -9,26 +10,23 @@ from llama_index.schema import NodeWithScore, QueryBundle, QueryType from metagpt.rag.llm import get_default_llm from metagpt.rag.rankers import get_rankers from metagpt.rag.retrievers import get_retriever -from metagpt.rag.schema import RankerConfig, RetrieverConfig +from metagpt.rag.retrievers.base import RAGRetriever +from metagpt.rag.schema import RankerConfigType, RetrieverConfigType from metagpt.utils.embedding import get_embedding class SimpleEngine(RetrieverQueryEngine): - """ - SimpleEngine is a search engine that uses a vector index for retrieving documents. - """ - @classmethod def from_docs( cls, input_dir: str = None, - input_files: list = None, + input_files: list[str] = None, llm: LLM = None, embed_model: BaseEmbedding = None, chunk_size: int = None, chunk_overlap: int = None, - retriever_configs: list[RetrieverConfig] = None, - ranker_configs: list[RankerConfig] = None, + retriever_configs: list[RetrieverConfigType] = None, + ranker_configs: list[RankerConfigType] = None, ) -> "SimpleEngine": """This engine is designed to be simple and straightforward @@ -44,8 +42,8 @@ class SimpleEngine(RetrieverQueryEngine): chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) - nodes = service_context.node_parser.get_nodes_from_documents(documents) - retriever = get_retriever(nodes, configs=retriever_configs, service_context=service_context) + index = VectorStoreIndex.from_documents(documents, service_context=service_context) + retriever = get_retriever(index, configs=retriever_configs) rankers = get_rankers(configs=ranker_configs, service_context=service_context) return SimpleEngine(retriever=retriever, node_postprocessors=rankers) @@ -58,3 +56,8 @@ class SimpleEngine(RetrieverQueryEngine): """Allow query to be str""" query_bundle = QueryBundle(query) if isinstance(query, str) else query return await super().aretrieve(query_bundle) + + def add_docs(self, input_files: list[str]): + documents = SimpleDirectoryReader(input_files=input_files).load_data() + retriever: RAGRetriever = self.retriever + retriever.add_docs(documents) diff --git a/metagpt/rag/llm.py b/metagpt/rag/llm.py index e67be1416..405b29991 100644 --- a/metagpt/rag/llm.py +++ b/metagpt/rag/llm.py @@ -4,4 +4,4 @@ from metagpt.config2 import config def get_default_llm() -> OpenAI: - return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key) + return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key, model=config.llm.model) diff --git a/metagpt/rag/rankers/__init__.py b/metagpt/rag/rankers/__init__.py index 5bfa866ef..bb14007ba 100644 --- a/metagpt/rag/rankers/__init__.py +++ b/metagpt/rag/rankers/__init__.py @@ -1,34 +1,6 @@ -"""init""" -from metagpt.rag.schema import RankerConfig, LLMRankerConfig -from llama_index import ServiceContext -from llama_index.postprocessor import LLMRerank -from llama_index.postprocessor.types import BaseNodePostprocessor +"""Rankers init""" + +from metagpt.rag.rankers.factory import get_rankers -def get_rankers( - configs: list[RankerConfig] = None, service_context: ServiceContext = None -) -> list[BaseNodePostprocessor]: - if not configs: - return [_default_ranker(service_context)] - - return [_get_ranker(config, service_context) for config in configs] - - -def _default_ranker(service_context: ServiceContext = None): - return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context) - - -def _get_ranker(config: RankerConfig, service_context: ServiceContext = None): - ranker_factory = { - LLMRankerConfig: _create_llm_ranker, - } - - create_func = ranker_factory.get(type(config)) - if create_func: - return create_func(config, service_context) - - raise ValueError(f"Unknown ranker config: {config}") - - -def _create_llm_ranker(config, service_context=None): - return LLMRerank(top_n=config.top_n, service_context=service_context) +__all__ = ["get_rankers"] diff --git a/metagpt/rag/rankers/factory.py b/metagpt/rag/rankers/factory.py new file mode 100644 index 000000000..14dc89604 --- /dev/null +++ b/metagpt/rag/rankers/factory.py @@ -0,0 +1,36 @@ +from llama_index import ServiceContext +from llama_index.postprocessor import LLMRerank +from llama_index.postprocessor.types import BaseNodePostprocessor + +from metagpt.rag.schema import LLMRankerConfig, RankerConfigType + + +class RankerFactory: + def __init__(self): + self.ranker_creators = { + LLMRankerConfig: self._create_llm_ranker, + } + + def get_rankers( + self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None + ) -> list[BaseNodePostprocessor]: + if not configs: + return [self._default_ranker(service_context)] + + return [self._get_ranker(config, service_context) for config in configs] + + def _default_ranker(self, service_context: ServiceContext = None): + return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context) + + def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None): + create_func = self.ranker_creators.get(type(config)) + if create_func: + return create_func(config, service_context) + + raise ValueError(f"Unknown ranker config: {config}") + + def _create_llm_ranker(self, config, service_context=None): + return LLMRerank(top_n=config.top_n, service_context=service_context) + + +get_rankers = RankerFactory().get_rankers diff --git a/metagpt/rag/retrievers/__init__.py b/metagpt/rag/retrievers/__init__.py index 3f9098e35..88cb4cc77 100644 --- a/metagpt/rag/retrievers/__init__.py +++ b/metagpt/rag/retrievers/__init__.py @@ -1,55 +1,6 @@ +"""Retrievers init""" + +from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever +from metagpt.rag.retrievers.factory import get_retriever + __all__ = ["SimpleHybridRetriever", "get_retriever"] - -from llama_index import ( - ServiceContext, - StorageContext, - VectorStoreIndex, -) -from llama_index.retrievers import BaseRetriever, BM25Retriever, VectorIndexRetriever -from llama_index.schema import BaseNode -from llama_index.vector_stores.faiss import FaissVectorStore - -from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever -from metagpt.rag.schema import RetrieverConfig, FAISSRetrieverConfig, BM25RetrieverConfig -import faiss - - -def get_retriever( - nodes: list[BaseNode], configs: list[RetrieverConfig] = None, service_context: ServiceContext = None -) -> BaseRetriever: - if not configs: - return _default_retriever(nodes, service_context) - - retrivers = [_get_retriever(nodes, config, service_context) for config in configs] - - return SimpleHybridRetriever(*retrivers, service_context=service_context) if len(retrivers) > 1 else retrivers[0] - - -def _default_retriever(nodes: list[BaseNode], service_context: ServiceContext = None) -> BaseRetriever: - return VectorStoreIndex(nodes=nodes, service_context=service_context).as_retriever() - - -def _get_retriever( - nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext = None -) -> BaseRetriever: - retriever_factory = { - FAISSRetrieverConfig: _create_faiss_retriever, - BM25RetrieverConfig: _create_bm25_retriever, - } - - create_func = retriever_factory.get(type(config)) - if create_func: - return create_func(nodes, config, service_context) - - raise ValueError(f"Unknown retriever config: {config}") - - -def _create_faiss_retriever(nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext): - vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) - storage_context = StorageContext.from_defaults(vector_store=vector_store) - vector_index = VectorStoreIndex(nodes=nodes, storage_context=storage_context, service_context=service_context) - return VectorIndexRetriever(index=vector_index, similarity_top_k=config.similarity_top_k) - - -def _create_bm25_retriever(nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext = None): - return BM25Retriever.from_defaults(**config.model_dump(), nodes=nodes) diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index c0291f217..535e427c3 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -3,6 +3,7 @@ from abc import abstractmethod +from llama_index import Document from llama_index.retrievers import BaseRetriever from llama_index.schema import NodeWithScore, QueryType @@ -14,5 +15,9 @@ class RAGRetriever(BaseRetriever): async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]: """retrieve nodes""" + @abstractmethod + def add_docs(self, documents: list[Document]) -> None: + """add docs""" + def _retrieve(self, query: QueryType) -> list[NodeWithScore]: """retrieve nodes""" diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py new file mode 100644 index 000000000..4141827dd --- /dev/null +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -0,0 +1,14 @@ +from llama_index import Document +from llama_index.retrievers import BM25Retriever + + +class DynamicBM25Retriever(BM25Retriever): + def add_docs(self, documents: list[Document]): + try: + from rank_bm25 import BM25Okapi + except ImportError: + raise ImportError("Please install rank_bm25: pip install rank-bm25") + + self._nodes.extend(documents) + self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] + self.bm25 = BM25Okapi(self._corpus) diff --git a/metagpt/rag/retrievers/factory.py b/metagpt/rag/retrievers/factory.py new file mode 100644 index 000000000..cde70e219 --- /dev/null +++ b/metagpt/rag/retrievers/factory.py @@ -0,0 +1,60 @@ +import faiss +from llama_index import StorageContext, VectorStoreIndex +from llama_index.indices.base import BaseIndex +from llama_index.vector_stores.faiss import FaissVectorStore + +from metagpt.rag.retrievers.base import RAGRetriever +from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever +from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever +from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever +from metagpt.rag.schema import ( + BM25RetrieverConfig, + FAISSRetrieverConfig, + RetrieverConfigType, +) + + +class RetrieverFactory: + def __init__(self): + self.retriever_creators = { + FAISSRetrieverConfig: self._create_faiss_retriever, + BM25RetrieverConfig: self._create_bm25_retriever, + } + + def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever: + if not configs: + return self._default_retriever(index) + + retrievers = [self._get_retriever(index, config) for config in configs] + + return ( + SimpleHybridRetriever(*retrievers, service_context=index.service_context) + if len(retrievers) > 1 + else retrievers[0] + ) + + def _default_retriever(self, index: BaseIndex) -> RAGRetriever: + return index.as_retriever() + + def _get_retriever(self, index: BaseIndex, config: RetrieverConfigType) -> RAGRetriever: + create_func = self.retriever_creators.get(type(config)) + if create_func: + return create_func(index, config) + + raise ValueError(f"Unknown retriever config: {config}") + + def _create_faiss_retriever(self, index: BaseIndex, config: FAISSRetrieverConfig): + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + vector_index = VectorStoreIndex( + nodes=list(index.docstore.docs.values()), + storage_context=storage_context, + service_context=index.service_context, + ) + return FAISSRetriever(vector_index, **config.model_dump()) + + def _create_bm25_retriever(self, index: BaseIndex, config: BM25RetrieverConfig): + return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index) + + +get_retriever = RetrieverFactory().get_retriever diff --git a/metagpt/rag/retrievers/faiss_retriever.py b/metagpt/rag/retrievers/faiss_retriever.py new file mode 100644 index 000000000..9888959e1 --- /dev/null +++ b/metagpt/rag/retrievers/faiss_retriever.py @@ -0,0 +1,8 @@ +from llama_index import Document +from llama_index.retrievers import VectorIndexRetriever + + +class FAISSRetriever(VectorIndexRetriever): + def add_docs(self, documents: list[Document]): + for document in documents: + self._index.insert(document) diff --git a/metagpt/rag/retrievers/hybrid.py b/metagpt/rag/retrievers/hybrid_retriever.py similarity index 88% rename from metagpt/rag/retrievers/hybrid.py rename to metagpt/rag/retrievers/hybrid_retriever.py index 701b13aa2..f4e9c3479 100644 --- a/metagpt/rag/retrievers/hybrid.py +++ b/metagpt/rag/retrievers/hybrid_retriever.py @@ -1,5 +1,5 @@ """Hybrid retriever.""" -from llama_index import ServiceContext +from llama_index import Document, ServiceContext from llama_index.schema import QueryType from metagpt.rag.retrievers.base import RAGRetriever @@ -36,3 +36,7 @@ class SimpleHybridRetriever(RAGRetriever): result.append(n) node_ids.add(n.node.node_id) return result + + def add_docs(self, documents: list[Document]): + for r in self.retrievers: + r.add_docs(documents) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index e781cc2ab..9eb76d43d 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,5 +1,7 @@ """Retriever schemas""" +from typing import Union + from pydantic import BaseModel @@ -21,3 +23,7 @@ class RankerConfig(BaseModel): class LLMRankerConfig(RankerConfig): ... + + +RetrieverConfigType = Union[FAISSRetrieverConfig, BM25RetrieverConfig] +RankerConfigType = LLMRankerConfig