reconstruct object in rag node

This commit is contained in:
seehi 2024-03-07 12:07:27 +08:00 committed by betterwang
parent f37828c75e
commit 38e8adf9b4
2 changed files with 39 additions and 15 deletions

View file

@ -22,6 +22,18 @@ TRAVEL_QUESTION = "What does Bob like?"
LLM_TIP = "If you not sure, just answer I don't know"
class Player(BaseModel):
"""To demonstrate rag add objs"""
name: str = ""
goal: str = "Win The 100-meter Sprint"
tool: str = "Red Bull Energy Drink"
def rag_key(self) -> str:
"""For search"""
return self.goal
class RAGExample:
"""Show how to use RAG."""
@ -95,17 +107,6 @@ class RAGExample:
self._print_title("RAG Add Objs")
class Player(BaseModel):
"""Player"""
name: str = ""
goal: str = "Win The 100-meter Sprint"
tool: str = "Red Bull Energy Drink"
def rag_key(self) -> str:
"""For search"""
return self.goal
player = Player(name="Mike")
question = f"{player.rag_key()}{LLM_TIP}"
@ -118,7 +119,7 @@ class RAGExample:
print("[Object Detail]")
player: Player = nodes[0].metadata["obj"]
print(player)
print(player.name)
async def rag_chromadb(self):
"""This example show how to use chromadb. how to save and load index. will print something like:

View file

@ -36,6 +36,7 @@ from metagpt.rag.factories import (
from metagpt.rag.interface import RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
from metagpt.utils.common import import_class
class SimpleEngine(RetrieverQueryEngine):
@ -129,9 +130,12 @@ class SimpleEngine(RetrieverQueryEngine):
return await self.aquery(content)
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str"""
"""Allow query to be str."""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
return await super().aretrieve(query_bundle)
nodes = await super().aretrieve(query_bundle)
self._try_reconstruct_object(nodes)
return nodes
def add_docs(self, input_files: list[str]):
"""Add docs to retriever. retriever must has add_nodes func."""
@ -145,7 +149,18 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs]
nodes = [
TextNode(
text=obj.rag_key(),
metadata={
"is_obj": True,
"obj_dict": obj.model_dump(),
"obj_cls_name": obj.__class__.__name__,
"obj_mod_name": obj.__class__.__module__,
},
)
for obj in objs
]
self._save_nodes(nodes)
def _ensure_retriever_modifiable(self):
@ -158,3 +173,11 @@ class SimpleEngine(RetrieverQueryEngine):
# for persist
self.index.insert_nodes(nodes)
@staticmethod
def _try_reconstruct_object(nodes: list[NodeWithScore]):
"""If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"]."""
for node in nodes:
if node.metadata.get("is_obj"):
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
node.metadata["obj"] = obj_cls(**node.metadata["obj_dict"])