upgrade llama-index to v0.10

This commit is contained in:
seehi 2024-02-23 11:06:53 +08:00 committed by betterwang
parent 04527cf0eb
commit e14aedcea7
29 changed files with 725 additions and 370 deletions

View file

@ -1,58 +1,75 @@
import pytest
from llama_index import VectorStoreIndex
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
class TestSimpleEngine:
def test_from_docs(self, mocker):
@pytest.fixture
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
@pytest.fixture
def mock_get_retriever(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
@pytest.fixture
def mock_get_rankers(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_rankers")
@pytest.fixture
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
def test_from_docs(
self,
mocker,
mock_simple_directory_reader,
mock_vector_store_index,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
):
# Mock
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
mock_service_context.return_value = "service_context"
mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
mock_get_retriever = mocker.patch("metagpt.rag.engines.simple.get_retriever")
mock_get_rankers = mocker.patch("metagpt.rag.engines.simple.get_rankers")
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
# Setup
input_dir = "test_dir"
input_files = ["test_file1", "test_file2"]
transformations = [mocker.MagicMock()]
embed_model = mocker.MagicMock()
llm = mocker.MagicMock()
chunk_size = 100
chunk_overlap = 10
retriever_configs = mocker.MagicMock()
ranker_configs = mocker.MagicMock()
retriever_configs = [mocker.MagicMock()]
ranker_configs = [mocker.MagicMock()]
# Execute
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
transformations=transformations,
embed_model=embed_model,
llm=llm,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_service_context.assert_called_once_with(
embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
configs=retriever_configs, index=mock_vector_store_index.return_value
)
mock_vector_store_index.assert_called_once_with(
["document1", "document2"], service_context=mock_service_context.return_value
)
mock_get_retriever.assert_called_once_with(mock_vector_store_index.return_value, configs=retriever_configs)
mock_get_rankers.assert_called_once_with(
configs=ranker_configs, service_context=mock_service_context.return_value
)
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
@pytest.mark.asyncio
@ -100,8 +117,12 @@ class TestSimpleEngine:
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"]
mock_index._transformations = mocker.MagicMock()
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
mock_run_transformations.return_value = ["node1", "node2"]
# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
@ -113,3 +134,27 @@ class TestSimpleEngine:
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
def test_add_objs(self, mocker):
# Mock
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
# Setup
class CustomTextNode(TextNode):
def rag_key(self):
return ""
def model_dump(self):
return {}
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
# Execute
engine.add_objs(objs=objs)
# Assertions
assert mock_retriever.add_nodes.call_count == 1
for node in mock_retriever.add_nodes.call_args[0][0]:
assert isinstance(node, TextNode)
assert "obj" in node.metadata

View file

@ -0,0 +1,102 @@
import pytest
from metagpt.rag.factories.base import ConfigFactory, GenericFactory
class TestGenericFactory:
@pytest.fixture
def creators(self):
return {
"type1": lambda name: f"Instance of type1 with {name}",
"type2": lambda name: f"Instance of type2 with {name}",
}
@pytest.fixture
def factory(self, creators):
return GenericFactory(creators=creators)
def test_get_instance_success(self, factory):
# Test successful retrieval of an instance
key = "type1"
instance = factory.get_instance(key, name="TestName")
assert instance == "Instance of type1 with TestName"
def test_get_instance_failure(self, factory):
# Test failure to retrieve an instance due to unregistered key
with pytest.raises(ValueError) as exc_info:
factory.get_instance("unknown_key")
assert "Creator not registered for key: unknown_key" in str(exc_info.value)
def test_get_instances_success(self, factory):
# Test successful retrieval of multiple instances
keys = ["type1", "type2"]
instances = factory.get_instances(keys, name="TestName")
expected = ["Instance of type1 with TestName", "Instance of type2 with TestName"]
assert instances == expected
@pytest.mark.parametrize(
"keys,expected_exception_message",
[
(["unknown_key"], "Creator not registered for key: unknown_key"),
(["type1", "unknown_key"], "Creator not registered for key: unknown_key"),
],
)
def test_get_instances_with_failure(self, factory, keys, expected_exception_message):
# Test failure to retrieve instances due to at least one unregistered key
with pytest.raises(ValueError) as exc_info:
factory.get_instances(keys, name="TestName")
assert expected_exception_message in str(exc_info.value)
class DummyConfig:
"""A dummy config class for testing."""
def __init__(self, name):
self.name = name
class TestConfigFactory:
@pytest.fixture
def config_creators(self):
return {
DummyConfig: lambda config, **kwargs: f"Processed {config.name} with {kwargs.get('extra', 'no extra')}",
}
@pytest.fixture
def config_factory(self, config_creators):
return ConfigFactory(creators=config_creators)
def test_get_instance_success(self, config_factory):
# Test successful retrieval of an instance
config = DummyConfig(name="TestConfig")
instance = config_factory.get_instance(config, extra="additional data")
assert instance == "Processed TestConfig with additional data"
def test_get_instance_failure(self, config_factory):
# Test failure to retrieve an instance due to unknown config type
class UnknownConfig:
pass
config = UnknownConfig()
with pytest.raises(ValueError) as exc_info:
config_factory.get_instance(config)
assert "Unknown config:" in str(exc_info.value)
def test_val_from_config_or_kwargs_priority(self):
# Test that the value from the config object has priority over kwargs
config = DummyConfig(name="ConfigName")
result = ConfigFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
assert result == "ConfigName"
def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
# Test fallback to kwargs when config object does not have the value
config = DummyConfig(name=None)
result = ConfigFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
assert result == "KwargsName"
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
ConfigFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)

View file

@ -0,0 +1,56 @@
import pytest
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.gemini import Gemini
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.llm import RAGLLMFactory
class TestRAGLLMFactory:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# Mock the config object for all tests in this class
self.mock_config = mocker.MagicMock()
self.mock_config.llm.api_type = LLMType.OPENAI
self.mock_config.llm.base_url = "http://example.com"
self.mock_config.llm.api_key = "test_api_key"
self.mock_config.llm.api_version = "v1"
self.mock_config.llm.model = "test_model"
self.mock_config.llm.max_token = 100
self.mock_config.llm.temperature = 0.5
mocker.patch("metagpt.rag.factories.llm.config", self.mock_config)
self.factory = RAGLLMFactory()
@pytest.mark.parametrize(
"llm_type,expected_class",
[
(LLMType.OPENAI, OpenAI),
(LLMType.AZURE, AzureOpenAI),
(LLMType.ANTHROPIC, Anthropic),
(LLMType.GEMINI, Gemini),
(LLMType.OLLAMA, Ollama),
],
)
def test_creates_correct_llm_instance(self, llm_type, expected_class, mocker):
# Mock the LLM constructors
mocker.patch.object(expected_class, "__init__", return_value=None)
instance = self.factory.get_rag_llm(key=llm_type)
assert isinstance(instance, expected_class)
expected_class.__init__.assert_called_once()
def test_uses_default_llm_type_when_no_key_provided(self, mocker):
# Assume the default API type is OPENAI for this test
mock = mocker.patch.object(OpenAI, "__init__", return_value=None)
instance = self.factory.get_rag_llm()
assert isinstance(instance, OpenAI)
mock.assert_called_once_with(
api_base=self.mock_config.llm.base_url,
api_key=self.mock_config.llm.api_key,
api_version=self.mock_config.llm.api_version,
model=self.mock_config.llm.model,
max_tokens=self.mock_config.llm.max_token,
temperature=self.mock_config.llm.temperature,
)

View file

@ -0,0 +1,43 @@
import pytest
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from metagpt.rag.factories.ranker import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
class TestRankerFactory:
@pytest.fixture
def ranker_factory(self) -> RankerFactory:
return RankerFactory()
@pytest.fixture
def mock_llm(self, mocker):
return mocker.MagicMock(spec=LLM)
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = ranker_factory.get_rankers()
assert len(default_rankers) == 1
assert isinstance(default_rankers[0], LLMRerank)
ranker_factory._extract_llm.assert_called_once()
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
rankers = ranker_factory.get_rankers(configs=[mock_config])
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = ranker_factory._create_llm_ranker(mock_config)
assert isinstance(ranker, LLMRerank)
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
extracted_llm = ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm

View file

@ -0,0 +1,79 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
class TestRetrieverFactory:
@pytest.fixture
def retriever_factory(self):
return RetrieverFactory()
@pytest.fixture
def mock_faiss_index(self, mocker):
return mocker.MagicMock(spec=faiss.IndexFlatL2)
@pytest.fixture
def mock_vector_store_index(self, mocker):
mock = mocker.MagicMock(spec=VectorStoreIndex)
mock._embed_model = mocker.MagicMock()
mock.docstore.docs.values.return_value = []
return mock
def test_get_retriever_with_faiss_config(
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
assert isinstance(retriever, SimpleHybridRetriever)
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mock_vector_store_index.as_retriever = mocker.MagicMock()
retriever = retriever_factory.get_retriever()
mock_vector_store_index.as_retriever.assert_called_once()
assert retriever is mock_vector_store_index.as_retriever.return_value
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
extracted_index = retriever_factory._extract_index(config=mock_config)
assert extracted_index == mock_vector_store_index
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index

View file

@ -1,5 +1,5 @@
import pytest
from llama_index.schema import Node
from llama_index.core.schema import Node
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
@ -17,7 +17,7 @@ class TestDynamicBM25Retriever:
# 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi")
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
# 初始化DynamicBM25Retriever对象并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer)

View file

@ -1,5 +1,5 @@
import pytest
from llama_index.schema import Node
from llama_index.core.schema import Node
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever

View file

@ -1,7 +1,7 @@
from unittest.mock import AsyncMock
import pytest
from llama_index.schema import NodeWithScore, TextNode
from llama_index.core.schema import NodeWithScore, TextNode
from metagpt.rag.retrievers import SimpleHybridRetriever

View file

@ -1,130 +0,0 @@
import pytest
from llama_index import ServiceContext
from llama_index.indices.base import BaseIndex
from llama_index.postprocessor import LLMRerank
from metagpt.rag.factory import RankerFactory, RetrieverFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
)
class TestRetrieverFactory:
@pytest.fixture
def mock_base_index(self, mocker):
mock = mocker.MagicMock(spec=BaseIndex)
mock.as_retriever.return_value = mocker.MagicMock(spec=RAGRetriever)
mock.service_context = mocker.MagicMock()
mock.docstore.docs.values.return_value = []
return mock
@pytest.fixture
def mock_faiss_retriever_config(self):
return FAISSRetrieverConfig(dimensions=128)
@pytest.fixture
def mock_bm25_retriever_config(self):
return BM25RetrieverConfig()
@pytest.fixture
def mock_faiss_vector_store(self, mocker):
return mocker.patch("metagpt.rag.factory.FaissVectorStore")
@pytest.fixture
def mock_storage_context(self, mocker):
return mocker.patch("metagpt.rag.factory.StorageContext")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.factory.VectorStoreIndex")
@pytest.fixture
def mock_dynamic_bm25_retriever(self, mocker):
mock = mocker.MagicMock(spec=DynamicBM25Retriever)
return mocker.patch("metagpt.rag.factory.DynamicBM25Retriever", mock)
def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index):
factory = RetrieverFactory()
retriever = factory.get_retriever(index=mock_base_index)
assert isinstance(retriever, RAGRetriever)
def test_get_retriever_with_specific_config_returns_correct_retriever(
self,
mock_base_index,
mock_faiss_retriever_config,
mock_faiss_vector_store,
mock_storage_context,
mock_vector_store_index,
):
factory = RetrieverFactory()
retriever = factory.get_retriever(index=mock_base_index, configs=[mock_faiss_retriever_config])
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_multiple_configs_returns_hybrid_retriever(
self,
mock_base_index,
mock_faiss_retriever_config,
mock_bm25_retriever_config,
mock_faiss_vector_store,
mock_storage_context,
mock_vector_store_index,
mock_dynamic_bm25_retriever,
):
factory = RetrieverFactory()
retriever = factory.get_retriever(
index=mock_base_index, configs=[mock_faiss_retriever_config, mock_bm25_retriever_config]
)
assert isinstance(retriever, SimpleHybridRetriever)
def test_get_retriever_with_unknown_config_raises_value_error(self, mock_base_index, mocker):
mock_unknown_config = mocker.MagicMock()
factory = RetrieverFactory()
with pytest.raises(ValueError):
factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config])
class TestRankerFactory:
@pytest.fixture
def mock_service_context(self, mocker):
return mocker.MagicMock(spec=ServiceContext)
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
# Setup
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
# Setup
config = LLMRankerConfig(top_n=3)
factory = RankerFactory()
# Execute
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
# Assertions
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
assert rankers[0].top_n == 3
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
# Mock
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
# Setup
factory = RankerFactory()
# Execute & Assertions
with pytest.raises(ValueError):
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)