diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 476fe8c1a..b40085eec 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -3,7 +3,9 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor - +from llama_index.postprocessor.colbert_rerank import ColbertRerank +from llama_index.postprocessor.cohere_rerank import CohereRerank +from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor from metagpt.rag.schema import ( @@ -11,6 +13,8 @@ from metagpt.rag.schema import ( ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig, + CohereRerankConfig, + FlagEmbeddingConfig ) @@ -22,6 +26,8 @@ class RankerFactory(ConfigBasedFactory): LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker, ObjectRankerConfig: self._create_object_ranker, + CohereRerankConfig: self._create_cohere_rerank, + FlagEmbeddingConfig: self._create_flag_rerank, } super().__init__(creators) @@ -48,6 +54,13 @@ class RankerFactory(ConfigBasedFactory): def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: return ObjectSortPostprocessor(**config.model_dump()) + def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank: + return CohereRerank(**config.model_dump()) + + def _create_flag_rerank(self, config: FlagEmbeddingReranker, **kwargs) -> LLMRerank: + return FlagEmbeddingReranker(**config.model_dump()) + + def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: return self._val_from_config_or_kwargs("llm", config, **kwargs)