diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 80fb95842..70d748b7d 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -68,8 +68,8 @@ class RAGExample: async def rag_add_docs(self): """This example show how to add docs. - before add docs llm anwser I don't know - after add docs llm give the correct answer, will print something like: + Before add docs llm anwser I don't know. + After add docs llm give the correct answer, will print something like: [Before add docs] Retrieve Result: @@ -98,8 +98,8 @@ class RAGExample: async def rag_add_objs(self, print_title=True): """This example show how to add objs. - before add docs engine retrieve nothing. - after add objs engine give the correct answer, will print something like: + Before add docs engine retrieve nothing. + After add objs engine give the correct answer, will print something like: [Before add objs] Retrieve Result: diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index bf7e55b17..8f8155914 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -29,7 +29,7 @@ class GenericFactory: raise ValueError(f"Creator not registered for key: {key}") -class RAGConfigRegistry(GenericFactory): +class ConfigBasedFactory(GenericFactory): """Designed to get objects based on object type.""" def get_instance(self, key: Any, **kwargs) -> Any: diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index 009bbc59f..cdac64fa6 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -6,7 +6,7 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.vector_stores.faiss import FaissVectorStore -from metagpt.rag.factories.base import RAGConfigRegistry +from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.schema import ( BaseIndexConfig, BM25IndexConfig, @@ -16,7 +16,7 @@ from metagpt.rag.schema import ( from metagpt.rag.vector_stores.chroma import ChromaVectorStore -class RAGIndexFactory(RAGConfigRegistry): +class RAGIndexFactory(ConfigBasedFactory): def __init__(self): creators = { FAISSIndexConfig: self._create_faiss, diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index f92d27b15..753041c6b 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -4,11 +4,11 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor -from metagpt.rag.factories.base import RAGConfigRegistry +from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig -class RankerFactory(RAGConfigRegistry): +class RankerFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" def __init__(self): diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index facb170ee..e5e810b45 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -8,7 +8,7 @@ from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.faiss import FaissVectorStore -from metagpt.rag.factories.base import RAGConfigRegistry +from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.retrievers.base import RAGRetriever from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever @@ -24,7 +24,7 @@ from metagpt.rag.schema import ( from metagpt.rag.vector_stores.chroma import ChromaVectorStore -class RetrieverFactory(RAGConfigRegistry): +class RetrieverFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" def __init__(self): diff --git a/tests/metagpt/rag/factories/test_base.py b/tests/metagpt/rag/factories/test_base.py index 508bf3d2f..1d41e1872 100644 --- a/tests/metagpt/rag/factories/test_base.py +++ b/tests/metagpt/rag/factories/test_base.py @@ -1,6 +1,6 @@ import pytest -from metagpt.rag.factories.base import GenericFactory, RAGConfigRegistry +from metagpt.rag.factories.base import ConfigBasedFactory, GenericFactory class TestGenericFactory: @@ -55,7 +55,7 @@ class DummyConfig: self.name = name -class TestRAGConfigRegistry: +class TestConfigBasedFactory: @pytest.fixture def config_creators(self): return { @@ -64,7 +64,7 @@ class TestRAGConfigRegistry: @pytest.fixture def config_factory(self, config_creators): - return RAGConfigRegistry(creators=config_creators) + return ConfigBasedFactory(creators=config_creators) def test_get_instance_success(self, config_factory): # Test successful retrieval of an instance @@ -85,18 +85,18 @@ class TestRAGConfigRegistry: def test_val_from_config_or_kwargs_priority(self): # Test that the value from the config object has priority over kwargs config = DummyConfig(name="ConfigName") - result = RAGConfigRegistry._val_from_config_or_kwargs("name", config, name="KwargsName") + result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName") assert result == "ConfigName" def test_val_from_config_or_kwargs_fallback_to_kwargs(self): # Test fallback to kwargs when config object does not have the value config = DummyConfig(name=None) - result = RAGConfigRegistry._val_from_config_or_kwargs("name", config, name="KwargsName") + result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName") assert result == "KwargsName" def test_val_from_config_or_kwargs_key_error(self): # Test KeyError when the key is not found in both config object and kwargs config = DummyConfig(name=None) with pytest.raises(KeyError) as exc_info: - RAGConfigRegistry._val_from_config_or_kwargs("missing_key", config) + ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config) assert "The key 'missing_key' is required but not provided" in str(exc_info.value)