diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 1151268ed..2f26ff052 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -22,6 +22,18 @@ TRAVEL_QUESTION = "What does Bob like?" LLM_TIP = "If you not sure, just answer I don't know" +class Player(BaseModel): + """To demonstrate rag add objs""" + + name: str = "" + goal: str = "Win The 100-meter Sprint" + tool: str = "Red Bull Energy Drink" + + def rag_key(self) -> str: + """For search""" + return self.goal + + class RAGExample: """Show how to use RAG.""" @@ -95,17 +107,6 @@ class RAGExample: self._print_title("RAG Add Objs") - class Player(BaseModel): - """Player""" - - name: str = "" - goal: str = "Win The 100-meter Sprint" - tool: str = "Red Bull Energy Drink" - - def rag_key(self) -> str: - """For search""" - return self.goal - player = Player(name="Mike") question = f"{player.rag_key()}{LLM_TIP}" @@ -118,7 +119,7 @@ class RAGExample: print("[Object Detail]") player: Player = nodes[0].metadata["obj"] - print(player) + print(player.name) async def rag_chromadb(self): """This example show how to use chromadb. how to save and load index. will print something like: diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 556f0f2f2..d5d1fc9c4 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -36,6 +36,7 @@ from metagpt.rag.factories import ( from metagpt.rag.interface import RAGObject from metagpt.rag.retrievers.base import ModifiableRAGRetriever from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig +from metagpt.utils.common import import_class class SimpleEngine(RetrieverQueryEngine): @@ -129,9 +130,12 @@ class SimpleEngine(RetrieverQueryEngine): return await self.aquery(content) async def aretrieve(self, query: QueryType) -> list[NodeWithScore]: - """Allow query to be str""" + """Allow query to be str.""" query_bundle = QueryBundle(query) if isinstance(query, str) else query - return await super().aretrieve(query_bundle) + + nodes = await super().aretrieve(query_bundle) + self._try_reconstruct_object(nodes) + return nodes def add_docs(self, input_files: list[str]): """Add docs to retriever. retriever must has add_nodes func.""" @@ -145,7 +149,18 @@ class SimpleEngine(RetrieverQueryEngine): """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" self._ensure_retriever_modifiable() - nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs] + nodes = [ + TextNode( + text=obj.rag_key(), + metadata={ + "is_obj": True, + "obj_dict": obj.model_dump(), + "obj_cls_name": obj.__class__.__name__, + "obj_mod_name": obj.__class__.__module__, + }, + ) + for obj in objs + ] self._save_nodes(nodes) def _ensure_retriever_modifiable(self): @@ -158,3 +173,11 @@ class SimpleEngine(RetrieverQueryEngine): # for persist self.index.insert_nodes(nodes) + + @staticmethod + def _try_reconstruct_object(nodes: list[NodeWithScore]): + """If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"].""" + for node in nodes: + if node.metadata.get("is_obj"): + obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"]) + node.metadata["obj"] = obj_cls(**node.metadata["obj_dict"])