add FLAREEngine and ColbertRerank

This commit is contained in:
seehi 2024-03-21 16:50:59 +08:00
parent e53188f898
commit 6e30b42cc0
5 changed files with 15 additions and 5 deletions

View file

@ -1,5 +1,6 @@
"""Engines init"""
from metagpt.rag.engines.simple import SimpleEngine
from metagpt.rag.engines.flare import FLAREEngine
__all__ = ["SimpleEngine"]
__all__ = ["SimpleEngine", "FLAREEngine"]

View file

View file

@ -3,18 +3,17 @@
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.colbert_rerank import ColbertRerank
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
from metagpt.rag.schema import BaseRankerConfig, ColbertRerankConfig, LLMRankerConfig
class RankerFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
}
creators = {LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker}
super().__init__(creators)
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
@ -28,6 +27,9 @@ class RankerFactory(ConfigBasedFactory):
config.llm = self._extract_llm(config, **kwargs)
return LLMRerank(**config.model_dump())
def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
return ColbertRerank(**config.model_dump())
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)

View file

@ -84,6 +84,12 @@ class LLMRankerConfig(BaseRankerConfig):
)
class ColbertRerankConfig(BaseRankerConfig):
model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.")
device: str = Field(default="cpu", description="Device to use for sentence transformer.")
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
class BaseIndexConfig(BaseModel):
"""Common config for index.

View file

@ -18,6 +18,7 @@ llama-index-readers-file==0.1.4
llama-index-retrievers-bm25==0.1.3
llama-index-vector-stores-faiss==0.1.1
llama-index-vector-stores-elasticsearch==0.1.5
llama-index-postprocessor-colbert-rerank==0.1.1
chromadb==0.4.23
loguru==0.6.0
meilisearch==0.21.0