mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-12 01:02:37 +02:00
Merge branch 'main' into feature_mute_stream_log_for_info_level
This commit is contained in:
commit
67130fc2d6
15 changed files with 325 additions and 37 deletions
|
|
@ -11,9 +11,13 @@ from metagpt.rag.schema import (
|
|||
BM25RetrieverConfig,
|
||||
ChromaIndexConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
)
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
|
||||
QUESTION = "What are key qualities to be a good writer?"
|
||||
|
|
@ -39,12 +43,22 @@ class Player(BaseModel):
|
|||
class RAGExample:
|
||||
"""Show how to use RAG."""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
def __init__(self, engine: SimpleEngine = None):
|
||||
self._engine = engine
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
if not self._engine:
|
||||
self._engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
return self._engine
|
||||
|
||||
@engine.setter
|
||||
def engine(self, value: SimpleEngine):
|
||||
self._engine = value
|
||||
|
||||
async def run_pipeline(self, question=QUESTION, print_title=True):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
|
|
@ -97,6 +111,7 @@ class RAGExample:
|
|||
self.engine.add_docs([travel_filepath])
|
||||
await self.run_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
@handle_exception
|
||||
async def add_objects(self, print_title=True):
|
||||
"""This example show how to add objects.
|
||||
|
||||
|
|
@ -154,20 +169,41 @@ class RAGExample:
|
|||
"""
|
||||
self._print_title("Init And Query ChromaDB")
|
||||
|
||||
# save index
|
||||
# 1. save index
|
||||
output_dir = DATA_PATH / "rag"
|
||||
SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
|
||||
)
|
||||
|
||||
# load index
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=ChromaIndexConfig(persist_path=output_dir),
|
||||
# 2. load index
|
||||
engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir))
|
||||
|
||||
# 3. query
|
||||
answer = await engine.aquery(TRAVEL_QUESTION)
|
||||
self._print_query_result(answer)
|
||||
|
||||
@handle_exception
|
||||
async def init_and_query_es(self):
|
||||
"""This example show how to use es. how to save and load index. will print something like:
|
||||
|
||||
Query Result:
|
||||
Bob likes traveling.
|
||||
"""
|
||||
self._print_title("Init And Query Elasticsearch")
|
||||
|
||||
# 1. create es index and save docs
|
||||
store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200")
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)],
|
||||
)
|
||||
|
||||
# query
|
||||
answer = engine.query(TRAVEL_QUESTION)
|
||||
# 2. load index
|
||||
engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config))
|
||||
|
||||
# 3. query
|
||||
answer = await engine.aquery(TRAVEL_QUESTION)
|
||||
self._print_query_result(answer)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -205,6 +241,7 @@ async def main():
|
|||
await e.add_objects()
|
||||
await e.init_objects()
|
||||
await e.init_and_query_chromadb()
|
||||
await e.init_and_query_es()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Engines init"""
|
||||
|
||||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
from metagpt.rag.engines.flare import FLAREEngine
|
||||
|
||||
__all__ = ["SimpleEngine"]
|
||||
__all__ = ["SimpleEngine", "FLAREEngine"]
|
||||
|
|
|
|||
9
metagpt/rag/engines/flare.py
Normal file
9
metagpt/rag/engines/flare.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""FLARE Engine.
|
||||
|
||||
Use llamaindex's FLAREInstructQueryEngine as FLAREEngine, which accepts other engines as parameters.
|
||||
For example, Create a simple engine, and then pass it to FLAREEngine.
|
||||
"""
|
||||
|
||||
from llama_index.core.query_engine import ( # noqa: F401
|
||||
FLAREInstructQueryEngine as FLAREEngine,
|
||||
)
|
||||
|
|
@ -130,10 +130,12 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
objs = objs or []
|
||||
retriever_configs = retriever_configs or []
|
||||
|
||||
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
objs = objs or []
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class ConfigBasedFactory(GenericFactory):
|
|||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
raise ValueError(f"Unknown config: {key}")
|
||||
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import chromadb
|
|||
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
|
|
@ -11,6 +13,8 @@ from metagpt.rag.schema import (
|
|||
BaseIndexConfig,
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchKeywordIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
|
@ -22,6 +26,8 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
FAISSIndexConfig: self._create_faiss,
|
||||
ChromaIndexConfig: self._create_chroma,
|
||||
BM25IndexConfig: self._create_bm25,
|
||||
ElasticsearchIndexConfig: self._create_es,
|
||||
ElasticsearchKeywordIndexConfig: self._create_es,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -30,31 +36,44 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
return super().get_instance(config, **kwargs)
|
||||
|
||||
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
|
||||
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
return index
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
||||
def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
||||
def _index_from_storage(
|
||||
self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
return load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
|
||||
def _index_from_vector_store(
|
||||
self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
return VectorStoreIndex.from_vector_store(
|
||||
vector_store=vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ class RAGLLM(CustomLLM):
|
|||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
"""Get LLM metadata."""
|
||||
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
|
||||
return LLMMetadata(
|
||||
context_window=self.context_window, num_output=self.num_output, model_name=self.model_name or "unknown"
|
||||
)
|
||||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
|
|
|
|||
|
|
@ -3,9 +3,16 @@
|
|||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.postprocessor.colbert_rerank import ColbertRerank
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
|
||||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import (
|
||||
BaseRankerConfig,
|
||||
ColbertRerankConfig,
|
||||
LLMRankerConfig,
|
||||
ObjectRankerConfig,
|
||||
)
|
||||
|
||||
|
||||
class RankerFactory(ConfigBasedFactory):
|
||||
|
|
@ -14,6 +21,8 @@ class RankerFactory(ConfigBasedFactory):
|
|||
def __init__(self):
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
ColbertRerankConfig: self._create_colbert_ranker,
|
||||
ObjectRankerConfig: self._create_object_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -28,6 +37,12 @@ class RankerFactory(ConfigBasedFactory):
|
|||
config.llm = self._extract_llm(config, **kwargs)
|
||||
return LLMRerank(**config.model_dump())
|
||||
|
||||
def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
|
||||
return ColbertRerank(**config.model_dump())
|
||||
|
||||
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
|
||||
return ObjectSortPostprocessor(**config.model_dump())
|
||||
|
||||
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
|
||||
return self._val_from_config_or_kwargs("llm", config, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,18 +6,22 @@ import chromadb
|
|||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchKeywordRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
|
|
@ -32,6 +36,8 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
ChromaRetrieverConfig: self._create_chroma_retriever,
|
||||
ElasticsearchRetrieverConfig: self._create_es_retriever,
|
||||
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -53,20 +59,29 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
nodes = list(config.index.docstore.docs.values())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return ElasticsearchRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
|
|
|
|||
55
metagpt/rag/rankers/object_ranker.py
Normal file
55
metagpt/rag/rankers/object_ranker.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""Object ranker."""
|
||||
|
||||
import heapq
|
||||
import json
|
||||
from typing import Literal, Optional
|
||||
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.rag.schema import ObjectNode
|
||||
|
||||
|
||||
class ObjectSortPostprocessor(BaseNodePostprocessor):
|
||||
"""Sorted by object's field, desc or asc.
|
||||
|
||||
Assumes nodes is list of ObjectNode with score.
|
||||
"""
|
||||
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
top_n: int = 5
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ObjectSortPostprocessor"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: list[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> list[NodeWithScore]:
|
||||
"""Postprocess nodes."""
|
||||
if query_bundle is None:
|
||||
raise ValueError("Missing query bundle in extra info.")
|
||||
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
self._check_metadata(nodes[0].node)
|
||||
|
||||
sort_key = lambda node: json.loads(node.node.metadata["obj_json"])[self.field_name]
|
||||
return self._get_sort_func()(self.top_n, nodes, key=sort_key)
|
||||
|
||||
def _check_metadata(self, node: ObjectNode):
|
||||
try:
|
||||
obj_dict = json.loads(node.metadata.get("obj_json"))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid object json in metadata: {node.metadata}, error: {e}")
|
||||
|
||||
if self.field_name not in obj_dict:
|
||||
raise ValueError(f"Field '{self.field_name}' not found in object: {obj_dict}")
|
||||
|
||||
def _get_sort_func(self):
|
||||
return heapq.nlargest if self.order == "desc" else heapq.nsmallest
|
||||
17
metagpt/rag/retrievers/es_retriever.py
Normal file
17
metagpt/rag/retrievers/es_retriever.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Elasticsearch retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class ElasticsearchRetriever(VectorIndexRetriever):
|
||||
"""Elasticsearch retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist.
|
||||
|
||||
Elasticsearch automatically saves, so there is no need to implement."""
|
||||
|
|
@ -8,7 +8,7 @@ class FAISSRetriever(VectorIndexRetriever):
|
|||
"""FAISS retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes"""
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
|
@ -46,6 +47,35 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
|
|||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class ElasticsearchStoreConfig(BaseModel):
|
||||
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
|
||||
es_url: str = Field(default=None, description="Elasticsearch URL.")
|
||||
es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
|
||||
es_api_key: str = Field(default=None, description="Elasticsearch API key.")
|
||||
es_user: str = Field(default=None, description="Elasticsearch username.")
|
||||
es_password: str = Field(default=None, description="Elasticsearch password.")
|
||||
batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
|
||||
distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")
|
||||
|
||||
|
||||
class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Elasticsearch-based retrievers. Support both vector and text."""
|
||||
|
||||
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
|
||||
vector_store_query_mode: VectorStoreQueryMode = Field(
|
||||
default=VectorStoreQueryMode.DEFAULT, description="default is vector query."
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig):
|
||||
"""Config for Elasticsearch-based retrievers. Support text only."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field(
|
||||
default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only."
|
||||
)
|
||||
|
||||
|
||||
class BaseRankerConfig(BaseModel):
|
||||
"""Common config for rankers.
|
||||
|
||||
|
|
@ -53,7 +83,6 @@ class BaseRankerConfig(BaseModel):
|
|||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
top_n: int = Field(default=5, description="The number of top results to return.")
|
||||
|
||||
|
||||
|
|
@ -66,12 +95,24 @@ class LLMRankerConfig(BaseRankerConfig):
|
|||
)
|
||||
|
||||
|
||||
class ColbertRerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.")
|
||||
device: str = Field(default="cpu", description="Device to use for sentence transformer.")
|
||||
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
|
||||
|
||||
|
||||
class ObjectRankerConfig(BaseRankerConfig):
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
|
||||
|
||||
class BaseIndexConfig(BaseModel):
|
||||
"""Common config for index.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
|
||||
|
||||
|
||||
|
|
@ -97,6 +138,19 @@ class BM25IndexConfig(BaseIndexConfig):
|
|||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ElasticsearchIndexConfig(VectorIndexConfig):
|
||||
"""Config for es-based index."""
|
||||
|
||||
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
|
||||
persist_path: Union[str, Path] = ""
|
||||
|
||||
|
||||
class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig):
|
||||
"""Config for es-based index. no embedding."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -36,6 +36,8 @@ extras_require = {
|
|||
"llama-index-readers-file==0.1.4",
|
||||
"llama-index-retrievers-bm25==0.1.3",
|
||||
"llama-index-vector-stores-faiss==0.1.1",
|
||||
"llama-index-vector-stores-elasticsearch==0.1.6",
|
||||
"llama-index-postprocessor-colbert-rerank==0.1.1",
|
||||
"chromadb==0.4.23",
|
||||
],
|
||||
}
|
||||
|
|
|
|||
60
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
60
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import ObjectNode
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
score: int
|
||||
|
||||
|
||||
class TestObjectSortPostprocessor:
|
||||
@pytest.fixture
|
||||
def nodes_with_scores(self):
|
||||
nodes = [
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5),
|
||||
]
|
||||
return nodes
|
||||
|
||||
@pytest.fixture
|
||||
def query_bundle(self, mocker):
|
||||
return mocker.MagicMock(spec=QueryBundle)
|
||||
|
||||
def test_sort_descending(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [20, 10, 5]
|
||||
|
||||
def test_sort_ascending(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [5, 10, 20]
|
||||
|
||||
def test_top_n_limit(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert len(sorted_nodes) == 2
|
||||
assert [node.score for node in sorted_nodes] == [20, 10]
|
||||
|
||||
def test_invalid_json_metadata(self, query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle)
|
||||
|
||||
def test_missing_query_bundle(self, nodes_with_scores):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
|
||||
|
||||
def test_field_not_found_in_object(self):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes)
|
||||
Loading…
Add table
Add a link
Reference in a new issue