mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
add excluded_llm_metadata_keys
This commit is contained in:
parent
f149007752
commit
a3b2cf7f0b
4 changed files with 44 additions and 21 deletions
|
|
@ -19,15 +19,15 @@ QUESTION = "What are key qualities to be a good writer?"
|
|||
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
|
||||
TRAVEL_QUESTION = "What does Bob like?"
|
||||
|
||||
LLM_TIP = "If you not sure, just answer I don't know"
|
||||
LLM_TIP = "If you not sure, just answer I don't know."
|
||||
|
||||
|
||||
class Player(BaseModel):
|
||||
"""To demonstrate rag add objs"""
|
||||
"""To demonstrate rag add objs."""
|
||||
|
||||
name: str = ""
|
||||
goal: str = "Win The 100-meter Sprint"
|
||||
tool: str = "Red Bull Energy Drink"
|
||||
goal: str = "Win The 100-meter Sprint."
|
||||
tool: str = "Red Bull Energy Drink."
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
|
|
@ -108,7 +108,7 @@ class RAGExample:
|
|||
self._print_title("RAG Add Objs")
|
||||
|
||||
player = Player(name="Mike")
|
||||
question = f"{player.rag_key()}{LLM_TIP}"
|
||||
question = f"{player.rag_key()}"
|
||||
|
||||
print("[Before add objs]")
|
||||
await self._retrieve_and_print(question)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from llama_index.core.schema import (
|
|||
NodeWithScore,
|
||||
QueryBundle,
|
||||
QueryType,
|
||||
TextNode,
|
||||
TransformComponent,
|
||||
)
|
||||
|
||||
|
|
@ -35,7 +34,12 @@ from metagpt.rag.factories import (
|
|||
)
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
BaseRankerConfig,
|
||||
BaseRetrieverConfig,
|
||||
ObjectNode,
|
||||
)
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
|
|
@ -149,18 +153,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [TextNode(text=obj.rag_key(), metadata=self._get_obj_metadata(obj)) for obj in objs]
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def _get_obj_metadata(self, obj: RAGObject) -> dict:
|
||||
metadata = {
|
||||
"is_obj": True,
|
||||
"obj_dict": obj.model_dump(),
|
||||
"obj_cls_name": obj.__class__.__name__,
|
||||
"obj_mod_name": obj.__class__.__module__,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
if not isinstance(self.retriever, ModifiableRAGRetriever):
|
||||
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@ from typing import Any, Union
|
|||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
class BaseRetrieverConfig(BaseModel):
|
||||
"""Common config for retrievers.
|
||||
|
|
@ -84,3 +87,27 @@ class ChromaIndexConfig(VectorIndexConfig):
|
|||
"""Config for chroma-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
is_obj: bool = Field(default=True)
|
||||
obj_dict: dict = Field(..., description="Inplement rag.interface.RAGObject.model_dump(), e.g. obj.model_dump()")
|
||||
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
|
||||
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
|
||||
|
||||
|
||||
class ObjectNode(TextNode):
|
||||
"""RAG add object."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
|
||||
|
||||
@staticmethod
|
||||
def get_obj_metadata(obj: RAGObject) -> dict:
|
||||
metadata = ObjectNodeMetadata(
|
||||
obj_dict=obj.model_dump(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
|
||||
)
|
||||
return metadata.model_dump()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.schema import NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
|
|
@ -97,7 +97,8 @@ class TestSimpleEngine:
|
|||
mock_super_aretrieve = mocker.patch(
|
||||
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
|
||||
)
|
||||
mock_super_aretrieve.return_value = ["node_with_score"]
|
||||
nodes = [NodeWithScore(node=TextNode())]
|
||||
mock_super_aretrieve.return_value = nodes
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
|
|
@ -109,7 +110,7 @@ class TestSimpleEngine:
|
|||
# Assertions
|
||||
mock_query_bundle.assert_called_once_with(test_query)
|
||||
mock_super_aretrieve.assert_called_once_with("query_bundle")
|
||||
assert result == ["node_with_score"]
|
||||
assert result == nodes
|
||||
|
||||
def test_add_docs(self, mocker):
|
||||
# Mock
|
||||
|
|
@ -157,4 +158,4 @@ class TestSimpleEngine:
|
|||
assert mock_retriever.add_nodes.call_count == 1
|
||||
for node in mock_retriever.add_nodes.call_args[0][0]:
|
||||
assert isinstance(node, TextNode)
|
||||
assert "obj" in node.metadata
|
||||
assert "obj_dict" in node.metadata
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue