From 211ba3dce11923fb536b358c0df3b312df994a12 Mon Sep 17 00:00:00 2001 From: YangQianli92 <108046369+YangQianli92@users.noreply.github.com> Date: Tue, 16 Apr 2024 11:32:40 +0800 Subject: [PATCH] Add files via upload --- metagpt/rag/factories/ranker.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index b40085eec..61b81ccdc 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -3,9 +3,7 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.postprocessor.colbert_rerank import ColbertRerank -from llama_index.postprocessor.cohere_rerank import CohereRerank -from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker + from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor from metagpt.rag.schema import ( @@ -55,12 +53,23 @@ class RankerFactory(ConfigBasedFactory): return ObjectSortPostprocessor(**config.model_dump()) def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.cohere_rerank import CohereRerank + except ImportError: + raise ImportError( + "`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`" + ) return CohereRerank(**config.model_dump()) - def _create_flag_rerank(self, config: FlagEmbeddingReranker, **kwargs) -> LLMRerank: + def _create_flag_rerank(self, config: FlagEmbeddingConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker + except ImportError: + raise ImportError( + "`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`" + ) return FlagEmbeddingReranker(**config.model_dump()) - def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: return self._val_from_config_or_kwargs("llm", config, **kwargs)