From cb9543b2b9b374dcc449956de44f88fdd988c82a Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 15 Mar 2024 15:36:10 +0800 Subject: [PATCH] rename ConfigFactory to RAGConfigRegistry --- metagpt/rag/factories/base.py | 2 +- metagpt/rag/factories/index.py | 4 ++-- metagpt/rag/factories/ranker.py | 4 ++-- metagpt/rag/factories/retriever.py | 4 ++-- tests/metagpt/rag/factories/test_base.py | 12 ++++++------ 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index 5c6173a3f..bf7e55b17 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -29,7 +29,7 @@ class GenericFactory: raise ValueError(f"Creator not registered for key: {key}") -class ConfigFactory(GenericFactory): +class RAGConfigRegistry(GenericFactory): """Designed to get objects based on object type.""" def get_instance(self, key: Any, **kwargs) -> Any: diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index e6c87c64a..009bbc59f 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -6,7 +6,7 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.vector_stores.faiss import FaissVectorStore -from metagpt.rag.factories.base import ConfigFactory +from metagpt.rag.factories.base import RAGConfigRegistry from metagpt.rag.schema import ( BaseIndexConfig, BM25IndexConfig, @@ -16,7 +16,7 @@ from metagpt.rag.schema import ( from metagpt.rag.vector_stores.chroma import ChromaVectorStore -class RAGIndexFactory(ConfigFactory): +class RAGIndexFactory(RAGConfigRegistry): def __init__(self): creators = { FAISSIndexConfig: self._create_faiss, diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 0867c7945..f92d27b15 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -4,11 +4,11 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor -from metagpt.rag.factories.base import ConfigFactory +from metagpt.rag.factories.base import RAGConfigRegistry from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig -class RankerFactory(ConfigFactory): +class RankerFactory(RAGConfigRegistry): """Modify creators for dynamically instance implementation.""" def __init__(self): diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 2581cbef0..facb170ee 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -8,7 +8,7 @@ from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.faiss import FaissVectorStore -from metagpt.rag.factories.base import ConfigFactory +from metagpt.rag.factories.base import RAGConfigRegistry from metagpt.rag.retrievers.base import RAGRetriever from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever @@ -24,7 +24,7 @@ from metagpt.rag.schema import ( from metagpt.rag.vector_stores.chroma import ChromaVectorStore -class RetrieverFactory(ConfigFactory): +class RetrieverFactory(RAGConfigRegistry): """Modify creators for dynamically instance implementation.""" def __init__(self): diff --git a/tests/metagpt/rag/factories/test_base.py b/tests/metagpt/rag/factories/test_base.py index 78e969ff4..508bf3d2f 100644 --- a/tests/metagpt/rag/factories/test_base.py +++ b/tests/metagpt/rag/factories/test_base.py @@ -1,6 +1,6 @@ import pytest -from metagpt.rag.factories.base import ConfigFactory, GenericFactory +from metagpt.rag.factories.base import GenericFactory, RAGConfigRegistry class TestGenericFactory: @@ -55,7 +55,7 @@ class DummyConfig: self.name = name -class TestConfigFactory: +class TestRAGConfigRegistry: @pytest.fixture def config_creators(self): return { @@ -64,7 +64,7 @@ class TestConfigFactory: @pytest.fixture def config_factory(self, config_creators): - return ConfigFactory(creators=config_creators) + return RAGConfigRegistry(creators=config_creators) def test_get_instance_success(self, config_factory): # Test successful retrieval of an instance @@ -85,18 +85,18 @@ class TestConfigFactory: def test_val_from_config_or_kwargs_priority(self): # Test that the value from the config object has priority over kwargs config = DummyConfig(name="ConfigName") - result = ConfigFactory._val_from_config_or_kwargs("name", config, name="KwargsName") + result = RAGConfigRegistry._val_from_config_or_kwargs("name", config, name="KwargsName") assert result == "ConfigName" def test_val_from_config_or_kwargs_fallback_to_kwargs(self): # Test fallback to kwargs when config object does not have the value config = DummyConfig(name=None) - result = ConfigFactory._val_from_config_or_kwargs("name", config, name="KwargsName") + result = RAGConfigRegistry._val_from_config_or_kwargs("name", config, name="KwargsName") assert result == "KwargsName" def test_val_from_config_or_kwargs_key_error(self): # Test KeyError when the key is not found in both config object and kwargs config = DummyConfig(name=None) with pytest.raises(KeyError) as exc_info: - ConfigFactory._val_from_config_or_kwargs("missing_key", config) + RAGConfigRegistry._val_from_config_or_kwargs("missing_key", config) assert "The key 'missing_key' is required but not provided" in str(exc_info.value)