replace rag llm factory with llamaindex custom llm

This commit is contained in:
seehi 2024-03-08 20:19:28 +08:00
parent 4712b2136b
commit 9fe9a4a2d1
9 changed files with 87 additions and 142 deletions

View file

@ -4,6 +4,7 @@ import asyncio
from pydantic import BaseModel
from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import (
BM25RetrieverConfig,
@ -85,10 +86,10 @@ class RAGExample:
travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}"
travel_filepath = TRAVEL_DOC_PATH
print("[Before add docs]")
logger.info("[Before add docs]")
await self.rag_pipeline(question=travel_question, print_title=False)
print("[After add docs]")
logger.info("[After add docs]")
self.engine.add_docs([travel_filepath])
await self.rag_pipeline(question=travel_question, print_title=False)
@ -110,19 +111,19 @@ class RAGExample:
player = Player(name="Mike")
question = f"{player.rag_key()}"
print("[Before add objs]")
logger.info("[Before add objs]")
await self._retrieve_and_print(question)
print("[After add objs]")
logger.info("[After add objs]")
self.engine.add_objs([player])
nodes = await self._retrieve_and_print(question)
print("[Object Detail]")
logger.info("[Object Detail]")
try:
player: Player = nodes[0].metadata["obj"]
print(player.name)
logger.info(player.name)
except Exception as e:
print(f"ERROR: nodes is empty, llm don't answer correctly, exception: {e}")
logger.info(f"ERROR: nodes is empty, llm don't answer correctly, exception: {e}")
async def rag_ini_objs(self):
"""This example show how to from objs, will print something like:
@ -162,20 +163,20 @@ class RAGExample:
@staticmethod
def _print_title(title):
print(f"{'#'*50} {title} {'#'*50}")
logger.info(f"{'#'*30} {title} {'#'*30}")
@staticmethod
def _print_result(result, state="Retrieve"):
"""print retrieve or query result"""
print(f"{state} Result:")
logger.info(f"{state} Result:")
if state == "Retrieve":
for i, node in enumerate(result):
print(f"{i}. {node.text[:10]}..., {node.score}")
print()
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
logger.info("")
return
print(f"{result}\n")
logger.info(f"{result}\n")
async def _retrieve_and_print(self, question):
nodes = await self.engine.aretrieve(question)