mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 17:26:22 +02:00
rag add objs
This commit is contained in:
parent
cd605bf8f4
commit
a35f13b4c4
3 changed files with 120 additions and 59 deletions
|
|
@ -1,6 +1,8 @@
|
|||
"""RAG pipeline"""
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import (
|
||||
|
|
@ -13,81 +15,128 @@ DOC_PATH = EXAMPLE_PATH / "data/rag_writer.txt"
|
|||
QUESTION = "What are key qualities to be a good writer?"
|
||||
|
||||
|
||||
def print_result(result, state="Retrieve"):
|
||||
"""print retrieve or query result"""
|
||||
print("-" * 50)
|
||||
print(f"{state} Result:")
|
||||
class RAGExample:
|
||||
def __init__(self):
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
|
||||
if state == "Retrieve":
|
||||
for i, node in enumerate(result):
|
||||
print(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
return
|
||||
async def rag_pipeline(self, question=QUESTION, print_title=True):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
|
||||
print(result)
|
||||
Retrieve Result:
|
||||
0. Productivi..., 10.0
|
||||
1. I wrote cu..., 7.0
|
||||
2. I highly r..., 5.0
|
||||
|
||||
Query Result:
|
||||
Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer.
|
||||
"""
|
||||
if print_title:
|
||||
self._print_title("RAG Pipeline")
|
||||
|
||||
def build_engine(input_files: list[str]):
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=input_files,
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
return engine
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
|
||||
answer = await self.engine.aquery(question)
|
||||
self._print_result(answer, state="Query")
|
||||
|
||||
async def rag_pipeline(engine: SimpleEngine, question=QUESTION):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
async def rag_add_docs(self):
|
||||
"""This example show how to add docs, before add docs llm anwser I don't know, after add docs llm give the correct answer, will print something like:
|
||||
|
||||
Retrieve Result:
|
||||
0. Productivi..., 10.0
|
||||
1. I wrote cu..., 7.0
|
||||
2. I highly r..., 5.0
|
||||
--------------------------------------------------
|
||||
Query Result:
|
||||
Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer.
|
||||
"""
|
||||
nodes = await engine.aretrieve(question)
|
||||
print_result(nodes, state="Retrieve")
|
||||
[Before add docs]
|
||||
Retrieve Result:
|
||||
|
||||
answer = await engine.aquery(question)
|
||||
print_result(answer, state="Query")
|
||||
Query Result:
|
||||
Empty Response
|
||||
|
||||
[After add docs]
|
||||
Retrieve Result:
|
||||
0. Bojan like..., 10.0
|
||||
|
||||
async def rag_add_docs(engine: SimpleEngine):
|
||||
"""This example show how to add docs, before add docs llm anwser I don't know, after add docs llm give the correct answer, will print something like:
|
||||
Query Result:
|
||||
Bojan likes traveling.
|
||||
"""
|
||||
self._print_title("RAG Add Docs")
|
||||
|
||||
[Before add docs]
|
||||
--------------------------------------------------
|
||||
Retrieve Result:
|
||||
--------------------------------------------------
|
||||
Query Result:
|
||||
I don't know.
|
||||
travel_question = "What does Bojan like? If you not sure, just answer I don't know"
|
||||
travel_filepath = EXAMPLE_PATH / "data/rag_travel.txt"
|
||||
|
||||
[After add docs]
|
||||
--------------------------------------------------
|
||||
Retrieve Result:
|
||||
0. Bojan like..., 10.0
|
||||
--------------------------------------------------
|
||||
Query Result:
|
||||
Bojan likes traveling.
|
||||
"""
|
||||
travel_question = "What does Bojan like? If you not sure, just answer i don't know"
|
||||
travel_filepath = EXAMPLE_PATH / "data/rag_travel.txt"
|
||||
print("[Before add docs]")
|
||||
await self.rag_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
print("[Before add docs]")
|
||||
await rag_pipeline(engine, question=travel_question)
|
||||
print("[After add docs]")
|
||||
self.engine.add_docs([travel_filepath])
|
||||
await self.rag_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
print("\n[After add docs]")
|
||||
engine.add_docs([travel_filepath])
|
||||
await rag_pipeline(engine, question=travel_question)
|
||||
async def rag_add_objs(self):
|
||||
"""This example show how to add objs, before add docs engine retrieve nothing, after add objs engine give the correct answer, will print something like:
|
||||
[Before add objs]
|
||||
Retrieve Result:
|
||||
|
||||
[After add objs]
|
||||
Retrieve Result:
|
||||
0. 100m Sprin..., 10.0
|
||||
|
||||
[Object Detail]
|
||||
{'name': 'foo', 'goal': 'Win The Game', 'tool': 'Red Bull Energy Drink'}
|
||||
"""
|
||||
|
||||
self._print_title("RAG Add Docs")
|
||||
|
||||
class Player(BaseModel):
|
||||
name: str = ""
|
||||
goal: str = "Win The Game"
|
||||
tool: str = "Red Bull Energy Drink"
|
||||
|
||||
def rag_key(self) -> str:
|
||||
return "100m Sprint"
|
||||
|
||||
foo = Player(name="foo")
|
||||
question = f"{foo.rag_key()}"
|
||||
|
||||
print("[Before add objs]")
|
||||
await self._retrieve_and_print(question)
|
||||
|
||||
print("[After add objs]")
|
||||
self.engine.add_objs([foo])
|
||||
nodes = await self._retrieve_and_print(question)
|
||||
|
||||
print("[Object Detail]")
|
||||
player: Player = nodes[0].metadata["obj"]
|
||||
print(f"{player.model_dump()}")
|
||||
|
||||
@staticmethod
|
||||
def _print_title(title):
|
||||
print(f"{'#'*50} {title} {'#'*50}")
|
||||
|
||||
@staticmethod
|
||||
def _print_result(result, state="Retrieve"):
|
||||
"""print retrieve or query result"""
|
||||
print(f"{state} Result:")
|
||||
|
||||
if state == "Retrieve":
|
||||
for i, node in enumerate(result):
|
||||
print(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
print()
|
||||
return
|
||||
|
||||
print(f"{result}\n")
|
||||
|
||||
async def _retrieve_and_print(self, question):
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
return nodes
|
||||
|
||||
|
||||
async def main():
|
||||
"""RAG pipeline"""
|
||||
engine = build_engine([DOC_PATH])
|
||||
await rag_pipeline(engine)
|
||||
print("#" * 100)
|
||||
await rag_add_docs(engine)
|
||||
e = RAGExample()
|
||||
await e.rag_pipeline()
|
||||
await e.rag_add_docs()
|
||||
await e.rag_add_objs()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue