mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
Merge branch 'feat-merge-github-rag' into 'mgx_ops'
Merge the newest rag See merge request pub/MetaGPT!136
This commit is contained in:
commit
7b293235b8
17 changed files with 482 additions and 113 deletions
|
|
@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
@ -48,6 +49,9 @@ class Config(CLIParams, YamlModel):
|
|||
# Key Parameters
|
||||
llm: LLMConfig
|
||||
|
||||
# RAG Embedding
|
||||
embedding: EmbeddingConfig = EmbeddingConfig()
|
||||
|
||||
# Global Proxy. Not used by LLM, but by other tools such as browsers.
|
||||
proxy: str = ""
|
||||
|
||||
|
|
|
|||
50
metagpt/configs/embedding_config.py
Normal file
50
metagpt/configs/embedding_config.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class EmbeddingType(Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class EmbeddingConfig(YamlModel):
|
||||
"""Config for Embedding.
|
||||
|
||||
Examples:
|
||||
---------
|
||||
api_type: "openai"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "azure"
|
||||
api_key: "YOU_API_KEY"
|
||||
base_url: "YOU_BASE_URL"
|
||||
api_version: "YOU_API_VERSION"
|
||||
|
||||
api_type: "gemini"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "ollama"
|
||||
base_url: "YOU_BASE_URL"
|
||||
model: "YOU_MODEL"
|
||||
"""
|
||||
|
||||
api_type: Optional[EmbeddingType] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None
|
||||
embed_batch_size: Optional[int] = None
|
||||
|
||||
@field_validator("api_type", mode="before")
|
||||
@classmethod
|
||||
def check_api_type(cls, v):
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
|
@ -4,8 +4,7 @@ import json
|
|||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from fsspec import AbstractFileSystem
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
|
|
@ -64,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
response_synthesizer: Optional[BaseSynthesizer] = None,
|
||||
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
index: Optional[BaseIndex] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
retriever=retriever,
|
||||
|
|
@ -72,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
self.index = index
|
||||
self._transformations = transformations or self._default_transformations()
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
|
|
@ -84,7 +83,6 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""From docs.
|
||||
|
||||
|
|
@ -102,15 +100,20 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
if not input_dir and not input_files:
|
||||
raise ValueError("Must provide either `input_dir` or `input_files`.")
|
||||
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files, fs=fs).load_data()
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
cls._fix_document_metadata(documents)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
transformations = transformations or cls._default_transformations()
|
||||
nodes = run_transformations(documents, transformations=transformations)
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_objs(
|
||||
|
|
@ -139,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_index(
|
||||
|
|
@ -163,6 +169,13 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
|
||||
def retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
|
||||
nodes = super().retrieve(query_bundle)
|
||||
self._try_reconstruct_obj(nodes)
|
||||
return nodes
|
||||
|
||||
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Allow query to be str."""
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
|
|
@ -178,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
self._fix_document_metadata(documents)
|
||||
|
||||
nodes = run_transformations(documents, transformations=self.index._transformations)
|
||||
nodes = run_transformations(documents, transformations=self._transformations)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def add_objs(self, objs: list[RAGObject]):
|
||||
|
|
@ -194,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_nodes(
|
||||
cls,
|
||||
nodes: list[BaseNode],
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
transformations=transformations,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
|
|
@ -203,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
|
|
@ -210,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
|
|
@ -261,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
return embed_model or get_rag_embedding()
|
||||
|
||||
@staticmethod
|
||||
def _default_transformations():
|
||||
return [SentenceSplitter()]
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ class GenericFactory:
|
|||
if creator:
|
||||
return creator(**kwargs)
|
||||
|
||||
self._raise_for_key(key)
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"Creator not registered for key: {key}")
|
||||
|
||||
|
||||
|
|
@ -33,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
|
|||
"""Designed to get objects based on object type."""
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Key is config, such as a pydantic model.
|
||||
"""Get instance by the type of key.
|
||||
|
||||
Call func by the type of key, and the key will be passed to func.
|
||||
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
|
||||
Raise Exception if key not found.
|
||||
"""
|
||||
creator = self._creators.get(type(key))
|
||||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
self._raise_for_key(key)
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
|
||||
|
||||
Return None if not found.
|
||||
"""
|
||||
if config is not None and hasattr(config, key):
|
||||
val = getattr(config, key)
|
||||
if val is not None:
|
||||
|
|
@ -54,6 +64,4 @@ class ConfigBasedFactory(GenericFactory):
|
|||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
|
||||
raise KeyError(
|
||||
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,37 +1,103 @@
|
|||
"""RAG Embedding Factory."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex Embedding with MetaGPT's config."""
|
||||
"""Create LlamaIndex Embedding with MetaGPT's embedding config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
EmbeddingType.OPENAI: self._create_openai,
|
||||
EmbeddingType.AZURE: self._create_azure,
|
||||
EmbeddingType.GEMINI: self._create_gemini,
|
||||
EmbeddingType.OLLAMA: self._create_ollama,
|
||||
# For backward compatibility
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
|
||||
"""Key is EmbeddingType."""
|
||||
return super().get_instance(key or self._resolve_embedding_type())
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
|
||||
def _resolve_embedding_type(self) -> EmbeddingType | LLMType:
|
||||
"""Resolves the embedding type.
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAIEmbedding(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE.
|
||||
Raise TypeError if embedding type not found.
|
||||
"""
|
||||
if config.embedding.api_type:
|
||||
return config.embedding.api_type
|
||||
|
||||
if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
|
||||
return config.llm.api_type
|
||||
|
||||
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
|
||||
|
||||
def _create_openai(self) -> OpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
api_base=config.embedding.base_url or config.llm.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return OpenAIEmbedding(**params)
|
||||
|
||||
def _create_azure(self) -> AzureOpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
azure_endpoint=config.embedding.base_url or config.llm.base_url,
|
||||
api_version=config.embedding.api_version or config.llm.api_version,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return AzureOpenAIEmbedding(**params)
|
||||
|
||||
def _create_gemini(self) -> GeminiEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key,
|
||||
api_base=config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return GeminiEmbedding(**params)
|
||||
|
||||
def _create_ollama(self) -> OllamaEmbedding:
|
||||
params = dict(
|
||||
base_url=config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return OllamaEmbedding(**params)
|
||||
|
||||
def _try_set_model_and_batch_size(self, params: dict):
|
||||
"""Set the model_name and embed_batch_size only when they are specified."""
|
||||
if config.embedding.model:
|
||||
params["model_name"] = config.embedding.model
|
||||
|
||||
if config.embedding.embed_batch_size:
|
||||
params["embed_batch_size"] = config.embedding.embed_batch_size
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}")
|
||||
|
||||
|
||||
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""RAG LLM."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
|
||||
|
|
@ -15,7 +15,7 @@ from pydantic import Field
|
|||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import run_coroutine_in_new_loop
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.token_counter import TOKEN_MAX
|
||||
|
||||
|
||||
|
|
@ -39,7 +39,8 @@ class RAGLLM(CustomLLM):
|
|||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
|
||||
NestAsyncio.apply_once()
|
||||
return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs))
|
||||
|
||||
@llm_completion_callback()
|
||||
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from metagpt.rag.factories.base import ConfigBasedFactory
|
|||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import (
|
||||
BaseRankerConfig,
|
||||
BGERerankConfig,
|
||||
CohereRerankConfig,
|
||||
ColbertRerankConfig,
|
||||
LLMRankerConfig,
|
||||
ObjectRankerConfig,
|
||||
|
|
@ -22,6 +24,8 @@ class RankerFactory(ConfigBasedFactory):
|
|||
LLMRankerConfig: self._create_llm_ranker,
|
||||
ColbertRerankConfig: self._create_colbert_ranker,
|
||||
ObjectRankerConfig: self._create_object_ranker,
|
||||
CohereRerankConfig: self._create_cohere_rerank,
|
||||
BGERerankConfig: self._create_bge_rerank,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -45,6 +49,26 @@ class RankerFactory(ConfigBasedFactory):
|
|||
)
|
||||
return ColbertRerank(**config.model_dump())
|
||||
|
||||
def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank:
|
||||
try:
|
||||
from llama_index.postprocessor.cohere_rerank import CohereRerank
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`"
|
||||
)
|
||||
return CohereRerank(**config.model_dump())
|
||||
|
||||
def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank:
|
||||
try:
|
||||
from llama_index.postprocessor.flag_embedding_reranker import (
|
||||
FlagEmbeddingReranker,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`"
|
||||
)
|
||||
return FlagEmbeddingReranker(**config.model_dump())
|
||||
|
||||
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
|
||||
return ObjectSortPostprocessor(**config.model_dump())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import copy
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
|
@ -24,10 +27,25 @@ from metagpt.rag.schema import (
|
|||
ElasticsearchKeywordRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_or_build_index(build_index_func):
|
||||
"""Decorator to get or build an index.
|
||||
|
||||
Get index using `_extract_index` method, if not found, using build_index_func.
|
||||
"""
|
||||
|
||||
@wraps(build_index_func)
|
||||
def wrapper(self, config, **kwargs):
|
||||
index = self._extract_index(config, **kwargs)
|
||||
if index is not None:
|
||||
return index
|
||||
return build_index_func(self, config, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
|
|
@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
|
||||
|
||||
def _create_default(self, **kwargs) -> RAGRetriever:
|
||||
return self._extract_index(**kwargs).as_retriever()
|
||||
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
|
||||
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_faiss_index(config, **kwargs)
|
||||
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
index = self._extract_index(config, **kwargs)
|
||||
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
|
||||
|
||||
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_chroma_index(config, **kwargs)
|
||||
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_es_index(config, **kwargs)
|
||||
|
||||
return ElasticsearchRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
|
||||
return self._val_from_config_or_kwargs("nodes", config, **kwargs)
|
||||
|
||||
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
||||
def _build_default_index(self, **kwargs) -> VectorStoreIndex:
|
||||
index = VectorStoreIndex(
|
||||
nodes=self._extract_nodes(**kwargs),
|
||||
embed_model=self._extract_embed_model(**kwargs),
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
@get_or_build_index
|
||||
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
def _build_index_from_vector_store(
|
||||
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
old_index = self._extract_index(config, **kwargs)
|
||||
new_index = VectorStoreIndex(
|
||||
nodes=list(old_index.docstore.docs.values()),
|
||||
index = VectorStoreIndex(
|
||||
nodes=self._extract_nodes(config, **kwargs),
|
||||
storage_context=storage_context,
|
||||
embed_model=old_index._embed_model,
|
||||
embed_model=self._extract_embed_model(config, **kwargs),
|
||||
)
|
||||
return new_index
|
||||
|
||||
return index
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
|
|
|
|||
|
|
@ -40,8 +40,10 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
if self._index:
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
if self._index:
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
|
@ -1,14 +1,17 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, ClassVar, Literal, Optional, Union
|
||||
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
|
|
@ -31,7 +34,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig):
|
|||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
|
||||
EmbeddingType.GEMINI: 768,
|
||||
EmbeddingType.OLLAMA: 4096,
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
|
|
@ -45,6 +60,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
|
|||
|
||||
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchStoreConfig(BaseModel):
|
||||
|
|
@ -101,6 +119,16 @@ class ColbertRerankConfig(BaseRankerConfig):
|
|||
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
|
||||
|
||||
|
||||
class CohereRerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="rerank-english-v3.0")
|
||||
api_key: str = Field(default="YOUR_COHERE_API")
|
||||
|
||||
|
||||
class BGERerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
|
||||
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
|
||||
|
||||
|
||||
class ObjectRankerConfig(BaseRankerConfig):
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
|
|
@ -130,6 +158,9 @@ class ChromaIndexConfig(VectorIndexConfig):
|
|||
"""Config for chroma-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
|
|
|
|||
|
|
@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any:
|
|||
new_loop.call_soon_threadsafe(new_loop.stop)
|
||||
t.join()
|
||||
new_loop.close()
|
||||
|
||||
|
||||
class NestAsyncio:
|
||||
"""Make asyncio event loop reentrant."""
|
||||
|
||||
is_applied = False
|
||||
|
||||
@classmethod
|
||||
def apply_once(cls):
|
||||
"""Ensures `nest_asyncio.apply()` is called only once."""
|
||||
if not cls.is_applied:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
cls.is_applied = True
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -32,12 +32,15 @@ extras_require = {
|
|||
"llama-index-core==0.10.15",
|
||||
"llama-index-embeddings-azure-openai==0.1.6",
|
||||
"llama-index-embeddings-openai==0.1.5",
|
||||
"llama-index-embeddings-gemini==0.1.6",
|
||||
"llama-index-embeddings-ollama==0.1.2",
|
||||
"llama-index-llms-azure-openai==0.1.4",
|
||||
"llama-index-readers-file==0.1.4",
|
||||
"llama-index-retrievers-bm25==0.1.3",
|
||||
"llama-index-vector-stores-faiss==0.1.1",
|
||||
"llama-index-vector-stores-elasticsearch==0.1.6",
|
||||
"llama-index-vector-stores-chroma==0.1.6",
|
||||
"docx2txt==0.8",
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,10 +25,6 @@ class TestSimpleEngine:
|
|||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
|
@ -45,7 +41,6 @@ class TestSimpleEngine:
|
|||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
|
|
@ -81,11 +76,8 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
)
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_retriever.assert_called_once()
|
||||
mock_get_rankers.assert_called_once()
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
|
|
@ -119,7 +111,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is not None
|
||||
assert engine._transformations is not None
|
||||
|
||||
def test_from_objs_with_bm25_config(self):
|
||||
# Setup
|
||||
|
|
@ -137,6 +129,7 @@ class TestSimpleEngine:
|
|||
def test_from_index(self, mocker, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index.as_retriever.return_value = "retriever"
|
||||
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
|
||||
mock_get_index.return_value = mock_index
|
||||
|
||||
|
|
@ -149,7 +142,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is mock_index
|
||||
assert engine._retriever == "retriever"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
|
|
@ -200,14 +193,11 @@ class TestSimpleEngine:
|
|||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Exec
|
||||
|
|
@ -230,7 +220,7 @@ class TestSimpleEngine:
|
|||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Exec
|
||||
engine.add_objs(objs=objs)
|
||||
|
|
|
|||
|
|
@ -97,6 +97,5 @@ class TestConfigBasedFactory:
|
|||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert val is None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
|
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
|
|||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
|
||||
@staticmethod
|
||||
def mock_openai_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
@staticmethod
|
||||
def mock_azure_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
@staticmethod
|
||||
def mock_gemini_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
@staticmethod
|
||||
def mock_ollama_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
@pytest.mark.parametrize(
|
||||
("mock_func", "embedding_type"),
|
||||
[
|
||||
(mock_openai_embedding, LLMType.OPENAI),
|
||||
(mock_azure_embedding, LLMType.AZURE),
|
||||
(mock_openai_embedding, EmbeddingType.OPENAI),
|
||||
(mock_azure_embedding, EmbeddingType.AZURE),
|
||||
(mock_gemini_embedding, EmbeddingType.GEMINI),
|
||||
(mock_ollama_embedding, EmbeddingType.OLLAMA),
|
||||
],
|
||||
)
|
||||
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock = mock_func(mocker)
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(embedding_type)
|
||||
|
||||
# Assert
|
||||
mock.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_config):
|
||||
# Mock
|
||||
mock_openai_embedding = self.mock_openai_embedding(mocker)
|
||||
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
|
|
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
|
|||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, embed_batch_size, expected_params",
|
||||
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
|
||||
)
|
||||
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
|
||||
# Mock
|
||||
mock_config.embedding.model = model
|
||||
mock_config.embedding.embed_batch_size = embed_batch_size
|
||||
|
||||
# Setup
|
||||
test_params = {}
|
||||
|
||||
# Exec
|
||||
self.embedding_factory._try_set_model_and_batch_size(test_params)
|
||||
|
||||
# Assert
|
||||
assert test_params == expected_params
|
||||
|
||||
def test_resolve_embedding_type(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = EmbeddingType.OPENAI
|
||||
|
||||
# Exec
|
||||
embedding_type = self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
# Assert
|
||||
assert embedding_type == EmbeddingType.OPENAI
|
||||
|
||||
def test_resolve_embedding_type_exception(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.GEMINI
|
||||
|
||||
# Assert
|
||||
with pytest.raises(TypeError):
|
||||
self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
def test_raise_for_key(self):
|
||||
with pytest.raises(ValueError):
|
||||
self.embedding_factory._raise_for_key("key")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
||||
|
|
@ -43,6 +45,14 @@ class TestRetrieverFactory:
|
|||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_nodes(self, mocker):
|
||||
return [TextNode(text="msg")]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
|
|
@ -52,42 +62,40 @@ class TestRetrieverFactory:
|
|||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
retriever = self.retriever_factory.get_retriever(
|
||||
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
|
||||
)
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
|
||||
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
|
||||
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
|
||||
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
|
||||
|
|
@ -111,3 +119,19 @@ class TestRetrieverFactory:
|
|||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_get_or_build_when_get(self, mocker):
|
||||
want = "existing_index"
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
||||
def test_get_or_build_when_build(self, mocker):
|
||||
want = "call_build_es_index"
|
||||
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue