add excluded_llm_metadata_keys

This commit is contained in:
seehi 2024-03-07 15:35:28 +08:00
parent f149007752
commit a3b2cf7f0b
4 changed files with 44 additions and 21 deletions

View file

@ -19,15 +19,15 @@ QUESTION = "What are key qualities to be a good writer?"
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
TRAVEL_QUESTION = "What does Bob like?"
LLM_TIP = "If you not sure, just answer I don't know"
LLM_TIP = "If you not sure, just answer I don't know."
class Player(BaseModel):
"""To demonstrate rag add objs"""
"""To demonstrate rag add objs."""
name: str = ""
goal: str = "Win The 100-meter Sprint"
tool: str = "Red Bull Energy Drink"
goal: str = "Win The 100-meter Sprint."
tool: str = "Red Bull Energy Drink."
def rag_key(self) -> str:
"""For search"""
@ -108,7 +108,7 @@ class RAGExample:
self._print_title("RAG Add Objs")
player = Player(name="Mike")
question = f"{player.rag_key()}{LLM_TIP}"
question = f"{player.rag_key()}"
print("[Before add objs]")
await self._retrieve_and_print(question)

View file

@ -22,7 +22,6 @@ from llama_index.core.schema import (
NodeWithScore,
QueryBundle,
QueryType,
TextNode,
TransformComponent,
)
@ -35,7 +34,12 @@ from metagpt.rag.factories import (
)
from metagpt.rag.interface import RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
from metagpt.rag.schema import (
BaseIndexConfig,
BaseRankerConfig,
BaseRetrieverConfig,
ObjectNode,
)
from metagpt.utils.common import import_class
@ -149,18 +153,9 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [TextNode(text=obj.rag_key(), metadata=self._get_obj_metadata(obj)) for obj in objs]
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
self._save_nodes(nodes)
def _get_obj_metadata(self, obj: RAGObject) -> dict:
metadata = {
"is_obj": True,
"obj_dict": obj.model_dump(),
"obj_cls_name": obj.__class__.__name__,
"obj_mod_name": obj.__class__.__module__,
}
return metadata
def _ensure_retriever_modifiable(self):
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")

View file

@ -5,8 +5,11 @@ from typing import Any, Union
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from pydantic import BaseModel, ConfigDict, Field
from metagpt.rag.interface import RAGObject
class BaseRetrieverConfig(BaseModel):
"""Common config for retrievers.
@ -84,3 +87,27 @@ class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""
is_obj: bool = Field(default=True)
obj_dict: dict = Field(..., description="Inplement rag.interface.RAGObject.model_dump(), e.g. obj.model_dump()")
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
class ObjectNode(TextNode):
"""RAG add object."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
@staticmethod
def get_obj_metadata(obj: RAGObject) -> dict:
metadata = ObjectNodeMetadata(
obj_dict=obj.model_dump(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
)
return metadata.model_dump()

View file

@ -1,6 +1,6 @@
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from llama_index.core.schema import NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
@ -97,7 +97,8 @@ class TestSimpleEngine:
mock_super_aretrieve = mocker.patch(
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
)
mock_super_aretrieve.return_value = ["node_with_score"]
nodes = [NodeWithScore(node=TextNode())]
mock_super_aretrieve.return_value = nodes
# Setup
engine = SimpleEngine(retriever=mocker.MagicMock())
@ -109,7 +110,7 @@ class TestSimpleEngine:
# Assertions
mock_query_bundle.assert_called_once_with(test_query)
mock_super_aretrieve.assert_called_once_with("query_bundle")
assert result == ["node_with_score"]
assert result == nodes
def test_add_docs(self, mocker):
# Mock
@ -157,4 +158,4 @@ class TestSimpleEngine:
assert mock_retriever.add_nodes.call_count == 1
for node in mock_retriever.add_nodes.call_args[0][0]:
assert isinstance(node, TextNode)
assert "obj" in node.metadata
assert "obj_dict" in node.metadata