add excluded_llm_metadata_keys

This commit is contained in:
seehi 2024-03-07 15:35:28 +08:00
parent f149007752
commit a3b2cf7f0b
4 changed files with 44 additions and 21 deletions

View file

@ -22,7 +22,6 @@ from llama_index.core.schema import (
NodeWithScore,
QueryBundle,
QueryType,
TextNode,
TransformComponent,
)
@ -35,7 +34,12 @@ from metagpt.rag.factories import (
)
from metagpt.rag.interface import RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
from metagpt.rag.schema import (
BaseIndexConfig,
BaseRankerConfig,
BaseRetrieverConfig,
ObjectNode,
)
from metagpt.utils.common import import_class
@ -149,18 +153,9 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [TextNode(text=obj.rag_key(), metadata=self._get_obj_metadata(obj)) for obj in objs]
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
self._save_nodes(nodes)
def _get_obj_metadata(self, obj: RAGObject) -> dict:
metadata = {
"is_obj": True,
"obj_dict": obj.model_dump(),
"obj_cls_name": obj.__class__.__name__,
"obj_mod_name": obj.__class__.__module__,
}
return metadata
def _ensure_retriever_modifiable(self):
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")

View file

@ -5,8 +5,11 @@ from typing import Any, Union
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from pydantic import BaseModel, ConfigDict, Field
from metagpt.rag.interface import RAGObject
class BaseRetrieverConfig(BaseModel):
"""Common config for retrievers.
@ -84,3 +87,27 @@ class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""
is_obj: bool = Field(default=True)
obj_dict: dict = Field(..., description="Inplement rag.interface.RAGObject.model_dump(), e.g. obj.model_dump()")
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
class ObjectNode(TextNode):
"""RAG add object."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
@staticmethod
def get_obj_metadata(obj: RAGObject) -> dict:
metadata = ObjectNodeMetadata(
obj_dict=obj.model_dump(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
)
return metadata.model_dump()