mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
upgrade llama-index to v0.10
This commit is contained in:
parent
04527cf0eb
commit
e14aedcea7
29 changed files with 725 additions and 370 deletions
|
|
@ -16,6 +16,8 @@ QUESTION = "What are key qualities to be a good writer?"
|
|||
|
||||
|
||||
class RAGExample:
|
||||
"""Show how to use RAG."""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
|
|
@ -84,14 +86,17 @@ class RAGExample:
|
|||
{'name': 'foo', 'goal': 'Win The Game', 'tool': 'Red Bull Energy Drink'}
|
||||
"""
|
||||
|
||||
self._print_title("RAG Add Docs")
|
||||
self._print_title("RAG Add Objs")
|
||||
|
||||
class Player(BaseModel):
|
||||
"""Player"""
|
||||
|
||||
name: str = ""
|
||||
goal: str = "Win The Game"
|
||||
tool: str = "Red Bull Energy Drink"
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.goal
|
||||
|
||||
foo = Player(name="foo")
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ from pathlib import Path
|
|||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
from llama_index.readers import Document, PDFReader, SimpleDirectoryReader
|
||||
from llama_index.core import Document, SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
from llama_index.readers.file import PDFReader
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ from pathlib import Path
|
|||
from typing import Any, Optional
|
||||
|
||||
import faiss
|
||||
from llama_index import VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.embeddings import BaseEmbedding
|
||||
from llama_index.schema import Document, QueryBundle, TextNode
|
||||
from llama_index.storage import StorageContext
|
||||
from llama_index.vector_stores import FaissVectorStore
|
||||
from llama_index.core import VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import Document, QueryBundle, TextNode
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.document import IndexableDocument
|
||||
from metagpt.document_store.base_store import LocalStore
|
||||
|
|
|
|||
|
|
@ -3,22 +3,32 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.callbacks.base import CallbackManager
|
||||
from llama_index.core.base_retriever import BaseRetriever
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.query_engine import RetrieverQueryEngine
|
||||
from llama_index.response_synthesizers import BaseSynthesizer
|
||||
from llama_index.schema import NodeWithScore, QueryBundle, QueryType, TextNode
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.ingestion.pipeline import run_transformations
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import (
|
||||
BaseSynthesizer,
|
||||
get_response_synthesizer,
|
||||
)
|
||||
from llama_index.core.retrievers import BaseRetriever
|
||||
from llama_index.core.schema import (
|
||||
NodeWithScore,
|
||||
QueryBundle,
|
||||
QueryType,
|
||||
TextNode,
|
||||
TransformComponent,
|
||||
)
|
||||
|
||||
from metagpt.rag.factory import get_rankers, get_retriever
|
||||
from metagpt.rag.factories import get_rag_llm, get_rankers, get_retriever
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.llm import get_default_llm
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
|
||||
from metagpt.rag.schema import BaseRankerConfig, BaseRetrieverConfig
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
||||
|
|
@ -51,45 +61,47 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
cls,
|
||||
input_dir: str = None,
|
||||
input_files: list[str] = None,
|
||||
llm: LLM = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
chunk_size: int = None,
|
||||
chunk_overlap: int = None,
|
||||
retriever_configs: list[RetrieverConfigType] = None,
|
||||
ranker_configs: list[RankerConfigType] = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""This engine is designed to be simple and straightforward
|
||||
|
||||
Args:
|
||||
input_dir: Path to the directory.
|
||||
input_files: List of file paths to read (Optional; overrides input_dir, exclude).
|
||||
llm: Must supported by llama index.
|
||||
embed_model: Must supported by llama index.
|
||||
chunk_size: The size of text chunks (in tokens) to split documents into for embedding.
|
||||
chunk_overlap: The number of tokens for overlapping between consecutive chunks. Helps in maintaining context continuity.
|
||||
transformations: Parse documents to nodes. Default [SentenceSplitter].
|
||||
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
|
||||
llm: Must supported by llama index. Default OpenAI.
|
||||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm=llm or get_default_llm(),
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=embed_model or get_embedding(),
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
|
||||
retriever = get_retriever(index, configs=retriever_configs)
|
||||
rankers = get_rankers(configs=ranker_configs, service_context=service_context)
|
||||
llm = llm or get_rag_llm()
|
||||
retriever = get_retriever(configs=retriever_configs, index=index)
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm)
|
||||
|
||||
return cls(retriever=retriever, node_postprocessors=rankers, index=index)
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
async def asearch(self, content: str, **kwargs) -> str:
|
||||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
|
||||
async def aretrieve(self, query_bundle: QueryType) -> list[NodeWithScore]:
|
||||
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Allow query to be str"""
|
||||
query_bundle = QueryBundle(query_bundle) if isinstance(query_bundle, str) else query_bundle
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
return await super().aretrieve(query_bundle)
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
|
|
@ -97,7 +109,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
self._ensure_retriever_modifiable()
|
||||
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
|
||||
nodes = run_transformations(documents, transformations=self.index._transformations)
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
||||
def add_objs(self, objs: list[RAGObject]):
|
||||
|
|
|
|||
6
metagpt/rag/factories/__init__.py
Normal file
6
metagpt/rag/factories/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""RAG factories"""
|
||||
from metagpt.rag.factories.retriever import get_retriever
|
||||
from metagpt.rag.factories.ranker import get_rankers
|
||||
from metagpt.rag.factories.llm import get_rag_llm
|
||||
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_llm"]
|
||||
58
metagpt/rag/factories/base.py
Normal file
58
metagpt/rag/factories/base.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Base Factory."""
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class GenericFactory:
|
||||
"""Designed to get objects based on any keys."""
|
||||
|
||||
def __init__(self, creators: dict[Any, Callable] = None):
|
||||
"""Creators is a dictionary.
|
||||
|
||||
Keys are identifiers, and the values are the associated creator function, which create objects.
|
||||
"""
|
||||
self._creators = creators or {}
|
||||
|
||||
def get_instances(self, keys: list[Any], **kwargs) -> list[Any]:
|
||||
"""Get instances by keys."""
|
||||
return [self.get_instance(key, **kwargs) for key in keys]
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Get instance by key.
|
||||
|
||||
Raise Exception if key not found.
|
||||
"""
|
||||
creator = self._creators.get(key)
|
||||
if creator:
|
||||
return creator(**kwargs)
|
||||
|
||||
raise ValueError(f"Creator not registered for key: {key}")
|
||||
|
||||
|
||||
class ConfigFactory(GenericFactory):
|
||||
"""Designed to get objects based on object type."""
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Key is config, such as a pydantic model.
|
||||
|
||||
Call func by the type of key, and the key will be passed to func.
|
||||
"""
|
||||
creator = self._creators.get(type(key))
|
||||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
raise ValueError(f"Unknown config: {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
|
||||
if config is not None and hasattr(config, key):
|
||||
val = getattr(config, key)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
|
||||
raise KeyError(
|
||||
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
|
||||
)
|
||||
76
metagpt/rag/factories/llm.py
Normal file
76
metagpt/rag/factories/llm.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""RAG LLM Factory.
|
||||
|
||||
The LLM of LlamaIndex and the LLM of MG are not the same.
|
||||
"""
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.llms.anthropic import Anthropic
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
from llama_index.llms.gemini import Gemini
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGLLMFactory(GenericFactory):
|
||||
"""Create LlamaIndex LLM with MG config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
LLMType.ANTHROPIC: self._create_anthropic,
|
||||
LLMType.GEMINI: self._create_gemini,
|
||||
LLMType.OLLAMA: self._create_ollama,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_llm(self, key: LLMType = None) -> LLM:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAI(
|
||||
api_base=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
model=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAI(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
model=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_anthropic(self):
|
||||
return Anthropic(
|
||||
base_url=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
model=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_gemini(self):
|
||||
return Gemini(
|
||||
api_base=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
model_name=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_ollama(self):
|
||||
return Ollama(base_url=config.llm.base_url, model=config.llm.model, temperature=config.llm.temperature)
|
||||
|
||||
|
||||
get_rag_llm = RAGLLMFactory().get_rag_llm
|
||||
39
metagpt/rag/factories/ranker.py
Normal file
39
metagpt/rag/factories/ranker.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""RAG Ranker Factory."""
|
||||
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
|
||||
from metagpt.rag.factories.base import ConfigFactory
|
||||
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
|
||||
|
||||
|
||||
class RankerFactory(ConfigFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return self._create_default(**kwargs)
|
||||
|
||||
return super().get_instances(configs, **kwargs)
|
||||
|
||||
def _create_default(self, **kwargs) -> list[LLMRerank]:
|
||||
config = LLMRankerConfig(llm=self._extract_llm(**kwargs))
|
||||
return [LLMRerank(**config.model_dump())]
|
||||
|
||||
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
|
||||
return self._val_from_config_or_kwargs("llm", config, **kwargs)
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
|
||||
config.llm = self._extract_llm(config, **kwargs)
|
||||
return LLMRerank(**config.model_dump())
|
||||
|
||||
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
64
metagpt/rag/factories/retriever.py
Normal file
64
metagpt/rag/factories/retriever.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations.
|
||||
|
||||
If multiple retrievers, using SimpleHybridRetriever.
|
||||
"""
|
||||
if not configs:
|
||||
return self._create_default(**kwargs)
|
||||
|
||||
retrievers = super().get_instances(configs, **kwargs)
|
||||
|
||||
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
|
||||
|
||||
def _create_default(self, **kwargs) -> RAGRetriever:
|
||||
return self._extract_index(**kwargs).as_retriever()
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
old_index = self._extract_index(config, **kwargs)
|
||||
new_index = VectorStoreIndex(
|
||||
nodes=list(old_index.docstore.docs.values()),
|
||||
storage_context=storage_context,
|
||||
embed_model=old_index._embed_model,
|
||||
)
|
||||
config.index = new_index
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = self._extract_index(config, **kwargs)
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump())
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
"""Factory for creating retriever, ranker"""
|
||||
from typing import Any, Callable
|
||||
|
||||
import faiss
|
||||
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
RankerConfigType,
|
||||
RetrieverConfigType,
|
||||
)
|
||||
|
||||
|
||||
class BaseFactory:
|
||||
"""
|
||||
A base factory class for creating instances based on provided configurations.
|
||||
It uses a registry of creator functions mapped to configuration types to instantiate objects dynamically.
|
||||
"""
|
||||
|
||||
def __init__(self, creators: dict[Any, Callable]):
|
||||
"""Creators is a dictionary mapping configuration types to creator functions."""
|
||||
self.creators = creators
|
||||
|
||||
def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]:
|
||||
"""Get instances by configs"""
|
||||
return [self._get_instance(config, **kwargs) for config in configs]
|
||||
|
||||
def _get_instance(self, config: Any, **kwargs) -> Any:
|
||||
create_func = self.creators.get(type(config))
|
||||
if create_func:
|
||||
return create_func(config, **kwargs)
|
||||
|
||||
raise ValueError(f"Unknown config: {config}")
|
||||
|
||||
|
||||
class RetrieverFactory(BaseFactory):
|
||||
"""Modify creators for dynamically instance implementation"""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations.
|
||||
If multiple retrievers, using SimpleHybridRetriever
|
||||
"""
|
||||
if not configs:
|
||||
return self._default_instance(index)
|
||||
|
||||
retrievers = super().get_instances(configs, index=index)
|
||||
|
||||
return (
|
||||
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
|
||||
if len(retrievers) > 1
|
||||
else retrievers[0]
|
||||
)
|
||||
|
||||
def _default_instance(self, index: BaseIndex) -> RAGRetriever:
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
vector_index = VectorStoreIndex(
|
||||
nodes=list(index.docstore.docs.values()),
|
||||
storage_context=storage_context,
|
||||
service_context=index.service_context,
|
||||
)
|
||||
return FAISSRetriever(**config.model_dump(), index=vector_index)
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex) -> DynamicBM25Retriever:
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
|
||||
|
||||
|
||||
class RankerFactory(BaseFactory):
|
||||
"""Modify creators for dynamically instance implementation"""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rankers(
|
||||
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
|
||||
) -> list[BaseNodePostprocessor]:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return [self._default_instance(service_context)]
|
||||
|
||||
return super().get_instances(configs, service_context=service_context)
|
||||
|
||||
def _default_instance(self, service_context: ServiceContext) -> LLMRerank:
|
||||
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, service_context: ServiceContext = None) -> LLMRerank:
|
||||
return LLMRerank(**config.model_dump(), service_context=service_context)
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
|
|
@ -1,14 +1,15 @@
|
|||
"""RAG Interface."""
|
||||
"""RAG Interfaces."""
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class RAGObject(Protocol):
|
||||
"""Support rag add object"""
|
||||
"""Support rag add object."""
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For rag search."""
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
"""For rag persist.
|
||||
Pydantic Model don't need to implement this, as there is a built-in function named model_dump
|
||||
|
||||
Pydantic Model don't need to implement this, as there is a built-in function named model_dump.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
"""RAG LLM
|
||||
The LLM of LlamaIndex and the LLM of MG are not the same.
|
||||
"""
|
||||
from llama_index.llms import OpenAI
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_default_llm() -> OpenAI:
|
||||
"""OpenAI"""
|
||||
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key, model=config.llm.model)
|
||||
|
|
@ -4,8 +4,8 @@ from abc import abstractmethod
|
|||
from typing import Optional
|
||||
|
||||
from llama_index import QueryBundle
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.schema import NodeWithScore
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
|
||||
class RAGRanker(BaseNodePostprocessor):
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
from abc import abstractmethod
|
||||
|
||||
from llama_index.retrievers import BaseRetriever
|
||||
from llama_index.schema import BaseNode, NodeWithScore, QueryType
|
||||
from llama_index.core.retrievers import BaseRetriever
|
||||
from llama_index.core.schema import BaseNode, NodeWithScore, QueryType
|
||||
|
||||
from metagpt.utils.reflection import check_methods
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""BM25 retriever."""
|
||||
from llama_index.retrievers import BM25Retriever
|
||||
from llama_index.schema import BaseNode
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.retrievers.bm25 import BM25Retriever
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
|
||||
class DynamicBM25Retriever(BM25Retriever):
|
||||
|
|
@ -8,11 +9,6 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs):
|
||||
"""Support add nodes"""
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
raise ImportError("Please install rank_bm25: pip install rank-bm25")
|
||||
|
||||
self._nodes.extend(nodes)
|
||||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""FAISS retriever."""
|
||||
from llama_index.retrievers import VectorIndexRetriever
|
||||
from llama_index.schema import BaseNode
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class FAISSRetriever(VectorIndexRetriever):
|
||||
|
|
|
|||
|
|
@ -1,23 +1,20 @@
|
|||
"""Hybrid retriever."""
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.schema import BaseNode, QueryType
|
||||
import copy
|
||||
|
||||
from llama_index.core.schema import BaseNode, QueryType
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
|
||||
|
||||
class SimpleHybridRetriever(RAGRetriever):
|
||||
"""
|
||||
SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers.
|
||||
"""
|
||||
"""A composite retriever that aggregates search results from multiple retrievers."""
|
||||
|
||||
def __init__(self, *retrievers, service_context: ServiceContext = None):
|
||||
def __init__(self, *retrievers):
|
||||
self.retrievers: list[RAGRetriever] = retrievers
|
||||
self.service_context = service_context
|
||||
super().__init__()
|
||||
|
||||
async def _aretrieve(self, query: QueryType, **kwargs):
|
||||
"""
|
||||
Asynchronously retrieves and aggregates search results from all configured retrievers.
|
||||
"""Asynchronously retrieves and aggregates search results from all configured retrievers.
|
||||
|
||||
This method queries each retriever in the `retrievers` list with the given query and
|
||||
additional keyword arguments. It then combines the results, ensuring that each node is
|
||||
|
|
@ -25,7 +22,9 @@ class SimpleHybridRetriever(RAGRetriever):
|
|||
"""
|
||||
all_nodes = []
|
||||
for retriever in self.retrievers:
|
||||
nodes = await retriever.aretrieve(query, **kwargs)
|
||||
# 防止retriever可能改变query的属性
|
||||
query_copy = copy.deepcopy(query)
|
||||
nodes = await retriever.aretrieve(query_copy, **kwargs)
|
||||
all_nodes.extend(nodes)
|
||||
|
||||
# combine all nodes
|
||||
|
|
|
|||
|
|
@ -1,36 +1,52 @@
|
|||
"""RAG schemas"""
|
||||
"""RAG schemas."""
|
||||
|
||||
from typing import Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RetrieverConfig(BaseModel):
|
||||
"""Common config for retrievers."""
|
||||
class BaseRetrieverConfig(BaseModel):
|
||||
"""Common config for retrievers.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
|
||||
|
||||
|
||||
class FAISSRetrieverConfig(RetrieverConfig):
|
||||
class IndexRetrieverConfig(BaseRetrieverConfig):
|
||||
"""Config for Index-basd retrievers."""
|
||||
|
||||
index: BaseIndex = Field(default=None, description="Index for retriver")
|
||||
|
||||
|
||||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
|
||||
class BM25RetrieverConfig(RetrieverConfig):
|
||||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for BM25-based retrievers."""
|
||||
|
||||
|
||||
class RankerConfig(BaseModel):
|
||||
"""Common config for rankers."""
|
||||
class BaseRankerConfig(BaseModel):
|
||||
"""Common config for rankers.
|
||||
|
||||
top_n: int = 5
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
top_n: int = Field(default=5, description="The number of top results to return.")
|
||||
|
||||
|
||||
class LLMRankerConfig(RankerConfig):
|
||||
class LLMRankerConfig(BaseRankerConfig):
|
||||
"""Config for LLM-based rankers."""
|
||||
|
||||
|
||||
# If add new config, it is necessary to add the corresponding instance implementation in rag.factory
|
||||
RetrieverConfigType = Union[FAISSRetrieverConfig, BM25RetrieverConfig]
|
||||
RankerConfigType = LLMRankerConfig
|
||||
llm: Any = Field(
|
||||
default=None,
|
||||
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,15 @@
|
|||
@Author : alexanderwu
|
||||
@File : embedding.py
|
||||
"""
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_embedding() -> OpenAIEmbedding:
|
||||
llm = config.get_openai_llm()
|
||||
if llm is None:
|
||||
raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.")
|
||||
|
||||
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
|
||||
return embedding
|
||||
|
|
|
|||
|
|
@ -11,7 +11,16 @@ typer==0.9.0
|
|||
# godot==0.1.1
|
||||
# google_api_python_client==2.93.0 # Used by search_engine.py
|
||||
lancedb==0.4.0
|
||||
llama-index==0.9.44
|
||||
llama-index-core==0.10.11.post1
|
||||
llama-index-embeddings-openai==0.1.5
|
||||
llama-index-llms-anthropic==0.1.3
|
||||
llama-index-llms-azure-openai==0.1.4
|
||||
llama-index-llms-gemini==0.1.4
|
||||
llama-index-llms-ollama==0.1.2
|
||||
llama-index-llms-openai==0.1.5
|
||||
llama-index-readers-file==0.1.4
|
||||
llama-index-retrievers-bm25==0.1.3
|
||||
llama-index-vector-stores-faiss==0.1.1
|
||||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy==1.24.3
|
||||
|
|
|
|||
|
|
@ -1,58 +1,75 @@
|
|||
import pytest
|
||||
from llama_index import VectorStoreIndex
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
|
||||
|
||||
class TestSimpleEngine:
|
||||
def test_from_docs(self, mocker):
|
||||
@pytest.fixture
|
||||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_rankers(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_rankers")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_response_synthesizer(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
):
|
||||
# Mock
|
||||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
|
||||
|
||||
mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
|
||||
mock_service_context.return_value = "service_context"
|
||||
|
||||
mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
mock_get_retriever = mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
mock_get_rankers = mocker.patch("metagpt.rag.engines.simple.get_rankers")
|
||||
mock_get_retriever.return_value = mocker.MagicMock()
|
||||
mock_get_rankers.return_value = [mocker.MagicMock()]
|
||||
mock_get_response_synthesizer.return_value = mocker.MagicMock()
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
transformations = [mocker.MagicMock()]
|
||||
embed_model = mocker.MagicMock()
|
||||
llm = mocker.MagicMock()
|
||||
chunk_size = 100
|
||||
chunk_overlap = 10
|
||||
retriever_configs = mocker.MagicMock()
|
||||
ranker_configs = mocker.MagicMock()
|
||||
retriever_configs = [mocker.MagicMock()]
|
||||
ranker_configs = [mocker.MagicMock()]
|
||||
|
||||
# Execute
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_service_context.assert_called_once_with(
|
||||
embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
)
|
||||
mock_vector_store_index.assert_called_once_with(
|
||||
["document1", "document2"], service_context=mock_service_context.return_value
|
||||
)
|
||||
mock_get_retriever.assert_called_once_with(mock_vector_store_index.return_value, configs=retriever_configs)
|
||||
mock_get_rankers.assert_called_once_with(
|
||||
configs=ranker_configs, service_context=mock_service_context.return_value
|
||||
)
|
||||
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -100,8 +117,12 @@ class TestSimpleEngine:
|
|||
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"]
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
|
|
@ -113,3 +134,27 @@ class TestSimpleEngine:
|
|||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
|
||||
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
|
||||
|
||||
def test_add_objs(self, mocker):
|
||||
# Mock
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
# Setup
|
||||
class CustomTextNode(TextNode):
|
||||
def rag_key(self):
|
||||
return ""
|
||||
|
||||
def model_dump(self):
|
||||
return {}
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
|
||||
# Execute
|
||||
engine.add_objs(objs=objs)
|
||||
|
||||
# Assertions
|
||||
assert mock_retriever.add_nodes.call_count == 1
|
||||
for node in mock_retriever.add_nodes.call_args[0][0]:
|
||||
assert isinstance(node, TextNode)
|
||||
assert "obj" in node.metadata
|
||||
102
tests/metagpt/rag/factories/test_base.py
Normal file
102
tests/metagpt/rag/factories/test_base.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.rag.factories.base import ConfigFactory, GenericFactory
|
||||
|
||||
|
||||
class TestGenericFactory:
|
||||
@pytest.fixture
|
||||
def creators(self):
|
||||
return {
|
||||
"type1": lambda name: f"Instance of type1 with {name}",
|
||||
"type2": lambda name: f"Instance of type2 with {name}",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self, creators):
|
||||
return GenericFactory(creators=creators)
|
||||
|
||||
def test_get_instance_success(self, factory):
|
||||
# Test successful retrieval of an instance
|
||||
key = "type1"
|
||||
instance = factory.get_instance(key, name="TestName")
|
||||
assert instance == "Instance of type1 with TestName"
|
||||
|
||||
def test_get_instance_failure(self, factory):
|
||||
# Test failure to retrieve an instance due to unregistered key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
factory.get_instance("unknown_key")
|
||||
assert "Creator not registered for key: unknown_key" in str(exc_info.value)
|
||||
|
||||
def test_get_instances_success(self, factory):
|
||||
# Test successful retrieval of multiple instances
|
||||
keys = ["type1", "type2"]
|
||||
instances = factory.get_instances(keys, name="TestName")
|
||||
expected = ["Instance of type1 with TestName", "Instance of type2 with TestName"]
|
||||
assert instances == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"keys,expected_exception_message",
|
||||
[
|
||||
(["unknown_key"], "Creator not registered for key: unknown_key"),
|
||||
(["type1", "unknown_key"], "Creator not registered for key: unknown_key"),
|
||||
],
|
||||
)
|
||||
def test_get_instances_with_failure(self, factory, keys, expected_exception_message):
|
||||
# Test failure to retrieve instances due to at least one unregistered key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
factory.get_instances(keys, name="TestName")
|
||||
assert expected_exception_message in str(exc_info.value)
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
"""A dummy config class for testing."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestConfigFactory:
|
||||
@pytest.fixture
|
||||
def config_creators(self):
|
||||
return {
|
||||
DummyConfig: lambda config, **kwargs: f"Processed {config.name} with {kwargs.get('extra', 'no extra')}",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def config_factory(self, config_creators):
|
||||
return ConfigFactory(creators=config_creators)
|
||||
|
||||
def test_get_instance_success(self, config_factory):
|
||||
# Test successful retrieval of an instance
|
||||
config = DummyConfig(name="TestConfig")
|
||||
instance = config_factory.get_instance(config, extra="additional data")
|
||||
assert instance == "Processed TestConfig with additional data"
|
||||
|
||||
def test_get_instance_failure(self, config_factory):
|
||||
# Test failure to retrieve an instance due to unknown config type
|
||||
class UnknownConfig:
|
||||
pass
|
||||
|
||||
config = UnknownConfig()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config_factory.get_instance(config)
|
||||
assert "Unknown config:" in str(exc_info.value)
|
||||
|
||||
def test_val_from_config_or_kwargs_priority(self):
|
||||
# Test that the value from the config object has priority over kwargs
|
||||
config = DummyConfig(name="ConfigName")
|
||||
result = ConfigFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
|
||||
assert result == "ConfigName"
|
||||
|
||||
def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
|
||||
# Test fallback to kwargs when config object does not have the value
|
||||
config = DummyConfig(name=None)
|
||||
result = ConfigFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
|
||||
assert result == "KwargsName"
|
||||
|
||||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
56
tests/metagpt/rag/factories/test_llm.py
Normal file
56
tests/metagpt/rag/factories/test_llm.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import pytest
|
||||
from llama_index.llms.anthropic import Anthropic
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
from llama_index.llms.gemini import Gemini
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.llm import RAGLLMFactory
|
||||
|
||||
|
||||
class TestRAGLLMFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# Mock the config object for all tests in this class
|
||||
self.mock_config = mocker.MagicMock()
|
||||
self.mock_config.llm.api_type = LLMType.OPENAI
|
||||
self.mock_config.llm.base_url = "http://example.com"
|
||||
self.mock_config.llm.api_key = "test_api_key"
|
||||
self.mock_config.llm.api_version = "v1"
|
||||
self.mock_config.llm.model = "test_model"
|
||||
self.mock_config.llm.max_token = 100
|
||||
self.mock_config.llm.temperature = 0.5
|
||||
mocker.patch("metagpt.rag.factories.llm.config", self.mock_config)
|
||||
self.factory = RAGLLMFactory()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_type,expected_class",
|
||||
[
|
||||
(LLMType.OPENAI, OpenAI),
|
||||
(LLMType.AZURE, AzureOpenAI),
|
||||
(LLMType.ANTHROPIC, Anthropic),
|
||||
(LLMType.GEMINI, Gemini),
|
||||
(LLMType.OLLAMA, Ollama),
|
||||
],
|
||||
)
|
||||
def test_creates_correct_llm_instance(self, llm_type, expected_class, mocker):
|
||||
# Mock the LLM constructors
|
||||
mocker.patch.object(expected_class, "__init__", return_value=None)
|
||||
instance = self.factory.get_rag_llm(key=llm_type)
|
||||
assert isinstance(instance, expected_class)
|
||||
expected_class.__init__.assert_called_once()
|
||||
|
||||
def test_uses_default_llm_type_when_no_key_provided(self, mocker):
|
||||
# Assume the default API type is OPENAI for this test
|
||||
mock = mocker.patch.object(OpenAI, "__init__", return_value=None)
|
||||
instance = self.factory.get_rag_llm()
|
||||
assert isinstance(instance, OpenAI)
|
||||
mock.assert_called_once_with(
|
||||
api_base=self.mock_config.llm.base_url,
|
||||
api_key=self.mock_config.llm.api_key,
|
||||
api_version=self.mock_config.llm.api_version,
|
||||
model=self.mock_config.llm.model,
|
||||
max_tokens=self.mock_config.llm.max_token,
|
||||
temperature=self.mock_config.llm.temperature,
|
||||
)
|
||||
43
tests/metagpt/rag/factories/test_ranker.py
Normal file
43
tests/metagpt/rag/factories/test_ranker.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.factories.ranker import RankerFactory
|
||||
from metagpt.rag.schema import LLMRankerConfig
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def ranker_factory(self) -> RankerFactory:
|
||||
return RankerFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self, mocker):
|
||||
return mocker.MagicMock(spec=LLM)
|
||||
|
||||
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
|
||||
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = ranker_factory.get_rankers()
|
||||
assert len(default_rankers) == 1
|
||||
assert isinstance(default_rankers[0], LLMRerank)
|
||||
ranker_factory._extract_llm.assert_called_once()
|
||||
|
||||
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
rankers = ranker_factory.get_rankers(configs=[mock_config])
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
ranker = ranker_factory._create_llm_ranker(mock_config)
|
||||
assert isinstance(ranker, LLMRerank)
|
||||
|
||||
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
extracted_llm = ranker_factory._extract_llm(config=mock_config)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
|
||||
assert extracted_llm == mock_llm
|
||||
79
tests/metagpt/rag/factories/test_retriever.py
Normal file
79
tests/metagpt/rag/factories/test_retriever.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
from metagpt.rag.factories.retriever import RetrieverFactory
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
@pytest.fixture
|
||||
def retriever_factory(self):
|
||||
return RetrieverFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_index(self, mocker):
|
||||
return mocker.MagicMock(spec=faiss.IndexFlatL2)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
mock = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock._embed_model = mocker.MagicMock()
|
||||
mock.docstore.docs.values.return_value = []
|
||||
return mock
|
||||
|
||||
def test_get_retriever_with_faiss_config(
|
||||
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
|
||||
):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(
|
||||
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
|
||||
):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mock_vector_store_index.as_retriever = mocker.MagicMock()
|
||||
|
||||
retriever = retriever_factory.get_retriever()
|
||||
|
||||
mock_vector_store_index.as_retriever.assert_called_once()
|
||||
assert retriever is mock_vector_store_index.as_retriever.return_value
|
||||
|
||||
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
|
||||
|
||||
extracted_index = retriever_factory._extract_index(config=mock_config)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from llama_index.schema import Node
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ class TestDynamicBM25Retriever:
|
|||
# 模拟nodes和tokenizer参数
|
||||
mock_nodes = []
|
||||
mock_tokenizer = mocker.MagicMock()
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi")
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
|
||||
# 初始化DynamicBM25Retriever对象,并提供必需的参数
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from llama_index.schema import Node
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from llama_index.schema import NodeWithScore, TextNode
|
||||
from llama_index.core.schema import NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
|
||||
|
|
|
|||
|
|
@ -1,130 +0,0 @@
|
|||
import pytest
|
||||
from llama_index import ServiceContext
|
||||
from llama_index.indices.base import BaseIndex
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.factory import RankerFactory, RetrieverFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
@pytest.fixture
|
||||
def mock_base_index(self, mocker):
|
||||
mock = mocker.MagicMock(spec=BaseIndex)
|
||||
mock.as_retriever.return_value = mocker.MagicMock(spec=RAGRetriever)
|
||||
mock.service_context = mocker.MagicMock()
|
||||
mock.docstore.docs.values.return_value = []
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_retriever_config(self):
|
||||
return FAISSRetrieverConfig(dimensions=128)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bm25_retriever_config(self):
|
||||
return BM25RetrieverConfig()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_vector_store(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factory.FaissVectorStore")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_context(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factory.StorageContext")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factory.VectorStoreIndex")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dynamic_bm25_retriever(self, mocker):
|
||||
mock = mocker.MagicMock(spec=DynamicBM25Retriever)
|
||||
return mocker.patch("metagpt.rag.factory.DynamicBM25Retriever", mock)
|
||||
|
||||
def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index):
|
||||
factory = RetrieverFactory()
|
||||
retriever = factory.get_retriever(index=mock_base_index)
|
||||
assert isinstance(retriever, RAGRetriever)
|
||||
|
||||
def test_get_retriever_with_specific_config_returns_correct_retriever(
|
||||
self,
|
||||
mock_base_index,
|
||||
mock_faiss_retriever_config,
|
||||
mock_faiss_vector_store,
|
||||
mock_storage_context,
|
||||
mock_vector_store_index,
|
||||
):
|
||||
factory = RetrieverFactory()
|
||||
retriever = factory.get_retriever(index=mock_base_index, configs=[mock_faiss_retriever_config])
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid_retriever(
|
||||
self,
|
||||
mock_base_index,
|
||||
mock_faiss_retriever_config,
|
||||
mock_bm25_retriever_config,
|
||||
mock_faiss_vector_store,
|
||||
mock_storage_context,
|
||||
mock_vector_store_index,
|
||||
mock_dynamic_bm25_retriever,
|
||||
):
|
||||
factory = RetrieverFactory()
|
||||
retriever = factory.get_retriever(
|
||||
index=mock_base_index, configs=[mock_faiss_retriever_config, mock_bm25_retriever_config]
|
||||
)
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_get_retriever_with_unknown_config_raises_value_error(self, mock_base_index, mocker):
|
||||
mock_unknown_config = mocker.MagicMock()
|
||||
factory = RetrieverFactory()
|
||||
with pytest.raises(ValueError):
|
||||
factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config])
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def mock_service_context(self, mocker):
|
||||
return mocker.MagicMock(spec=ServiceContext)
|
||||
|
||||
def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context):
|
||||
# Setup
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute
|
||||
rankers = factory.get_rankers(service_context=mock_service_context)
|
||||
|
||||
# Assertions
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context):
|
||||
# Setup
|
||||
config = LLMRankerConfig(top_n=3)
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute
|
||||
rankers = factory.get_rankers(configs=[config], service_context=mock_service_context)
|
||||
|
||||
# Assertions
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
assert rankers[0].top_n == 3
|
||||
|
||||
def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context):
|
||||
# Mock
|
||||
mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型
|
||||
|
||||
# Setup
|
||||
factory = RankerFactory()
|
||||
|
||||
# Execute & Assertions
|
||||
with pytest.raises(ValueError):
|
||||
factory.get_rankers(configs=[mock_config], service_context=mock_service_context)
|
||||
Loading…
Add table
Add a link
Reference in a new issue