rename RAGConfigRegistry to ConfigBasedFactory

This commit is contained in:
seehi 2024-03-15 18:55:08 +08:00
parent f46cc95bc2
commit 8e8075317e
6 changed files with 17 additions and 17 deletions

View file

@ -68,8 +68,8 @@ class RAGExample:
async def rag_add_docs(self):
"""This example show how to add docs.
before add docs llm anwser I don't know
after add docs llm give the correct answer, will print something like:
Before add docs llm anwser I don't know.
After add docs llm give the correct answer, will print something like:
[Before add docs]
Retrieve Result:
@ -98,8 +98,8 @@ class RAGExample:
async def rag_add_objs(self, print_title=True):
"""This example show how to add objs.
before add docs engine retrieve nothing.
after add objs engine give the correct answer, will print something like:
Before add docs engine retrieve nothing.
After add objs engine give the correct answer, will print something like:
[Before add objs]
Retrieve Result:

View file

@ -29,7 +29,7 @@ class GenericFactory:
raise ValueError(f"Creator not registered for key: {key}")
class RAGConfigRegistry(GenericFactory):
class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""
def get_instance(self, key: Any, **kwargs) -> Any:

View file

@ -6,7 +6,7 @@ from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import RAGConfigRegistry
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
BaseIndexConfig,
BM25IndexConfig,
@ -16,7 +16,7 @@ from metagpt.rag.schema import (
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RAGIndexFactory(RAGConfigRegistry):
class RAGIndexFactory(ConfigBasedFactory):
def __init__(self):
creators = {
FAISSIndexConfig: self._create_faiss,

View file

@ -4,11 +4,11 @@ from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from metagpt.rag.factories.base import RAGConfigRegistry
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
class RankerFactory(RAGConfigRegistry):
class RankerFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):

View file

@ -8,7 +8,7 @@ from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import RAGConfigRegistry
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
@ -24,7 +24,7 @@ from metagpt.rag.schema import (
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RetrieverFactory(RAGConfigRegistry):
class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):

View file

@ -1,6 +1,6 @@
import pytest
from metagpt.rag.factories.base import GenericFactory, RAGConfigRegistry
from metagpt.rag.factories.base import ConfigBasedFactory, GenericFactory
class TestGenericFactory:
@ -55,7 +55,7 @@ class DummyConfig:
self.name = name
class TestRAGConfigRegistry:
class TestConfigBasedFactory:
@pytest.fixture
def config_creators(self):
return {
@ -64,7 +64,7 @@ class TestRAGConfigRegistry:
@pytest.fixture
def config_factory(self, config_creators):
return RAGConfigRegistry(creators=config_creators)
return ConfigBasedFactory(creators=config_creators)
def test_get_instance_success(self, config_factory):
# Test successful retrieval of an instance
@ -85,18 +85,18 @@ class TestRAGConfigRegistry:
def test_val_from_config_or_kwargs_priority(self):
# Test that the value from the config object has priority over kwargs
config = DummyConfig(name="ConfigName")
result = RAGConfigRegistry._val_from_config_or_kwargs("name", config, name="KwargsName")
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
assert result == "ConfigName"
def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
# Test fallback to kwargs when config object does not have the value
config = DummyConfig(name=None)
result = RAGConfigRegistry._val_from_config_or_kwargs("name", config, name="KwargsName")
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
assert result == "KwargsName"
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
RAGConfigRegistry._val_from_config_or_kwargs("missing_key", config)
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)