mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-09 15:52:38 +02:00
add excluded_llm_metadata_keys
This commit is contained in:
parent
f149007752
commit
a3b2cf7f0b
4 changed files with 44 additions and 21 deletions
|
|
@ -22,7 +22,6 @@ from llama_index.core.schema import (
|
|||
NodeWithScore,
|
||||
QueryBundle,
|
||||
QueryType,
|
||||
TextNode,
|
||||
TransformComponent,
|
||||
)
|
||||
|
||||
|
|
@ -35,7 +34,12 @@ from metagpt.rag.factories import (
|
|||
)
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
BaseRankerConfig,
|
||||
BaseRetrieverConfig,
|
||||
ObjectNode,
|
||||
)
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
|
|
@ -149,18 +153,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [TextNode(text=obj.rag_key(), metadata=self._get_obj_metadata(obj)) for obj in objs]
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def _get_obj_metadata(self, obj: RAGObject) -> dict:
|
||||
metadata = {
|
||||
"is_obj": True,
|
||||
"obj_dict": obj.model_dump(),
|
||||
"obj_cls_name": obj.__class__.__name__,
|
||||
"obj_mod_name": obj.__class__.__module__,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
if not isinstance(self.retriever, ModifiableRAGRetriever):
|
||||
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@ from typing import Any, Union
|
|||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
class BaseRetrieverConfig(BaseModel):
|
||||
"""Common config for retrievers.
|
||||
|
|
@ -84,3 +87,27 @@ class ChromaIndexConfig(VectorIndexConfig):
|
|||
"""Config for chroma-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
is_obj: bool = Field(default=True)
|
||||
obj_dict: dict = Field(..., description="Inplement rag.interface.RAGObject.model_dump(), e.g. obj.model_dump()")
|
||||
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
|
||||
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
|
||||
|
||||
|
||||
class ObjectNode(TextNode):
|
||||
"""RAG add object."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
|
||||
|
||||
@staticmethod
|
||||
def get_obj_metadata(obj: RAGObject) -> dict:
|
||||
metadata = ObjectNodeMetadata(
|
||||
obj_dict=obj.model_dump(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
|
||||
)
|
||||
return metadata.model_dump()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue