From 468e574ef0ea0d5ba35c1ee1a86e09a21095807e Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 14 Mar 2024 20:39:19 +0800 Subject: [PATCH] add NoEmbedding interface --- metagpt/rag/engines/simple.py | 18 +++++++++++++----- metagpt/rag/interface.py | 6 ++++++ metagpt/rag/schema.py | 6 +++++- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index ebe467ecf..6045a8005 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -2,11 +2,12 @@ import json import os -from typing import Optional, Union +from typing import Any, Optional, Union from llama_index.core import SimpleDirectoryReader, VectorStoreIndex from llama_index.core.callbacks.base import CallbackManager from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.embeddings.mock_embed_model import MockEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.ingestion.pipeline import run_transformations from llama_index.core.llms import LLM @@ -33,7 +34,7 @@ from metagpt.rag.factories import ( get_rankers, get_retriever, ) -from metagpt.rag.interface import RAGObject +from metagpt.rag.interface import NoEmbedding, RAGObject from metagpt.rag.llm import get_rag_llm from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever @@ -105,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine): index = VectorStoreIndex.from_documents( documents=documents, transformations=transformations or [SentenceSplitter()], - embed_model=embed_model or get_rag_embedding(), + embed_model=cls._resolve_embed_model(embed_model, retriever_configs), ) return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @@ -137,7 +138,7 @@ class SimpleEngine(RetrieverQueryEngine): index = VectorStoreIndex( nodes=nodes, transformations=transformations or [SentenceSplitter()], - embed_model=embed_model or get_rag_embedding(), + embed_model=cls._resolve_embed_model(embed_model, retriever_configs), ) return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) @@ -151,7 +152,7 @@ class SimpleEngine(RetrieverQueryEngine): ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": """Load from previously maintained""" - index = get_index(index_config, embed_model=embed_model or get_rag_embedding()) + index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config])) return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs) async def asearch(self, content: str, **kwargs) -> str: @@ -249,3 +250,10 @@ class SimpleEngine(RetrieverQueryEngine): """LlamaIndex keep metadata['file_path'], which is unnecessary, maybe deleted in the near future.""" for doc in documents: doc.excluded_embed_metadata_keys.append("file_path") + + @staticmethod + def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = None) -> BaseEmbedding: + if configs and all(isinstance(c, NoEmbedding) for c in configs): + return MockEmbedding(embed_dim=1) + + return embed_model or get_rag_embedding() diff --git a/metagpt/rag/interface.py b/metagpt/rag/interface.py index 9af2c1219..726f68772 100644 --- a/metagpt/rag/interface.py +++ b/metagpt/rag/interface.py @@ -14,3 +14,9 @@ class RAGObject(Protocol): Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json. """ + + +class NoEmbedding(Protocol): + """Some retriever does not require embeddings, e.g. BM25""" + + _no_embedding: bool diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index ade4b3def..2894dc05a 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -6,7 +6,7 @@ from typing import Any, Union from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from metagpt.rag.interface import RAGObject @@ -36,6 +36,8 @@ class FAISSRetrieverConfig(IndexRetrieverConfig): class BM25RetrieverConfig(IndexRetrieverConfig): """Config for BM25-based retrievers.""" + _no_embedding: bool = PrivateAttr(default=True) + class ChromaRetrieverConfig(IndexRetrieverConfig): """Config for Chroma-based retrievers.""" @@ -92,6 +94,8 @@ class ChromaIndexConfig(VectorIndexConfig): class BM25IndexConfig(BaseIndexConfig): """Config for bm25-based index.""" + _no_embedding: bool = PrivateAttr(default=True) + class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode."""