add NoEmbedding interface

This commit is contained in:
seehi 2024-03-14 20:39:19 +08:00
parent 6c95e601a0
commit 468e574ef0
3 changed files with 24 additions and 6 deletions

View file

@ -2,11 +2,12 @@
import json
import os
from typing import Optional, Union
from typing import Any, Optional, Union
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.ingestion.pipeline import run_transformations
from llama_index.core.llms import LLM
@ -33,7 +34,7 @@ from metagpt.rag.factories import (
get_rankers,
get_retriever,
)
from metagpt.rag.interface import RAGObject
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.llm import get_rag_llm
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
@ -105,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine):
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=embed_model or get_rag_embedding(),
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@ -137,7 +138,7 @@ class SimpleEngine(RetrieverQueryEngine):
index = VectorStoreIndex(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=embed_model or get_rag_embedding(),
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@ -151,7 +152,7 @@ class SimpleEngine(RetrieverQueryEngine):
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""Load from previously maintained"""
index = get_index(index_config, embed_model=embed_model or get_rag_embedding())
index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config]))
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
async def asearch(self, content: str, **kwargs) -> str:
@ -249,3 +250,10 @@ class SimpleEngine(RetrieverQueryEngine):
"""LlamaIndex keep metadata['file_path'], which is unnecessary, maybe deleted in the near future."""
for doc in documents:
doc.excluded_embed_metadata_keys.append("file_path")
@staticmethod
def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = None) -> BaseEmbedding:
if configs and all(isinstance(c, NoEmbedding) for c in configs):
return MockEmbedding(embed_dim=1)
return embed_model or get_rag_embedding()

View file

@ -14,3 +14,9 @@ class RAGObject(Protocol):
Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json.
"""
class NoEmbedding(Protocol):
"""Some retriever does not require embeddings, e.g. BM25"""
_no_embedding: bool

View file

@ -6,7 +6,7 @@ from typing import Any, Union
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from metagpt.rag.interface import RAGObject
@ -36,6 +36,8 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
_no_embedding: bool = PrivateAttr(default=True)
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
@ -92,6 +94,8 @@ class ChromaIndexConfig(VectorIndexConfig):
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
_no_embedding: bool = PrivateAttr(default=True)
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""