From a35f13b4c4e9c4c54d1306842061b8a117d6988b Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 7 Feb 2024 18:19:22 +0800 Subject: [PATCH] rag add objs --- examples/rag_pipeline.py | 163 ++++++++++++++++++++++------------ metagpt/rag/engines/simple.py | 10 ++- metagpt/rag/interface.py | 6 ++ 3 files changed, 120 insertions(+), 59 deletions(-) create mode 100644 metagpt/rag/interface.py diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index ba8287f4b..3aae9aa70 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -1,6 +1,8 @@ """RAG pipeline""" import asyncio +from pydantic import BaseModel + from metagpt.const import EXAMPLE_PATH from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ( @@ -13,81 +15,128 @@ DOC_PATH = EXAMPLE_PATH / "data/rag_writer.txt" QUESTION = "What are key qualities to be a good writer?" -def print_result(result, state="Retrieve"): - """print retrieve or query result""" - print("-" * 50) - print(f"{state} Result:") +class RAGExample: + def __init__(self): + self.engine = SimpleEngine.from_docs( + input_files=[DOC_PATH], + retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) - if state == "Retrieve": - for i, node in enumerate(result): - print(f"{i}. {node.text[:10]}..., {node.score}") - return + async def rag_pipeline(self, question=QUESTION, print_title=True): + """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: - print(result) + Retrieve Result: + 0. Productivi..., 10.0 + 1. I wrote cu..., 7.0 + 2. I highly r..., 5.0 + Query Result: + Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer. + """ + if print_title: + self._print_title("RAG Pipeline") -def build_engine(input_files: list[str]): - engine = SimpleEngine.from_docs( - input_files=input_files, - retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], - ranker_configs=[LLMRankerConfig()], - ) - return engine + nodes = await self.engine.aretrieve(question) + self._print_result(nodes, state="Retrieve") + answer = await self.engine.aquery(question) + self._print_result(answer, state="Query") -async def rag_pipeline(engine: SimpleEngine, question=QUESTION): - """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: + async def rag_add_docs(self): + """This example show how to add docs, before add docs llm anwser I don't know, after add docs llm give the correct answer, will print something like: - Retrieve Result: - 0. Productivi..., 10.0 - 1. I wrote cu..., 7.0 - 2. I highly r..., 5.0 - -------------------------------------------------- - Query Result: - Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer. - """ - nodes = await engine.aretrieve(question) - print_result(nodes, state="Retrieve") + [Before add docs] + Retrieve Result: - answer = await engine.aquery(question) - print_result(answer, state="Query") + Query Result: + Empty Response + [After add docs] + Retrieve Result: + 0. Bojan like..., 10.0 -async def rag_add_docs(engine: SimpleEngine): - """This example show how to add docs, before add docs llm anwser I don't know, after add docs llm give the correct answer, will print something like: + Query Result: + Bojan likes traveling. + """ + self._print_title("RAG Add Docs") - [Before add docs] - -------------------------------------------------- - Retrieve Result: - -------------------------------------------------- - Query Result: - I don't know. + travel_question = "What does Bojan like? If you not sure, just answer I don't know" + travel_filepath = EXAMPLE_PATH / "data/rag_travel.txt" - [After add docs] - -------------------------------------------------- - Retrieve Result: - 0. Bojan like..., 10.0 - -------------------------------------------------- - Query Result: - Bojan likes traveling. - """ - travel_question = "What does Bojan like? If you not sure, just answer i don't know" - travel_filepath = EXAMPLE_PATH / "data/rag_travel.txt" + print("[Before add docs]") + await self.rag_pipeline(question=travel_question, print_title=False) - print("[Before add docs]") - await rag_pipeline(engine, question=travel_question) + print("[After add docs]") + self.engine.add_docs([travel_filepath]) + await self.rag_pipeline(question=travel_question, print_title=False) - print("\n[After add docs]") - engine.add_docs([travel_filepath]) - await rag_pipeline(engine, question=travel_question) + async def rag_add_objs(self): + """This example show how to add objs, before add docs engine retrieve nothing, after add objs engine give the correct answer, will print something like: + [Before add objs] + Retrieve Result: + + [After add objs] + Retrieve Result: + 0. 100m Sprin..., 10.0 + + [Object Detail] + {'name': 'foo', 'goal': 'Win The Game', 'tool': 'Red Bull Energy Drink'} + """ + + self._print_title("RAG Add Docs") + + class Player(BaseModel): + name: str = "" + goal: str = "Win The Game" + tool: str = "Red Bull Energy Drink" + + def rag_key(self) -> str: + return "100m Sprint" + + foo = Player(name="foo") + question = f"{foo.rag_key()}" + + print("[Before add objs]") + await self._retrieve_and_print(question) + + print("[After add objs]") + self.engine.add_objs([foo]) + nodes = await self._retrieve_and_print(question) + + print("[Object Detail]") + player: Player = nodes[0].metadata["obj"] + print(f"{player.model_dump()}") + + @staticmethod + def _print_title(title): + print(f"{'#'*50} {title} {'#'*50}") + + @staticmethod + def _print_result(result, state="Retrieve"): + """print retrieve or query result""" + print(f"{state} Result:") + + if state == "Retrieve": + for i, node in enumerate(result): + print(f"{i}. {node.text[:10]}..., {node.score}") + print() + return + + print(f"{result}\n") + + async def _retrieve_and_print(self, question): + nodes = await self.engine.aretrieve(question) + self._print_result(nodes, state="Retrieve") + return nodes async def main(): """RAG pipeline""" - engine = build_engine([DOC_PATH]) - await rag_pipeline(engine) - print("#" * 100) - await rag_add_docs(engine) + e = RAGExample() + await e.rag_pipeline() + await e.rag_add_docs() + await e.rag_add_objs() if __name__ == "__main__": diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index e71cfc439..1b8a63434 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -12,9 +12,10 @@ from llama_index.llms.llm import LLM from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.query_engine import RetrieverQueryEngine from llama_index.response_synthesizers import BaseSynthesizer -from llama_index.schema import NodeWithScore, QueryBundle, QueryType +from llama_index.schema import NodeWithScore, QueryBundle, QueryType, TextNode from metagpt.rag.factory import get_rankers, get_retriever +from metagpt.rag.interface import RAGObject from metagpt.rag.llm import get_default_llm from metagpt.rag.retrievers.base import ModifiableRAGRetriever from metagpt.rag.schema import RankerConfigType, RetrieverConfigType @@ -92,10 +93,15 @@ class SimpleEngine(RetrieverQueryEngine): return await super().aretrieve(query_bundle) def add_docs(self, input_files: list[str]): - """Add docs to retriever. retriever must has add_nodes func""" + """Add docs to retriever. retriever must has add_nodes func.""" if not isinstance(self.retriever, ModifiableRAGRetriever): raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}") documents = SimpleDirectoryReader(input_files=input_files).load_data() nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents) self.retriever.add_nodes(nodes) + + def add_objs(self, obj_list: list[RAGObject]): + """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" + nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in obj_list] + self.retriever.add_nodes(nodes) diff --git a/metagpt/rag/interface.py b/metagpt/rag/interface.py new file mode 100644 index 000000000..7ed2c6b58 --- /dev/null +++ b/metagpt/rag/interface.py @@ -0,0 +1,6 @@ +from typing import Protocol + + +class RAGObject(Protocol): + def rag_key(self) -> str: + """for rag search"""