diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 2f26ff052..daf4014fc 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -19,15 +19,15 @@ QUESTION = "What are key qualities to be a good writer?" TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt" TRAVEL_QUESTION = "What does Bob like?" -LLM_TIP = "If you not sure, just answer I don't know" +LLM_TIP = "If you not sure, just answer I don't know." class Player(BaseModel): - """To demonstrate rag add objs""" + """To demonstrate rag add objs.""" name: str = "" - goal: str = "Win The 100-meter Sprint" - tool: str = "Red Bull Energy Drink" + goal: str = "Win The 100-meter Sprint." + tool: str = "Red Bull Energy Drink." def rag_key(self) -> str: """For search""" @@ -108,7 +108,7 @@ class RAGExample: self._print_title("RAG Add Objs") player = Player(name="Mike") - question = f"{player.rag_key()}{LLM_TIP}" + question = f"{player.rag_key()}" print("[Before add objs]") await self._retrieve_and_print(question) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 4d47c7084..22351d8fd 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -22,7 +22,6 @@ from llama_index.core.schema import ( NodeWithScore, QueryBundle, QueryType, - TextNode, TransformComponent, ) @@ -35,7 +34,12 @@ from metagpt.rag.factories import ( ) from metagpt.rag.interface import RAGObject from metagpt.rag.retrievers.base import ModifiableRAGRetriever -from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig +from metagpt.rag.schema import ( + BaseIndexConfig, + BaseRankerConfig, + BaseRetrieverConfig, + ObjectNode, +) from metagpt.utils.common import import_class @@ -149,18 +153,9 @@ class SimpleEngine(RetrieverQueryEngine): """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" self._ensure_retriever_modifiable() - nodes = [TextNode(text=obj.rag_key(), metadata=self._get_obj_metadata(obj)) for obj in objs] + nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] self._save_nodes(nodes) - def _get_obj_metadata(self, obj: RAGObject) -> dict: - metadata = { - "is_obj": True, - "obj_dict": obj.model_dump(), - "obj_cls_name": obj.__class__.__name__, - "obj_mod_name": obj.__class__.__module__, - } - return metadata - def _ensure_retriever_modifiable(self): if not isinstance(self.retriever, ModifiableRAGRetriever): raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}") diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 35e16e286..9657ae846 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -5,8 +5,11 @@ from typing import Any, Union from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex +from llama_index.core.schema import TextNode from pydantic import BaseModel, ConfigDict, Field +from metagpt.rag.interface import RAGObject + class BaseRetrieverConfig(BaseModel): """Common config for retrievers. @@ -84,3 +87,27 @@ class ChromaIndexConfig(VectorIndexConfig): """Config for chroma-based index.""" collection_name: str = Field(default="metagpt", description="The name of the collection.") + + +class ObjectNodeMetadata(BaseModel): + """Metadata of ObjectNode.""" + + is_obj: bool = Field(default=True) + obj_dict: dict = Field(..., description="Inplement rag.interface.RAGObject.model_dump(), e.g. obj.model_dump()") + obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__") + obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__") + + +class ObjectNode(TextNode): + """RAG add object.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys()) + + @staticmethod + def get_obj_metadata(obj: RAGObject) -> dict: + metadata = ObjectNodeMetadata( + obj_dict=obj.model_dump(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__ + ) + return metadata.model_dump() diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 1d1ddad12..60e72e422 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -1,6 +1,6 @@ import pytest from llama_index.core import VectorStoreIndex -from llama_index.core.schema import TextNode +from llama_index.core.schema import NodeWithScore, TextNode from metagpt.rag.engines import SimpleEngine from metagpt.rag.retrievers.base import ModifiableRAGRetriever @@ -97,7 +97,8 @@ class TestSimpleEngine: mock_super_aretrieve = mocker.patch( "metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock ) - mock_super_aretrieve.return_value = ["node_with_score"] + nodes = [NodeWithScore(node=TextNode())] + mock_super_aretrieve.return_value = nodes # Setup engine = SimpleEngine(retriever=mocker.MagicMock()) @@ -109,7 +110,7 @@ class TestSimpleEngine: # Assertions mock_query_bundle.assert_called_once_with(test_query) mock_super_aretrieve.assert_called_once_with("query_bundle") - assert result == ["node_with_score"] + assert result == nodes def test_add_docs(self, mocker): # Mock @@ -157,4 +158,4 @@ class TestSimpleEngine: assert mock_retriever.add_nodes.call_count == 1 for node in mock_retriever.add_nodes.call_args[0][0]: assert isinstance(node, TextNode) - assert "obj" in node.metadata + assert "obj_dict" in node.metadata