Merge pull request #1129 from seehi/feat-lazy-import-colbert

Feat lazy import colbert
This commit is contained in:
Alexander Wu 2024-03-28 16:23:43 +08:00 committed by GitHub
commit f48a07389c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 13 additions and 6 deletions

View file

@ -3,7 +3,6 @@
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.colbert_rerank import ColbertRerank
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
@ -38,6 +37,12 @@ class RankerFactory(ConfigBasedFactory):
return LLMRerank(**config.model_dump())
def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
try:
from llama_index.postprocessor.colbert_rerank import ColbertRerank
except ImportError:
raise ImportError(
"`llama-index-postprocessor-colbert-rerank` package not found, please run `pip install llama-index-postprocessor-colbert-rerank`"
)
return ColbertRerank(**config.model_dump())
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:

View file

@ -38,7 +38,6 @@ extras_require = {
"llama-index-vector-stores-faiss==0.1.1",
"llama-index-vector-stores-elasticsearch==0.1.6",
"llama-index-vector-stores-chroma==0.1.6",
"llama-index-postprocessor-colbert-rerank==0.1.1",
],
}

View file

@ -1,3 +1,5 @@
import contextlib
import pytest
from llama_index.core.llms import MockLLM
from llama_index.core.postprocessor import LLMRerank
@ -41,12 +43,13 @@ class TestRankerFactory:
assert isinstance(ranker, LLMRerank)
def test_create_colbert_ranker(self, mocker, mock_llm):
mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert")
with contextlib.suppress(ImportError):
mocker.patch("llama_index.postprocessor.colbert_rerank.ColbertRerank", return_value="colbert")
mock_config = ColbertRerankConfig(llm=mock_llm)
ranker = self.ranker_factory._create_colbert_ranker(mock_config)
mock_config = ColbertRerankConfig(llm=mock_llm)
ranker = self.ranker_factory._create_colbert_ranker(mock_config)
assert ranker == "colbert"
assert ranker == "colbert"
def test_create_object_ranker(self, mocker, mock_llm):
mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")