diff --git a/metagpt/rag/engines/__init__.py b/metagpt/rag/engines/__init__.py index 373181384..93699db88 100644 --- a/metagpt/rag/engines/__init__.py +++ b/metagpt/rag/engines/__init__.py @@ -1,5 +1,6 @@ """Engines init""" from metagpt.rag.engines.simple import SimpleEngine +from metagpt.rag.engines.flare import FLAREEngine -__all__ = ["SimpleEngine"] +__all__ = ["SimpleEngine", "FLAREEngine"] diff --git a/metagpt/rag/engines/flare.py b/metagpt/rag/engines/flare.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index f05599e15..15dc55bf9 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -3,18 +3,17 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.postprocessor.colbert_rerank import ColbertRerank from metagpt.rag.factories.base import ConfigBasedFactory -from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig +from metagpt.rag.schema import BaseRankerConfig, ColbertRerankConfig, LLMRankerConfig class RankerFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" def __init__(self): - creators = { - LLMRankerConfig: self._create_llm_ranker, - } + creators = {LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker} super().__init__(creators) def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]: @@ -28,6 +27,9 @@ class RankerFactory(ConfigBasedFactory): config.llm = self._extract_llm(config, **kwargs) return LLMRerank(**config.model_dump()) + def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank: + return ColbertRerank(**config.model_dump()) + def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: return self._val_from_config_or_kwargs("llm", config, **kwargs) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index e98a6fc89..cacce3178 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -84,6 +84,12 @@ class LLMRankerConfig(BaseRankerConfig): ) +class ColbertRerankConfig(BaseRankerConfig): + model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.") + device: str = Field(default="cpu", description="Device to use for sentence transformer.") + keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") + + class BaseIndexConfig(BaseModel): """Common config for index. diff --git a/requirements.txt b/requirements.txt index 3e545d146..9bcd2a45b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ llama-index-readers-file==0.1.4 llama-index-retrievers-bm25==0.1.3 llama-index-vector-stores-faiss==0.1.1 llama-index-vector-stores-elasticsearch==0.1.5 +llama-index-postprocessor-colbert-rerank==0.1.1 chromadb==0.4.23 loguru==0.6.0 meilisearch==0.21.0