From da359fdbb156ee442234e344d52c989dabd9374c Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 28 Mar 2024 15:47:38 +0800 Subject: [PATCH 1/2] lazy import colbert --- metagpt/rag/factories/ranker.py | 7 ++++++- setup.py | 1 - 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 07cb1b929..476fe8c1a 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -3,7 +3,6 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.postprocessor.colbert_rerank import ColbertRerank from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor @@ -38,6 +37,12 @@ class RankerFactory(ConfigBasedFactory): return LLMRerank(**config.model_dump()) def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.colbert_rerank import ColbertRerank + except ImportError: + raise ImportError( + "`llama-index-postprocessor-colbert-rerank` package not found, please run `pip install llama-index-postprocessor-colbert-rerank`" + ) return ColbertRerank(**config.model_dump()) def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: diff --git a/setup.py b/setup.py index 4fa5499da..3eab2b6a0 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,6 @@ extras_require = { "llama-index-vector-stores-faiss==0.1.1", "llama-index-vector-stores-elasticsearch==0.1.6", "llama-index-vector-stores-chroma==0.1.6", - "llama-index-postprocessor-colbert-rerank==0.1.1", ], } From 4ee273df4722435c9327a00c49b2fd421e7d2884 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 28 Mar 2024 16:06:32 +0800 Subject: [PATCH 2/2] lazy import colbert --- tests/metagpt/rag/factories/test_ranker.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/metagpt/rag/factories/test_ranker.py b/tests/metagpt/rag/factories/test_ranker.py index 3f6b94b47..e40f7f8df 100644 --- a/tests/metagpt/rag/factories/test_ranker.py +++ b/tests/metagpt/rag/factories/test_ranker.py @@ -1,3 +1,5 @@ +import contextlib + import pytest from llama_index.core.llms import MockLLM from llama_index.core.postprocessor import LLMRerank @@ -41,12 +43,13 @@ class TestRankerFactory: assert isinstance(ranker, LLMRerank) def test_create_colbert_ranker(self, mocker, mock_llm): - mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert") + with contextlib.suppress(ImportError): + mocker.patch("llama_index.postprocessor.colbert_rerank.ColbertRerank", return_value="colbert") - mock_config = ColbertRerankConfig(llm=mock_llm) - ranker = self.ranker_factory._create_colbert_ranker(mock_config) + mock_config = ColbertRerankConfig(llm=mock_llm) + ranker = self.ranker_factory._create_colbert_ranker(mock_config) - assert ranker == "colbert" + assert ranker == "colbert" def test_create_object_ranker(self, mocker, mock_llm): mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")