mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
add NoEmbedding interface
This commit is contained in:
parent
6c95e601a0
commit
468e574ef0
3 changed files with 24 additions and 6 deletions
|
|
@ -2,11 +2,12 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.ingestion.pipeline import run_transformations
|
||||
from llama_index.core.llms import LLM
|
||||
|
|
@ -33,7 +34,7 @@ from metagpt.rag.factories import (
|
|||
get_rankers,
|
||||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.llm import get_rag_llm
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
|
|
@ -105,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=embed_model or get_rag_embedding(),
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
|
|
@ -137,7 +138,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=embed_model or get_rag_embedding(),
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
|
|
@ -151,7 +152,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""Load from previously maintained"""
|
||||
index = get_index(index_config, embed_model=embed_model or get_rag_embedding())
|
||||
index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config]))
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
async def asearch(self, content: str, **kwargs) -> str:
|
||||
|
|
@ -249,3 +250,10 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""LlamaIndex keep metadata['file_path'], which is unnecessary, maybe deleted in the near future."""
|
||||
for doc in documents:
|
||||
doc.excluded_embed_metadata_keys.append("file_path")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = None) -> BaseEmbedding:
|
||||
if configs and all(isinstance(c, NoEmbedding) for c in configs):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
return embed_model or get_rag_embedding()
|
||||
|
|
|
|||
|
|
@ -14,3 +14,9 @@ class RAGObject(Protocol):
|
|||
|
||||
Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json.
|
||||
"""
|
||||
|
||||
|
||||
class NoEmbedding(Protocol):
|
||||
"""Some retriever does not require embeddings, e.g. BM25"""
|
||||
|
||||
_no_embedding: bool
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, Union
|
|||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
|
@ -36,6 +36,8 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
|
|||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for BM25-based retrievers."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ChromaRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Chroma-based retrievers."""
|
||||
|
|
@ -92,6 +94,8 @@ class ChromaIndexConfig(VectorIndexConfig):
|
|||
class BM25IndexConfig(BaseIndexConfig):
|
||||
"""Config for bm25-based index."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue