From 93a328de5b92d795af13d048813043cfc8c5b1cc Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 6 Mar 2024 17:39:41 +0800 Subject: [PATCH] rag add chromadb save&load example --- examples/rag_pipeline.py | 29 ++++++++++++++++++++++++++++- metagpt/rag/engines/simple.py | 4 ++-- requirements.txt | 4 ++-- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 64a83e77c..1151268ed 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -3,10 +3,12 @@ import asyncio from pydantic import BaseModel -from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ( BM25RetrieverConfig, + ChromaIndexConfig, + ChromaRetrieverConfig, FAISSRetrieverConfig, LLMRankerConfig, ) @@ -118,6 +120,30 @@ class RAGExample: player: Player = nodes[0].metadata["obj"] print(player) + async def rag_chromadb(self): + """This example show how to use chromadb. how to save and load index. will print something like: + + Query Result: + Bob likes traveling. + """ + self._print_title("RAG ChromaDB") + + # save index + output_dir = DATA_PATH / "rag" + SimpleEngine.from_docs( + input_files=[TRAVEL_DOC_PATH], + retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)], + ) + + # load index + engine = SimpleEngine.from_index( + index_config=ChromaIndexConfig(persist_path=output_dir), + ) + + # query + answer = engine.query(TRAVEL_QUESTION) + self._print_result(answer, state="Query") + @staticmethod def _print_title(title): print(f"{'#'*50} {title} {'#'*50}") @@ -147,6 +173,7 @@ async def main(): await e.rag_pipeline() await e.rag_add_docs() await e.rag_add_objs() + await e.rag_chromadb() if __name__ == "__main__": diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 895b7bd1e..556f0f2f2 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -100,7 +100,7 @@ class SimpleEngine(RetrieverQueryEngine): llm: LLM = None, retriever_configs: list[BaseRetrieverConfig] = None, ranker_configs: list[BaseRankerConfig] = None, - ): + ) -> "SimpleEngine": """Load from previously maintained""" index = get_index(index_config, embed_model=embed_model or get_rag_embedding()) return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @@ -112,7 +112,7 @@ class SimpleEngine(RetrieverQueryEngine): llm: LLM = None, retriever_configs: list[BaseRetrieverConfig] = None, ranker_configs: list[BaseRankerConfig] = None, - ): + ) -> "SimpleEngine": llm = llm or get_rag_llm() retriever = get_retriever(configs=retriever_configs, index=index) rankers = get_rankers(configs=ranker_configs, llm=llm) diff --git a/requirements.txt b/requirements.txt index 6586b3c82..991c318ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ typer==0.9.0 # godot==0.1.1 # google_api_python_client==2.93.0 # Used by search_engine.py lancedb==0.4.0 -llama-index-core==0.10.12 +llama-index-core==0.10.15 llama-index-embeddings-azure-openai==0.1.6 llama-index-embeddings-huggingface==0.1.3 llama-index-embeddings-openai==0.1.5 @@ -70,7 +70,7 @@ typing-extensions==4.9.0 socksio~=1.0.0 gitignore-parser==0.1.9 # connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py -websockets~=12.0 +websockets~=11.0 networkx~=3.2.1 google-generativeai==0.3.2 playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py