diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 675fe62f1..70c592a1e 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -16,6 +16,8 @@ QUESTION = "What are key qualities to be a good writer?" class RAGExample: + """Show how to use RAG.""" + def __init__(self): self.engine = SimpleEngine.from_docs( input_files=[DOC_PATH], @@ -84,14 +86,17 @@ class RAGExample: {'name': 'foo', 'goal': 'Win The Game', 'tool': 'Red Bull Energy Drink'} """ - self._print_title("RAG Add Docs") + self._print_title("RAG Add Objs") class Player(BaseModel): + """Player""" + name: str = "" goal: str = "Win The Game" tool: str = "Red Bull Energy Drink" def rag_key(self) -> str: + """For search""" return self.goal foo = Player(name="foo") diff --git a/metagpt/document.py b/metagpt/document.py index be238621c..4a8bb68d5 100644 --- a/metagpt/document.py +++ b/metagpt/document.py @@ -11,8 +11,9 @@ from pathlib import Path from typing import Optional, Union import pandas as pd -from llama_index.node_parser import SimpleNodeParser -from llama_index.readers import Document, PDFReader, SimpleDirectoryReader +from llama_index.core import Document, SimpleDirectoryReader +from llama_index.core.node_parser import SimpleNodeParser +from llama_index.readers.file import PDFReader from pydantic import BaseModel, ConfigDict, Field from tqdm import tqdm diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index 2136e49db..f8ce05072 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -10,11 +10,11 @@ from pathlib import Path from typing import Any, Optional import faiss -from llama_index import VectorStoreIndex, load_index_from_storage -from llama_index.embeddings import BaseEmbedding -from llama_index.schema import Document, QueryBundle, TextNode -from llama_index.storage import StorageContext -from llama_index.vector_stores import FaissVectorStore +from llama_index.core import VectorStoreIndex, load_index_from_storage +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.schema import Document, QueryBundle, TextNode +from llama_index.core.storage import StorageContext +from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.document import IndexableDocument from metagpt.document_store.base_store import LocalStore diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index d48fc8a1a..ca09f1059 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -3,22 +3,32 @@ from typing import Optional -from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex -from llama_index.callbacks.base import CallbackManager -from llama_index.core.base_retriever import BaseRetriever -from llama_index.embeddings.base import BaseEmbedding -from llama_index.indices.base import BaseIndex -from llama_index.llms.llm import LLM -from llama_index.postprocessor.types import BaseNodePostprocessor -from llama_index.query_engine import RetrieverQueryEngine -from llama_index.response_synthesizers import BaseSynthesizer -from llama_index.schema import NodeWithScore, QueryBundle, QueryType, TextNode +from llama_index.core import SimpleDirectoryReader, VectorStoreIndex +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.indices.base import BaseIndex +from llama_index.core.ingestion.pipeline import run_transformations +from llama_index.core.llms import LLM +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.response_synthesizers import ( + BaseSynthesizer, + get_response_synthesizer, +) +from llama_index.core.retrievers import BaseRetriever +from llama_index.core.schema import ( + NodeWithScore, + QueryBundle, + QueryType, + TextNode, + TransformComponent, +) -from metagpt.rag.factory import get_rankers, get_retriever +from metagpt.rag.factories import get_rag_llm, get_rankers, get_retriever from metagpt.rag.interface import RAGObject -from metagpt.rag.llm import get_default_llm from metagpt.rag.retrievers.base import ModifiableRAGRetriever -from metagpt.rag.schema import RankerConfigType, RetrieverConfigType +from metagpt.rag.schema import BaseRankerConfig, BaseRetrieverConfig from metagpt.utils.embedding import get_embedding @@ -51,45 +61,47 @@ class SimpleEngine(RetrieverQueryEngine): cls, input_dir: str = None, input_files: list[str] = None, - llm: LLM = None, + transformations: Optional[list[TransformComponent]] = None, embed_model: BaseEmbedding = None, - chunk_size: int = None, - chunk_overlap: int = None, - retriever_configs: list[RetrieverConfigType] = None, - ranker_configs: list[RankerConfigType] = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": """This engine is designed to be simple and straightforward Args: input_dir: Path to the directory. input_files: List of file paths to read (Optional; overrides input_dir, exclude). - llm: Must supported by llama index. - embed_model: Must supported by llama index. - chunk_size: The size of text chunks (in tokens) to split documents into for embedding. - chunk_overlap: The number of tokens for overlapping between consecutive chunks. Helps in maintaining context continuity. + transformations: Parse documents to nodes. Default [SentenceSplitter]. + embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding. + llm: Must supported by llama index. Default OpenAI. retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever. ranker_configs: Configuration for rankers. """ documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() - service_context = ServiceContext.from_defaults( - llm=llm or get_default_llm(), + index = VectorStoreIndex.from_documents( + documents=documents, + transformations=transformations or [SentenceSplitter()], embed_model=embed_model or get_embedding(), - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, ) - index = VectorStoreIndex.from_documents(documents, service_context=service_context) - retriever = get_retriever(index, configs=retriever_configs) - rankers = get_rankers(configs=ranker_configs, service_context=service_context) + llm = llm or get_rag_llm() + retriever = get_retriever(configs=retriever_configs, index=index) + rankers = get_rankers(configs=ranker_configs, llm=llm) - return cls(retriever=retriever, node_postprocessors=rankers, index=index) + return cls( + retriever=retriever, + node_postprocessors=rankers, + response_synthesizer=get_response_synthesizer(llm=llm), + index=index, + ) async def asearch(self, content: str, **kwargs) -> str: """Inplement tools.SearchInterface""" return await self.aquery(content) - async def aretrieve(self, query_bundle: QueryType) -> list[NodeWithScore]: + async def aretrieve(self, query: QueryType) -> list[NodeWithScore]: """Allow query to be str""" - query_bundle = QueryBundle(query_bundle) if isinstance(query_bundle, str) else query_bundle + query_bundle = QueryBundle(query) if isinstance(query, str) else query return await super().aretrieve(query_bundle) def add_docs(self, input_files: list[str]): @@ -97,7 +109,7 @@ class SimpleEngine(RetrieverQueryEngine): self._ensure_retriever_modifiable() documents = SimpleDirectoryReader(input_files=input_files).load_data() - nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents) + nodes = run_transformations(documents, transformations=self.index._transformations) self.retriever.add_nodes(nodes) def add_objs(self, objs: list[RAGObject]): diff --git a/metagpt/rag/factories/__init__.py b/metagpt/rag/factories/__init__.py new file mode 100644 index 000000000..74290fd69 --- /dev/null +++ b/metagpt/rag/factories/__init__.py @@ -0,0 +1,6 @@ +"""RAG factories""" +from metagpt.rag.factories.retriever import get_retriever +from metagpt.rag.factories.ranker import get_rankers +from metagpt.rag.factories.llm import get_rag_llm + +__all__ = ["get_retriever", "get_rankers", "get_rag_llm"] diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py new file mode 100644 index 000000000..5d27eb273 --- /dev/null +++ b/metagpt/rag/factories/base.py @@ -0,0 +1,58 @@ +"""Base Factory.""" +from typing import Any, Callable + + +class GenericFactory: + """Designed to get objects based on any keys.""" + + def __init__(self, creators: dict[Any, Callable] = None): + """Creators is a dictionary. + + Keys are identifiers, and the values are the associated creator function, which create objects. + """ + self._creators = creators or {} + + def get_instances(self, keys: list[Any], **kwargs) -> list[Any]: + """Get instances by keys.""" + return [self.get_instance(key, **kwargs) for key in keys] + + def get_instance(self, key: Any, **kwargs) -> Any: + """Get instance by key. + + Raise Exception if key not found. + """ + creator = self._creators.get(key) + if creator: + return creator(**kwargs) + + raise ValueError(f"Creator not registered for key: {key}") + + +class ConfigFactory(GenericFactory): + """Designed to get objects based on object type.""" + + def get_instance(self, key: Any, **kwargs) -> Any: + """Key is config, such as a pydantic model. + + Call func by the type of key, and the key will be passed to func. + """ + creator = self._creators.get(type(key)) + if creator: + return creator(key, **kwargs) + + raise ValueError(f"Unknown config: {key}") + + @staticmethod + def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any: + """It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.""" + if config is not None and hasattr(config, key): + val = getattr(config, key) + if val is not None: + return val + + if key in kwargs: + return kwargs[key] + + raise KeyError( + f"The key '{key}' is required but not provided in either configuration object or keyword arguments." + ) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py new file mode 100644 index 000000000..b551532d4 --- /dev/null +++ b/metagpt/rag/factories/llm.py @@ -0,0 +1,76 @@ +"""RAG LLM Factory. + +The LLM of LlamaIndex and the LLM of MG are not the same. +""" +from llama_index.core.llms import LLM +from llama_index.llms.anthropic import Anthropic +from llama_index.llms.azure_openai import AzureOpenAI +from llama_index.llms.gemini import Gemini +from llama_index.llms.ollama import Ollama +from llama_index.llms.openai import OpenAI + +from metagpt.config2 import config +from metagpt.configs.llm_config import LLMType +from metagpt.rag.factories.base import GenericFactory + + +class RAGLLMFactory(GenericFactory): + """Create LlamaIndex LLM with MG config.""" + + def __init__(self): + creators = { + LLMType.OPENAI: self._create_openai, + LLMType.AZURE: self._create_azure, + LLMType.ANTHROPIC: self._create_anthropic, + LLMType.GEMINI: self._create_gemini, + LLMType.OLLAMA: self._create_ollama, + } + super().__init__(creators) + + def get_rag_llm(self, key: LLMType = None) -> LLM: + """Key is LLMType, default use config.llm.api_type.""" + return super().get_instance(key or config.llm.api_type) + + def _create_openai(self): + return OpenAI( + api_base=config.llm.base_url, + api_key=config.llm.api_key, + api_version=config.llm.api_version, + model=config.llm.model, + max_tokens=config.llm.max_token, + temperature=config.llm.temperature, + ) + + def _create_azure(self): + return AzureOpenAI( + azure_endpoint=config.llm.base_url, + api_key=config.llm.api_key, + api_version=config.llm.api_version, + model=config.llm.model, + max_tokens=config.llm.max_token, + temperature=config.llm.temperature, + ) + + def _create_anthropic(self): + return Anthropic( + base_url=config.llm.base_url, + api_key=config.llm.api_key, + model=config.llm.model, + max_tokens=config.llm.max_token, + temperature=config.llm.temperature, + ) + + def _create_gemini(self): + return Gemini( + api_base=config.llm.base_url, + api_key=config.llm.api_key, + model_name=config.llm.model, + max_tokens=config.llm.max_token, + temperature=config.llm.temperature, + ) + + def _create_ollama(self): + return Ollama(base_url=config.llm.base_url, model=config.llm.model, temperature=config.llm.temperature) + + +get_rag_llm = RAGLLMFactory().get_rag_llm diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py new file mode 100644 index 000000000..f74e30834 --- /dev/null +++ b/metagpt/rag/factories/ranker.py @@ -0,0 +1,39 @@ +"""RAG Ranker Factory.""" + +from llama_index.core.llms import LLM +from llama_index.core.postprocessor import LLMRerank +from llama_index.core.postprocessor.types import BaseNodePostprocessor + +from metagpt.rag.factories.base import ConfigFactory +from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig + + +class RankerFactory(ConfigFactory): + """Modify creators for dynamically instance implementation.""" + + def __init__(self): + creators = { + LLMRankerConfig: self._create_llm_ranker, + } + super().__init__(creators) + + def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]: + """Creates and returns a retriever instance based on the provided configurations.""" + if not configs: + return self._create_default(**kwargs) + + return super().get_instances(configs, **kwargs) + + def _create_default(self, **kwargs) -> list[LLMRerank]: + config = LLMRankerConfig(llm=self._extract_llm(**kwargs)) + return [LLMRerank(**config.model_dump())] + + def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: + return self._val_from_config_or_kwargs("llm", config, **kwargs) + + def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank: + config.llm = self._extract_llm(config, **kwargs) + return LLMRerank(**config.model_dump()) + + +get_rankers = RankerFactory().get_rankers diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py new file mode 100644 index 000000000..44678fc92 --- /dev/null +++ b/metagpt/rag/factories/retriever.py @@ -0,0 +1,64 @@ +"""RAG Retriever Factory.""" + +import faiss +from llama_index.core import StorageContext, VectorStoreIndex +from llama_index.vector_stores.faiss import FaissVectorStore + +from metagpt.rag.factories.base import ConfigFactory +from metagpt.rag.retrievers.base import RAGRetriever +from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever +from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever +from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever +from metagpt.rag.schema import ( + BaseRetrieverConfig, + BM25RetrieverConfig, + FAISSRetrieverConfig, +) + + +class RetrieverFactory(ConfigFactory): + """Modify creators for dynamically instance implementation.""" + + def __init__(self): + creators = { + FAISSRetrieverConfig: self._create_faiss_retriever, + BM25RetrieverConfig: self._create_bm25_retriever, + } + super().__init__(creators) + + def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> RAGRetriever: + """Creates and returns a retriever instance based on the provided configurations. + + If multiple retrievers, using SimpleHybridRetriever. + """ + if not configs: + return self._create_default(**kwargs) + + retrievers = super().get_instances(configs, **kwargs) + + return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0] + + def _create_default(self, **kwargs) -> RAGRetriever: + return self._extract_index(**kwargs).as_retriever() + + def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: + return self._val_from_config_or_kwargs("index", config, **kwargs) + + def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + old_index = self._extract_index(config, **kwargs) + new_index = VectorStoreIndex( + nodes=list(old_index.docstore.docs.values()), + storage_context=storage_context, + embed_model=old_index._embed_model, + ) + config.index = new_index + return FAISSRetriever(**config.model_dump()) + + def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: + config.index = self._extract_index(config, **kwargs) + return DynamicBM25Retriever.from_defaults(**config.model_dump()) + + +get_retriever = RetrieverFactory().get_retriever diff --git a/metagpt/rag/factory.py b/metagpt/rag/factory.py deleted file mode 100644 index 04543f57e..000000000 --- a/metagpt/rag/factory.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Factory for creating retriever, ranker""" -from typing import Any, Callable - -import faiss -from llama_index import ServiceContext, StorageContext, VectorStoreIndex -from llama_index.indices.base import BaseIndex -from llama_index.postprocessor import LLMRerank -from llama_index.postprocessor.types import BaseNodePostprocessor -from llama_index.vector_stores.faiss import FaissVectorStore - -from metagpt.rag.retrievers.base import RAGRetriever -from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever -from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever -from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever -from metagpt.rag.schema import ( - BM25RetrieverConfig, - FAISSRetrieverConfig, - LLMRankerConfig, - RankerConfigType, - RetrieverConfigType, -) - - -class BaseFactory: - """ - A base factory class for creating instances based on provided configurations. - It uses a registry of creator functions mapped to configuration types to instantiate objects dynamically. - """ - - def __init__(self, creators: dict[Any, Callable]): - """Creators is a dictionary mapping configuration types to creator functions.""" - self.creators = creators - - def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]: - """Get instances by configs""" - return [self._get_instance(config, **kwargs) for config in configs] - - def _get_instance(self, config: Any, **kwargs) -> Any: - create_func = self.creators.get(type(config)) - if create_func: - return create_func(config, **kwargs) - - raise ValueError(f"Unknown config: {config}") - - -class RetrieverFactory(BaseFactory): - """Modify creators for dynamically instance implementation""" - - def __init__(self): - creators = { - FAISSRetrieverConfig: self._create_faiss_retriever, - BM25RetrieverConfig: self._create_bm25_retriever, - } - super().__init__(creators) - - def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever: - """Creates and returns a retriever instance based on the provided configurations. - If multiple retrievers, using SimpleHybridRetriever - """ - if not configs: - return self._default_instance(index) - - retrievers = super().get_instances(configs, index=index) - - return ( - SimpleHybridRetriever(*retrievers, service_context=index.service_context) - if len(retrievers) > 1 - else retrievers[0] - ) - - def _default_instance(self, index: BaseIndex) -> RAGRetriever: - return index.as_retriever() - - def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex) -> FAISSRetriever: - vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) - storage_context = StorageContext.from_defaults(vector_store=vector_store) - vector_index = VectorStoreIndex( - nodes=list(index.docstore.docs.values()), - storage_context=storage_context, - service_context=index.service_context, - ) - return FAISSRetriever(**config.model_dump(), index=vector_index) - - def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex) -> DynamicBM25Retriever: - return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index) - - -class RankerFactory(BaseFactory): - """Modify creators for dynamically instance implementation""" - - def __init__(self): - creators = { - LLMRankerConfig: self._create_llm_ranker, - } - super().__init__(creators) - - def get_rankers( - self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None - ) -> list[BaseNodePostprocessor]: - """Creates and returns a retriever instance based on the provided configurations.""" - if not configs: - return [self._default_instance(service_context)] - - return super().get_instances(configs, service_context=service_context) - - def _default_instance(self, service_context: ServiceContext) -> LLMRerank: - return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context) - - def _create_llm_ranker(self, config: LLMRankerConfig, service_context: ServiceContext = None) -> LLMRerank: - return LLMRerank(**config.model_dump(), service_context=service_context) - - -get_retriever = RetrieverFactory().get_retriever -get_rankers = RankerFactory().get_rankers diff --git a/metagpt/rag/interface.py b/metagpt/rag/interface.py index 97faf9f01..8039e76d5 100644 --- a/metagpt/rag/interface.py +++ b/metagpt/rag/interface.py @@ -1,14 +1,15 @@ -"""RAG Interface.""" +"""RAG Interfaces.""" from typing import Any, Protocol class RAGObject(Protocol): - """Support rag add object""" + """Support rag add object.""" def rag_key(self) -> str: """For rag search.""" def model_dump(self) -> dict[str, Any]: """For rag persist. - Pydantic Model don't need to implement this, as there is a built-in function named model_dump + + Pydantic Model don't need to implement this, as there is a built-in function named model_dump. """ diff --git a/metagpt/rag/llm.py b/metagpt/rag/llm.py deleted file mode 100644 index 83b3a849d..000000000 --- a/metagpt/rag/llm.py +++ /dev/null @@ -1,11 +0,0 @@ -"""RAG LLM -The LLM of LlamaIndex and the LLM of MG are not the same. -""" -from llama_index.llms import OpenAI - -from metagpt.config2 import config - - -def get_default_llm() -> OpenAI: - """OpenAI""" - return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key, model=config.llm.model) diff --git a/metagpt/rag/rankers/base.py b/metagpt/rag/rankers/base.py index 482fc4aef..ecb23cf3e 100644 --- a/metagpt/rag/rankers/base.py +++ b/metagpt/rag/rankers/base.py @@ -4,8 +4,8 @@ from abc import abstractmethod from typing import Optional from llama_index import QueryBundle -from llama_index.postprocessor.types import BaseNodePostprocessor -from llama_index.schema import NodeWithScore +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.schema import NodeWithScore class RAGRanker(BaseNodePostprocessor): diff --git a/metagpt/rag/retrievers/base.py b/metagpt/rag/retrievers/base.py index f89a078ca..87d678809 100644 --- a/metagpt/rag/retrievers/base.py +++ b/metagpt/rag/retrievers/base.py @@ -3,8 +3,8 @@ from abc import abstractmethod -from llama_index.retrievers import BaseRetriever -from llama_index.schema import BaseNode, NodeWithScore, QueryType +from llama_index.core.retrievers import BaseRetriever +from llama_index.core.schema import BaseNode, NodeWithScore, QueryType from metagpt.utils.reflection import check_methods diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index dc8d59802..c451e98fd 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -1,6 +1,7 @@ """BM25 retriever.""" -from llama_index.retrievers import BM25Retriever -from llama_index.schema import BaseNode +from llama_index.core.schema import BaseNode +from llama_index.retrievers.bm25 import BM25Retriever +from rank_bm25 import BM25Okapi class DynamicBM25Retriever(BM25Retriever): @@ -8,11 +9,6 @@ class DynamicBM25Retriever(BM25Retriever): def add_nodes(self, nodes: list[BaseNode], **kwargs): """Support add nodes""" - try: - from rank_bm25 import BM25Okapi - except ImportError: - raise ImportError("Please install rank_bm25: pip install rank-bm25") - self._nodes.extend(nodes) self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes] self.bm25 = BM25Okapi(self._corpus) diff --git a/metagpt/rag/retrievers/faiss_retriever.py b/metagpt/rag/retrievers/faiss_retriever.py index a898d0292..8c1bc8f8a 100644 --- a/metagpt/rag/retrievers/faiss_retriever.py +++ b/metagpt/rag/retrievers/faiss_retriever.py @@ -1,6 +1,6 @@ """FAISS retriever.""" -from llama_index.retrievers import VectorIndexRetriever -from llama_index.schema import BaseNode +from llama_index.core.retrievers import VectorIndexRetriever +from llama_index.core.schema import BaseNode class FAISSRetriever(VectorIndexRetriever): diff --git a/metagpt/rag/retrievers/hybrid_retriever.py b/metagpt/rag/retrievers/hybrid_retriever.py index d514194c9..3074a4053 100644 --- a/metagpt/rag/retrievers/hybrid_retriever.py +++ b/metagpt/rag/retrievers/hybrid_retriever.py @@ -1,23 +1,20 @@ """Hybrid retriever.""" -from llama_index import ServiceContext -from llama_index.schema import BaseNode, QueryType +import copy + +from llama_index.core.schema import BaseNode, QueryType from metagpt.rag.retrievers.base import RAGRetriever class SimpleHybridRetriever(RAGRetriever): - """ - SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers. - """ + """A composite retriever that aggregates search results from multiple retrievers.""" - def __init__(self, *retrievers, service_context: ServiceContext = None): + def __init__(self, *retrievers): self.retrievers: list[RAGRetriever] = retrievers - self.service_context = service_context super().__init__() async def _aretrieve(self, query: QueryType, **kwargs): - """ - Asynchronously retrieves and aggregates search results from all configured retrievers. + """Asynchronously retrieves and aggregates search results from all configured retrievers. This method queries each retriever in the `retrievers` list with the given query and additional keyword arguments. It then combines the results, ensuring that each node is @@ -25,7 +22,9 @@ class SimpleHybridRetriever(RAGRetriever): """ all_nodes = [] for retriever in self.retrievers: - nodes = await retriever.aretrieve(query, **kwargs) + # 防止retriever可能改变query的属性 + query_copy = copy.deepcopy(query) + nodes = await retriever.aretrieve(query_copy, **kwargs) all_nodes.extend(nodes) # combine all nodes diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 1e3d945f2..c74846cb6 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,36 +1,52 @@ -"""RAG schemas""" +"""RAG schemas.""" -from typing import Union +from typing import Any -from pydantic import BaseModel, Field +from llama_index.core.indices.base import BaseIndex +from pydantic import BaseModel, ConfigDict, Field -class RetrieverConfig(BaseModel): - """Common config for retrievers.""" +class BaseRetrieverConfig(BaseModel): + """Common config for retrievers. + If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.") -class FAISSRetrieverConfig(RetrieverConfig): +class IndexRetrieverConfig(BaseRetrieverConfig): + """Config for Index-basd retrievers.""" + + index: BaseIndex = Field(default=None, description="Index for retriver") + + +class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.") -class BM25RetrieverConfig(RetrieverConfig): +class BM25RetrieverConfig(IndexRetrieverConfig): """Config for BM25-based retrievers.""" -class RankerConfig(BaseModel): - """Common config for rankers.""" +class BaseRankerConfig(BaseModel): + """Common config for rankers. - top_n: int = 5 + If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + top_n: int = Field(default=5, description="The number of top results to return.") -class LLMRankerConfig(RankerConfig): +class LLMRankerConfig(BaseRankerConfig): """Config for LLM-based rankers.""" - -# If add new config, it is necessary to add the corresponding instance implementation in rag.factory -RetrieverConfigType = Union[FAISSRetrieverConfig, BM25RetrieverConfig] -RankerConfigType = LLMRankerConfig + llm: Any = Field( + default=None, + description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1", + ) diff --git a/metagpt/utils/embedding.py b/metagpt/utils/embedding.py index 3b5465f99..3d53a314c 100644 --- a/metagpt/utils/embedding.py +++ b/metagpt/utils/embedding.py @@ -5,12 +5,15 @@ @Author : alexanderwu @File : embedding.py """ -from llama_index.embeddings import OpenAIEmbedding +from llama_index.embeddings.openai import OpenAIEmbedding from metagpt.config2 import config def get_embedding() -> OpenAIEmbedding: llm = config.get_openai_llm() + if llm is None: + raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.") + embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url) return embedding diff --git a/requirements.txt b/requirements.txt index d6651bba2..54583129c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,16 @@ typer==0.9.0 # godot==0.1.1 # google_api_python_client==2.93.0 # Used by search_engine.py lancedb==0.4.0 -llama-index==0.9.44 +llama-index-core==0.10.11.post1 +llama-index-embeddings-openai==0.1.5 +llama-index-llms-anthropic==0.1.3 +llama-index-llms-azure-openai==0.1.4 +llama-index-llms-gemini==0.1.4 +llama-index-llms-ollama==0.1.2 +llama-index-llms-openai==0.1.5 +llama-index-readers-file==0.1.4 +llama-index-retrievers-bm25==0.1.3 +llama-index-vector-stores-faiss==0.1.1 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 diff --git a/tests/metagpt/rag/engine/test_simple.py b/tests/metagpt/rag/engines/test_simple.py similarity index 52% rename from tests/metagpt/rag/engine/test_simple.py rename to tests/metagpt/rag/engines/test_simple.py index ceec4d63a..1d1ddad12 100644 --- a/tests/metagpt/rag/engine/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -1,58 +1,75 @@ import pytest -from llama_index import VectorStoreIndex +from llama_index.core import VectorStoreIndex +from llama_index.core.schema import TextNode from metagpt.rag.engines import SimpleEngine from metagpt.rag.retrievers.base import ModifiableRAGRetriever class TestSimpleEngine: - def test_from_docs(self, mocker): + @pytest.fixture + def mock_simple_directory_reader(self, mocker): + return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") + + @pytest.fixture + def mock_vector_store_index(self, mocker): + return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") + + @pytest.fixture + def mock_get_retriever(self, mocker): + return mocker.patch("metagpt.rag.engines.simple.get_retriever") + + @pytest.fixture + def mock_get_rankers(self, mocker): + return mocker.patch("metagpt.rag.engines.simple.get_rankers") + + @pytest.fixture + def mock_get_response_synthesizer(self, mocker): + return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer") + + def test_from_docs( + self, + mocker, + mock_simple_directory_reader, + mock_vector_store_index, + mock_get_retriever, + mock_get_rankers, + mock_get_response_synthesizer, + ): # Mock - mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"] - - mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults") - mock_service_context.return_value = "service_context" - - mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents") - mock_get_retriever = mocker.patch("metagpt.rag.engines.simple.get_retriever") - mock_get_rankers = mocker.patch("metagpt.rag.engines.simple.get_rankers") + mock_get_retriever.return_value = mocker.MagicMock() + mock_get_rankers.return_value = [mocker.MagicMock()] + mock_get_response_synthesizer.return_value = mocker.MagicMock() # Setup input_dir = "test_dir" input_files = ["test_file1", "test_file2"] + transformations = [mocker.MagicMock()] embed_model = mocker.MagicMock() llm = mocker.MagicMock() - chunk_size = 100 - chunk_overlap = 10 - retriever_configs = mocker.MagicMock() - ranker_configs = mocker.MagicMock() + retriever_configs = [mocker.MagicMock()] + ranker_configs = [mocker.MagicMock()] # Execute engine = SimpleEngine.from_docs( input_dir=input_dir, input_files=input_files, + transformations=transformations, embed_model=embed_model, llm=llm, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, retriever_configs=retriever_configs, ranker_configs=ranker_configs, ) # Assertions mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) - mock_service_context.assert_called_once_with( - embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm + mock_vector_store_index.assert_called_once() + mock_get_retriever.assert_called_once_with( + configs=retriever_configs, index=mock_vector_store_index.return_value ) - mock_vector_store_index.assert_called_once_with( - ["document1", "document2"], service_context=mock_service_context.return_value - ) - mock_get_retriever.assert_called_once_with(mock_vector_store_index.return_value, configs=retriever_configs) - mock_get_rankers.assert_called_once_with( - configs=ranker_configs, service_context=mock_service_context.return_value - ) - + mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm) + mock_get_response_synthesizer.assert_called_once_with(llm=llm) assert isinstance(engine, SimpleEngine) @pytest.mark.asyncio @@ -100,8 +117,12 @@ class TestSimpleEngine: mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"] mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever) + mock_index = mocker.MagicMock(spec=VectorStoreIndex) - mock_index.service_context.node_parser.get_nodes_from_documents = lambda x: ["node1", "node2"] + mock_index._transformations = mocker.MagicMock() + + mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations") + mock_run_transformations.return_value = ["node1", "node2"] # Setup engine = SimpleEngine(retriever=mock_retriever, index=mock_index) @@ -113,3 +134,27 @@ class TestSimpleEngine: # Assertions mock_simple_directory_reader.assert_called_once_with(input_files=input_files) mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"]) + + def test_add_objs(self, mocker): + # Mock + mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever) + + # Setup + class CustomTextNode(TextNode): + def rag_key(self): + return "" + + def model_dump(self): + return {} + + objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)] + engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock()) + + # Execute + engine.add_objs(objs=objs) + + # Assertions + assert mock_retriever.add_nodes.call_count == 1 + for node in mock_retriever.add_nodes.call_args[0][0]: + assert isinstance(node, TextNode) + assert "obj" in node.metadata diff --git a/tests/metagpt/rag/factories/test_base.py b/tests/metagpt/rag/factories/test_base.py new file mode 100644 index 000000000..78e969ff4 --- /dev/null +++ b/tests/metagpt/rag/factories/test_base.py @@ -0,0 +1,102 @@ +import pytest + +from metagpt.rag.factories.base import ConfigFactory, GenericFactory + + +class TestGenericFactory: + @pytest.fixture + def creators(self): + return { + "type1": lambda name: f"Instance of type1 with {name}", + "type2": lambda name: f"Instance of type2 with {name}", + } + + @pytest.fixture + def factory(self, creators): + return GenericFactory(creators=creators) + + def test_get_instance_success(self, factory): + # Test successful retrieval of an instance + key = "type1" + instance = factory.get_instance(key, name="TestName") + assert instance == "Instance of type1 with TestName" + + def test_get_instance_failure(self, factory): + # Test failure to retrieve an instance due to unregistered key + with pytest.raises(ValueError) as exc_info: + factory.get_instance("unknown_key") + assert "Creator not registered for key: unknown_key" in str(exc_info.value) + + def test_get_instances_success(self, factory): + # Test successful retrieval of multiple instances + keys = ["type1", "type2"] + instances = factory.get_instances(keys, name="TestName") + expected = ["Instance of type1 with TestName", "Instance of type2 with TestName"] + assert instances == expected + + @pytest.mark.parametrize( + "keys,expected_exception_message", + [ + (["unknown_key"], "Creator not registered for key: unknown_key"), + (["type1", "unknown_key"], "Creator not registered for key: unknown_key"), + ], + ) + def test_get_instances_with_failure(self, factory, keys, expected_exception_message): + # Test failure to retrieve instances due to at least one unregistered key + with pytest.raises(ValueError) as exc_info: + factory.get_instances(keys, name="TestName") + assert expected_exception_message in str(exc_info.value) + + +class DummyConfig: + """A dummy config class for testing.""" + + def __init__(self, name): + self.name = name + + +class TestConfigFactory: + @pytest.fixture + def config_creators(self): + return { + DummyConfig: lambda config, **kwargs: f"Processed {config.name} with {kwargs.get('extra', 'no extra')}", + } + + @pytest.fixture + def config_factory(self, config_creators): + return ConfigFactory(creators=config_creators) + + def test_get_instance_success(self, config_factory): + # Test successful retrieval of an instance + config = DummyConfig(name="TestConfig") + instance = config_factory.get_instance(config, extra="additional data") + assert instance == "Processed TestConfig with additional data" + + def test_get_instance_failure(self, config_factory): + # Test failure to retrieve an instance due to unknown config type + class UnknownConfig: + pass + + config = UnknownConfig() + with pytest.raises(ValueError) as exc_info: + config_factory.get_instance(config) + assert "Unknown config:" in str(exc_info.value) + + def test_val_from_config_or_kwargs_priority(self): + # Test that the value from the config object has priority over kwargs + config = DummyConfig(name="ConfigName") + result = ConfigFactory._val_from_config_or_kwargs("name", config, name="KwargsName") + assert result == "ConfigName" + + def test_val_from_config_or_kwargs_fallback_to_kwargs(self): + # Test fallback to kwargs when config object does not have the value + config = DummyConfig(name=None) + result = ConfigFactory._val_from_config_or_kwargs("name", config, name="KwargsName") + assert result == "KwargsName" + + def test_val_from_config_or_kwargs_key_error(self): + # Test KeyError when the key is not found in both config object and kwargs + config = DummyConfig(name=None) + with pytest.raises(KeyError) as exc_info: + ConfigFactory._val_from_config_or_kwargs("missing_key", config) + assert "The key 'missing_key' is required but not provided" in str(exc_info.value) diff --git a/tests/metagpt/rag/factories/test_llm.py b/tests/metagpt/rag/factories/test_llm.py new file mode 100644 index 000000000..21f5ee823 --- /dev/null +++ b/tests/metagpt/rag/factories/test_llm.py @@ -0,0 +1,56 @@ +import pytest +from llama_index.llms.anthropic import Anthropic +from llama_index.llms.azure_openai import AzureOpenAI +from llama_index.llms.gemini import Gemini +from llama_index.llms.ollama import Ollama +from llama_index.llms.openai import OpenAI + +from metagpt.configs.llm_config import LLMType +from metagpt.rag.factories.llm import RAGLLMFactory + + +class TestRAGLLMFactory: + @pytest.fixture(autouse=True) + def setup(self, mocker): + # Mock the config object for all tests in this class + self.mock_config = mocker.MagicMock() + self.mock_config.llm.api_type = LLMType.OPENAI + self.mock_config.llm.base_url = "http://example.com" + self.mock_config.llm.api_key = "test_api_key" + self.mock_config.llm.api_version = "v1" + self.mock_config.llm.model = "test_model" + self.mock_config.llm.max_token = 100 + self.mock_config.llm.temperature = 0.5 + mocker.patch("metagpt.rag.factories.llm.config", self.mock_config) + self.factory = RAGLLMFactory() + + @pytest.mark.parametrize( + "llm_type,expected_class", + [ + (LLMType.OPENAI, OpenAI), + (LLMType.AZURE, AzureOpenAI), + (LLMType.ANTHROPIC, Anthropic), + (LLMType.GEMINI, Gemini), + (LLMType.OLLAMA, Ollama), + ], + ) + def test_creates_correct_llm_instance(self, llm_type, expected_class, mocker): + # Mock the LLM constructors + mocker.patch.object(expected_class, "__init__", return_value=None) + instance = self.factory.get_rag_llm(key=llm_type) + assert isinstance(instance, expected_class) + expected_class.__init__.assert_called_once() + + def test_uses_default_llm_type_when_no_key_provided(self, mocker): + # Assume the default API type is OPENAI for this test + mock = mocker.patch.object(OpenAI, "__init__", return_value=None) + instance = self.factory.get_rag_llm() + assert isinstance(instance, OpenAI) + mock.assert_called_once_with( + api_base=self.mock_config.llm.base_url, + api_key=self.mock_config.llm.api_key, + api_version=self.mock_config.llm.api_version, + model=self.mock_config.llm.model, + max_tokens=self.mock_config.llm.max_token, + temperature=self.mock_config.llm.temperature, + ) diff --git a/tests/metagpt/rag/factories/test_ranker.py b/tests/metagpt/rag/factories/test_ranker.py new file mode 100644 index 000000000..d4b4167a6 --- /dev/null +++ b/tests/metagpt/rag/factories/test_ranker.py @@ -0,0 +1,43 @@ +import pytest +from llama_index.core.llms import LLM +from llama_index.core.postprocessor import LLMRerank + +from metagpt.rag.factories.ranker import RankerFactory +from metagpt.rag.schema import LLMRankerConfig + + +class TestRankerFactory: + @pytest.fixture + def ranker_factory(self) -> RankerFactory: + return RankerFactory() + + @pytest.fixture + def mock_llm(self, mocker): + return mocker.MagicMock(spec=LLM) + + def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker): + mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm) + default_rankers = ranker_factory.get_rankers() + assert len(default_rankers) == 1 + assert isinstance(default_rankers[0], LLMRerank) + ranker_factory._extract_llm.assert_called_once() + + def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm): + mock_config = LLMRankerConfig(llm=mock_llm) + rankers = ranker_factory.get_rankers(configs=[mock_config]) + assert len(rankers) == 1 + assert isinstance(rankers[0], LLMRerank) + + def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm): + mock_config = LLMRankerConfig(llm=mock_llm) + ranker = ranker_factory._create_llm_ranker(mock_config) + assert isinstance(ranker, LLMRerank) + + def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm): + mock_config = LLMRankerConfig(llm=mock_llm) + extracted_llm = ranker_factory._extract_llm(config=mock_config) + assert extracted_llm == mock_llm + + def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm): + extracted_llm = ranker_factory._extract_llm(llm=mock_llm) + assert extracted_llm == mock_llm diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py new file mode 100644 index 000000000..ac8926d46 --- /dev/null +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -0,0 +1,79 @@ +import faiss +import pytest +from llama_index.core import VectorStoreIndex + +from metagpt.rag.factories.retriever import RetrieverFactory +from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever +from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever +from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever +from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig + + +class TestRetrieverFactory: + @pytest.fixture + def retriever_factory(self): + return RetrieverFactory() + + @pytest.fixture + def mock_faiss_index(self, mocker): + return mocker.MagicMock(spec=faiss.IndexFlatL2) + + @pytest.fixture + def mock_vector_store_index(self, mocker): + mock = mocker.MagicMock(spec=VectorStoreIndex) + mock._embed_model = mocker.MagicMock() + mock.docstore.docs.values.return_value = [] + return mock + + def test_get_retriever_with_faiss_config( + self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index + ): + mock_config = FAISSRetrieverConfig(dimensions=128) + mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index) + mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index) + + retriever = retriever_factory.get_retriever(configs=[mock_config]) + + assert isinstance(retriever, FAISSRetriever) + + def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index): + mock_config = BM25RetrieverConfig() + mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) + mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index) + + retriever = retriever_factory.get_retriever(configs=[mock_config]) + + assert isinstance(retriever, DynamicBM25Retriever) + + def test_get_retriever_with_multiple_configs_returns_hybrid( + self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index + ): + mock_faiss_config = FAISSRetrieverConfig(dimensions=128) + mock_bm25_config = BM25RetrieverConfig() + mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) + mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index) + + retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config]) + + assert isinstance(retriever, SimpleHybridRetriever) + + def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index): + mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index) + mock_vector_store_index.as_retriever = mocker.MagicMock() + + retriever = retriever_factory.get_retriever() + + mock_vector_store_index.as_retriever.assert_called_once() + assert retriever is mock_vector_store_index.as_retriever.return_value + + def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index): + mock_config = FAISSRetrieverConfig(index=mock_vector_store_index) + + extracted_index = retriever_factory._extract_index(config=mock_config) + + assert extracted_index == mock_vector_store_index + + def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index): + extracted_index = retriever_factory._extract_index(index=mock_vector_store_index) + + assert extracted_index == mock_vector_store_index diff --git a/tests/metagpt/rag/retrievers/test_bm25_retriever.py b/tests/metagpt/rag/retrievers/test_bm25_retriever.py index cc845a35a..77a1db495 100644 --- a/tests/metagpt/rag/retrievers/test_bm25_retriever.py +++ b/tests/metagpt/rag/retrievers/test_bm25_retriever.py @@ -1,5 +1,5 @@ import pytest -from llama_index.schema import Node +from llama_index.core.schema import Node from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever @@ -17,7 +17,7 @@ class TestDynamicBM25Retriever: # 模拟nodes和tokenizer参数 mock_nodes = [] mock_tokenizer = mocker.MagicMock() - self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi") + self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) # 初始化DynamicBM25Retriever对象,并提供必需的参数 self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer) diff --git a/tests/metagpt/rag/retrievers/test_faiss_retriever.py b/tests/metagpt/rag/retrievers/test_faiss_retriever.py index 7d5a5a5a3..9113f110c 100644 --- a/tests/metagpt/rag/retrievers/test_faiss_retriever.py +++ b/tests/metagpt/rag/retrievers/test_faiss_retriever.py @@ -1,5 +1,5 @@ import pytest -from llama_index.schema import Node +from llama_index.core.schema import Node from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever diff --git a/tests/metagpt/rag/retrievers/test_hybrid_retriever.py b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py index 62d976ba2..8cc3087c8 100644 --- a/tests/metagpt/rag/retrievers/test_hybrid_retriever.py +++ b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py @@ -1,7 +1,7 @@ from unittest.mock import AsyncMock import pytest -from llama_index.schema import NodeWithScore, TextNode +from llama_index.core.schema import NodeWithScore, TextNode from metagpt.rag.retrievers import SimpleHybridRetriever diff --git a/tests/metagpt/rag/test_factory.py b/tests/metagpt/rag/test_factory.py deleted file mode 100644 index 70e0809a9..000000000 --- a/tests/metagpt/rag/test_factory.py +++ /dev/null @@ -1,130 +0,0 @@ -import pytest -from llama_index import ServiceContext -from llama_index.indices.base import BaseIndex -from llama_index.postprocessor import LLMRerank - -from metagpt.rag.factory import RankerFactory, RetrieverFactory -from metagpt.rag.retrievers.base import RAGRetriever -from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever -from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever -from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever -from metagpt.rag.schema import ( - BM25RetrieverConfig, - FAISSRetrieverConfig, - LLMRankerConfig, -) - - -class TestRetrieverFactory: - @pytest.fixture - def mock_base_index(self, mocker): - mock = mocker.MagicMock(spec=BaseIndex) - mock.as_retriever.return_value = mocker.MagicMock(spec=RAGRetriever) - mock.service_context = mocker.MagicMock() - mock.docstore.docs.values.return_value = [] - return mock - - @pytest.fixture - def mock_faiss_retriever_config(self): - return FAISSRetrieverConfig(dimensions=128) - - @pytest.fixture - def mock_bm25_retriever_config(self): - return BM25RetrieverConfig() - - @pytest.fixture - def mock_faiss_vector_store(self, mocker): - return mocker.patch("metagpt.rag.factory.FaissVectorStore") - - @pytest.fixture - def mock_storage_context(self, mocker): - return mocker.patch("metagpt.rag.factory.StorageContext") - - @pytest.fixture - def mock_vector_store_index(self, mocker): - return mocker.patch("metagpt.rag.factory.VectorStoreIndex") - - @pytest.fixture - def mock_dynamic_bm25_retriever(self, mocker): - mock = mocker.MagicMock(spec=DynamicBM25Retriever) - return mocker.patch("metagpt.rag.factory.DynamicBM25Retriever", mock) - - def test_get_retriever_with_no_configs_returns_default_retriever(self, mock_base_index): - factory = RetrieverFactory() - retriever = factory.get_retriever(index=mock_base_index) - assert isinstance(retriever, RAGRetriever) - - def test_get_retriever_with_specific_config_returns_correct_retriever( - self, - mock_base_index, - mock_faiss_retriever_config, - mock_faiss_vector_store, - mock_storage_context, - mock_vector_store_index, - ): - factory = RetrieverFactory() - retriever = factory.get_retriever(index=mock_base_index, configs=[mock_faiss_retriever_config]) - assert isinstance(retriever, FAISSRetriever) - - def test_get_retriever_with_multiple_configs_returns_hybrid_retriever( - self, - mock_base_index, - mock_faiss_retriever_config, - mock_bm25_retriever_config, - mock_faiss_vector_store, - mock_storage_context, - mock_vector_store_index, - mock_dynamic_bm25_retriever, - ): - factory = RetrieverFactory() - retriever = factory.get_retriever( - index=mock_base_index, configs=[mock_faiss_retriever_config, mock_bm25_retriever_config] - ) - assert isinstance(retriever, SimpleHybridRetriever) - - def test_get_retriever_with_unknown_config_raises_value_error(self, mock_base_index, mocker): - mock_unknown_config = mocker.MagicMock() - factory = RetrieverFactory() - with pytest.raises(ValueError): - factory.get_retriever(index=mock_base_index, configs=[mock_unknown_config]) - - -class TestRankerFactory: - @pytest.fixture - def mock_service_context(self, mocker): - return mocker.MagicMock(spec=ServiceContext) - - def test_get_rankers_with_no_configs_returns_default_ranker(self, mock_service_context): - # Setup - factory = RankerFactory() - - # Execute - rankers = factory.get_rankers(service_context=mock_service_context) - - # Assertions - assert len(rankers) == 1 - assert isinstance(rankers[0], LLMRerank) - - def test_get_rankers_with_specific_config_returns_correct_ranker(self, mock_service_context): - # Setup - config = LLMRankerConfig(top_n=3) - factory = RankerFactory() - - # Execute - rankers = factory.get_rankers(configs=[config], service_context=mock_service_context) - - # Assertions - assert len(rankers) == 1 - assert isinstance(rankers[0], LLMRerank) - assert rankers[0].top_n == 3 - - def test_get_rankers_with_unknown_config_raises_value_error(self, mocker, mock_service_context): - # Mock - mock_config = mocker.MagicMock() # 使用 MagicMock 来模拟一个未知的配置类型 - - # Setup - factory = RankerFactory() - - # Execute & Assertions - with pytest.raises(ValueError): - factory.get_rankers(configs=[mock_config], service_context=mock_service_context)