Merge branch 'feat-merge-github-rag' into 'mgx_ops'

Merge the newest rag

See merge request pub/MetaGPT!136
This commit is contained in:
林义章 2024-06-03 02:51:18 +00:00
commit 7b293235b8
17 changed files with 482 additions and 113 deletions

View file

@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional
from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig
@ -48,6 +49,9 @@ class Config(CLIParams, YamlModel):
# Key Parameters
llm: LLMConfig
# RAG Embedding
embedding: EmbeddingConfig = EmbeddingConfig()
# Global Proxy. Not used by LLM, but by other tools such as browsers.
proxy: str = ""

View file

@ -0,0 +1,50 @@
from enum import Enum
from typing import Optional
from pydantic import field_validator
from metagpt.utils.yaml_model import YamlModel
class EmbeddingType(Enum):
OPENAI = "openai"
AZURE = "azure"
GEMINI = "gemini"
OLLAMA = "ollama"
class EmbeddingConfig(YamlModel):
"""Config for Embedding.
Examples:
---------
api_type: "openai"
api_key: "YOU_API_KEY"
api_type: "azure"
api_key: "YOU_API_KEY"
base_url: "YOU_BASE_URL"
api_version: "YOU_API_VERSION"
api_type: "gemini"
api_key: "YOU_API_KEY"
api_type: "ollama"
base_url: "YOU_BASE_URL"
model: "YOU_MODEL"
"""
api_type: Optional[EmbeddingType] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
model: Optional[str] = None
embed_batch_size: Optional[int] = None
@field_validator("api_type", mode="before")
@classmethod
def check_api_type(cls, v):
if v == "":
return None
return v

View file

@ -4,8 +4,7 @@ import json
import os
from typing import Any, Optional, Union
from fsspec import AbstractFileSystem
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core import SimpleDirectoryReader
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
@ -64,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine):
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
transformations: Optional[list[TransformComponent]] = None,
) -> None:
super().__init__(
retriever=retriever,
@ -72,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine):
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
self._transformations = transformations or self._default_transformations()
@classmethod
def from_docs(
@ -84,7 +83,6 @@ class SimpleEngine(RetrieverQueryEngine):
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
fs: Optional[AbstractFileSystem] = None,
) -> "SimpleEngine":
"""From docs.
@ -102,15 +100,20 @@ class SimpleEngine(RetrieverQueryEngine):
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files, fs=fs).load_data()
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
cls._fix_document_metadata(documents)
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations = transformations or cls._default_transformations()
nodes = run_transformations(documents, transformations=transformations)
return cls._from_nodes(
nodes=nodes,
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_objs(
@ -139,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
return cls._from_nodes(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_index(
@ -163,6 +169,13 @@ class SimpleEngine(RetrieverQueryEngine):
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
def retrieve(self, query: QueryType) -> list[NodeWithScore]:
query_bundle = QueryBundle(query) if isinstance(query, str) else query
nodes = super().retrieve(query_bundle)
self._try_reconstruct_obj(nodes)
return nodes
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str."""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
@ -178,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine):
documents = SimpleDirectoryReader(input_files=input_files).load_data()
self._fix_document_metadata(documents)
nodes = run_transformations(documents, transformations=self.index._transformations)
nodes = run_transformations(documents, transformations=self._transformations)
self._save_nodes(nodes)
def add_objs(self, objs: list[RAGObject]):
@ -194,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine):
self._persist(str(persist_dir), **kwargs)
@classmethod
def _from_nodes(
cls,
nodes: list[BaseNode],
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
transformations=transformations,
)
@classmethod
def _from_index(
cls,
@ -203,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine):
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
@ -210,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine):
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
def _ensure_retriever_modifiable(self):
@ -261,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine):
return MockEmbedding(embed_dim=1)
return embed_model or get_rag_embedding()
@staticmethod
def _default_transformations():
return [SentenceSplitter()]

View file

@ -26,6 +26,9 @@ class GenericFactory:
if creator:
return creator(**kwargs)
self._raise_for_key(key)
def _raise_for_key(self, key: Any):
raise ValueError(f"Creator not registered for key: {key}")
@ -33,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""
def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
"""Get instance by the type of key.
Call func by the type of key, and the key will be passed to func.
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
Raise Exception if key not found.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)
self._raise_for_key(key)
def _raise_for_key(self, key: Any):
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
Return None if not found.
"""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
@ -54,6 +64,4 @@ class ConfigBasedFactory(GenericFactory):
if key in kwargs:
return kwargs[key]
raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)
return None

View file

@ -1,37 +1,103 @@
"""RAG Embedding Factory."""
from __future__ import annotations
from typing import Any
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.base import GenericFactory
class RAGEmbeddingFactory(GenericFactory):
"""Create LlamaIndex Embedding with MetaGPT's config."""
"""Create LlamaIndex Embedding with MetaGPT's embedding config."""
def __init__(self):
creators = {
EmbeddingType.OPENAI: self._create_openai,
EmbeddingType.AZURE: self._create_azure,
EmbeddingType.GEMINI: self._create_gemini,
EmbeddingType.OLLAMA: self._create_ollama,
# For backward compatibility
LLMType.OPENAI: self._create_openai,
LLMType.AZURE: self._create_azure,
}
super().__init__(creators)
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
"""Key is LLMType, default use config.llm.api_type."""
return super().get_instance(key or config.llm.api_type)
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
"""Key is EmbeddingType."""
return super().get_instance(key or self._resolve_embedding_type())
def _create_openai(self):
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
def _resolve_embedding_type(self) -> EmbeddingType | LLMType:
"""Resolves the embedding type.
def _create_azure(self):
return AzureOpenAIEmbedding(
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE.
Raise TypeError if embedding type not found.
"""
if config.embedding.api_type:
return config.embedding.api_type
if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
return config.llm.api_type
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
def _create_openai(self) -> OpenAIEmbedding:
params = dict(
api_key=config.embedding.api_key or config.llm.api_key,
api_base=config.embedding.base_url or config.llm.base_url,
)
self._try_set_model_and_batch_size(params)
return OpenAIEmbedding(**params)
def _create_azure(self) -> AzureOpenAIEmbedding:
params = dict(
api_key=config.embedding.api_key or config.llm.api_key,
azure_endpoint=config.embedding.base_url or config.llm.base_url,
api_version=config.embedding.api_version or config.llm.api_version,
)
self._try_set_model_and_batch_size(params)
return AzureOpenAIEmbedding(**params)
def _create_gemini(self) -> GeminiEmbedding:
params = dict(
api_key=config.embedding.api_key,
api_base=config.embedding.base_url,
)
self._try_set_model_and_batch_size(params)
return GeminiEmbedding(**params)
def _create_ollama(self) -> OllamaEmbedding:
params = dict(
base_url=config.embedding.base_url,
)
self._try_set_model_and_batch_size(params)
return OllamaEmbedding(**params)
def _try_set_model_and_batch_size(self, params: dict):
"""Set the model_name and embed_batch_size only when they are specified."""
if config.embedding.model:
params["model_name"] = config.embedding.model
if config.embedding.embed_batch_size:
params["embed_batch_size"] = config.embedding.embed_batch_size
def _raise_for_key(self, key: Any):
raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}")
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding

View file

@ -48,7 +48,7 @@ class RAGIndexFactory(ConfigBasedFactory):
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)

View file

@ -1,5 +1,5 @@
"""RAG LLM."""
import asyncio
from typing import Any
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
@ -15,7 +15,7 @@ from pydantic import Field
from metagpt.config2 import config
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.async_helper import run_coroutine_in_new_loop
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.token_counter import TOKEN_MAX
@ -39,7 +39,8 @@ class RAGLLM(CustomLLM):
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
NestAsyncio.apply_once()
return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs))
@llm_completion_callback()
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:

View file

@ -8,6 +8,8 @@ from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import (
BaseRankerConfig,
BGERerankConfig,
CohereRerankConfig,
ColbertRerankConfig,
LLMRankerConfig,
ObjectRankerConfig,
@ -22,6 +24,8 @@ class RankerFactory(ConfigBasedFactory):
LLMRankerConfig: self._create_llm_ranker,
ColbertRerankConfig: self._create_colbert_ranker,
ObjectRankerConfig: self._create_object_ranker,
CohereRerankConfig: self._create_cohere_rerank,
BGERerankConfig: self._create_bge_rerank,
}
super().__init__(creators)
@ -45,6 +49,26 @@ class RankerFactory(ConfigBasedFactory):
)
return ColbertRerank(**config.model_dump())
def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank:
try:
from llama_index.postprocessor.cohere_rerank import CohereRerank
except ImportError:
raise ImportError(
"`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`"
)
return CohereRerank(**config.model_dump())
def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank:
try:
from llama_index.postprocessor.flag_embedding_reranker import (
FlagEmbeddingReranker,
)
except ImportError:
raise ImportError(
"`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`"
)
return FlagEmbeddingReranker(**config.model_dump())
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
return ObjectSortPostprocessor(**config.model_dump())

View file

@ -1,10 +1,13 @@
"""RAG Retriever Factory."""
import copy
from functools import wraps
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
@ -24,10 +27,25 @@ from metagpt.rag.schema import (
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
def get_or_build_index(build_index_func):
"""Decorator to get or build an index.
Get index using `_extract_index` method, if not found, using build_index_func.
"""
@wraps(build_index_func)
def wrapper(self, config, **kwargs):
index = self._extract_index(config, **kwargs)
if index is not None:
return index
return build_index_func(self, config, **kwargs)
return wrapper
class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory):
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_faiss_index(config, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_chroma_index(config, **kwargs)
return ChromaRetriever(**config.model_dump())
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_es_index(config, **kwargs)
return ElasticsearchRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
return self._val_from_config_or_kwargs("nodes", config, **kwargs)
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
def _build_default_index(self, **kwargs) -> VectorStoreIndex:
index = VectorStoreIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
)
return index
@get_or_build_index
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
return self._build_index_from_vector_store(config, vector_store, **kwargs)
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
index = VectorStoreIndex(
nodes=self._extract_nodes(config, **kwargs),
storage_context=storage_context,
embed_model=old_index._embed_model,
embed_model=self._extract_embed_model(config, **kwargs),
)
return new_index
return index
get_retriever = RetrieverFactory().get_retriever

View file

@ -40,8 +40,10 @@ class DynamicBM25Retriever(BM25Retriever):
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)
self._index.insert_nodes(nodes, **kwargs)
if self._index:
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)
if self._index:
self._index.storage_context.persist(persist_dir)

View file

@ -1,14 +1,17 @@
"""RAG schemas."""
from pathlib import Path
from typing import Any, Literal, Union
from typing import Any, ClassVar, Literal, Optional, Union
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.rag.interface import RAGObject
@ -31,7 +34,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig):
class FAISSRetrieverConfig(IndexRetrieverConfig):
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
return self
class BM25RetrieverConfig(IndexRetrieverConfig):
@ -45,6 +60,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class ElasticsearchStoreConfig(BaseModel):
@ -101,6 +119,16 @@ class ColbertRerankConfig(BaseRankerConfig):
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
class CohereRerankConfig(BaseRankerConfig):
model: str = Field(default="rerank-english-v3.0")
api_key: str = Field(default="YOUR_COHERE_API")
class BGERerankConfig(BaseRankerConfig):
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
class ObjectRankerConfig(BaseRankerConfig):
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
@ -130,6 +158,9 @@ class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):

View file

@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any:
new_loop.call_soon_threadsafe(new_loop.stop)
t.join()
new_loop.close()
class NestAsyncio:
"""Make asyncio event loop reentrant."""
is_applied = False
@classmethod
def apply_once(cls):
"""Ensures `nest_asyncio.apply()` is called only once."""
if not cls.is_applied:
import nest_asyncio
nest_asyncio.apply()
cls.is_applied = True

View file

@ -32,12 +32,15 @@ extras_require = {
"llama-index-core==0.10.15",
"llama-index-embeddings-azure-openai==0.1.6",
"llama-index-embeddings-openai==0.1.5",
"llama-index-embeddings-gemini==0.1.6",
"llama-index-embeddings-ollama==0.1.2",
"llama-index-llms-azure-openai==0.1.4",
"llama-index-readers-file==0.1.4",
"llama-index-retrievers-bm25==0.1.3",
"llama-index-vector-stores-faiss==0.1.1",
"llama-index-vector-stores-elasticsearch==0.1.6",
"llama-index-vector-stores-chroma==0.1.6",
"docx2txt==0.8",
],
}

View file

@ -25,10 +25,6 @@ class TestSimpleEngine:
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
@pytest.fixture
def mock_get_retriever(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
@ -45,7 +41,6 @@ class TestSimpleEngine:
self,
mocker,
mock_simple_directory_reader,
mock_vector_store_index,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
@ -81,11 +76,8 @@ class TestSimpleEngine:
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
configs=retriever_configs, index=mock_vector_store_index.return_value
)
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
@ -119,7 +111,7 @@ class TestSimpleEngine:
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is not None
assert engine._transformations is not None
def test_from_objs_with_bm25_config(self):
# Setup
@ -137,6 +129,7 @@ class TestSimpleEngine:
def test_from_index(self, mocker, mock_llm, mock_embedding):
# Mock
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index.as_retriever.return_value = "retriever"
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
mock_get_index.return_value = mock_index
@ -149,7 +142,7 @@ class TestSimpleEngine:
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is mock_index
assert engine._retriever == "retriever"
@pytest.mark.asyncio
async def test_asearch(self, mocker):
@ -200,14 +193,11 @@ class TestSimpleEngine:
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index._transformations = mocker.MagicMock()
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
mock_run_transformations.return_value = ["node1", "node2"]
# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
engine = SimpleEngine(retriever=mock_retriever)
input_files = ["test_file1", "test_file2"]
# Exec
@ -230,7 +220,7 @@ class TestSimpleEngine:
return ""
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
engine = SimpleEngine(retriever=mock_retriever)
# Exec
engine.add_objs(objs=objs)

View file

@ -97,6 +97,5 @@ class TestConfigBasedFactory:
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert val is None

View file

@ -1,5 +1,6 @@
import pytest
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
self.embedding_factory = RAGEmbeddingFactory()
@pytest.fixture
def mock_openai_embedding(self, mocker):
def mock_config(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.config")
@staticmethod
def mock_openai_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
@pytest.fixture
def mock_azure_embedding(self, mocker):
@staticmethod
def mock_azure_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
def test_get_rag_embedding_openai(self, mock_openai_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
@staticmethod
def mock_gemini_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
# Assert
mock_openai_embedding.assert_called_once()
@staticmethod
def mock_ollama_embedding(mocker):
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
def test_get_rag_embedding_azure(self, mock_azure_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
# Assert
mock_azure_embedding.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
@pytest.mark.parametrize(
("mock_func", "embedding_type"),
[
(mock_openai_embedding, LLMType.OPENAI),
(mock_azure_embedding, LLMType.AZURE),
(mock_openai_embedding, EmbeddingType.OPENAI),
(mock_azure_embedding, EmbeddingType.AZURE),
(mock_gemini_embedding, EmbeddingType.GEMINI),
(mock_ollama_embedding, EmbeddingType.OLLAMA),
],
)
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
# Mock
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
mock = mock_func(mocker)
# Exec
self.embedding_factory.get_rag_embedding(embedding_type)
# Assert
mock.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_config):
# Mock
mock_openai_embedding = self.mock_openai_embedding(mocker)
mock_config.embedding.api_type = None
mock_config.llm.api_type = LLMType.OPENAI
# Exec
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
# Assert
mock_openai_embedding.assert_called_once()
@pytest.mark.parametrize(
"model, embed_batch_size, expected_params",
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
)
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
# Mock
mock_config.embedding.model = model
mock_config.embedding.embed_batch_size = embed_batch_size
# Setup
test_params = {}
# Exec
self.embedding_factory._try_set_model_and_batch_size(test_params)
# Assert
assert test_params == expected_params
def test_resolve_embedding_type(self, mock_config):
# Mock
mock_config.embedding.api_type = EmbeddingType.OPENAI
# Exec
embedding_type = self.embedding_factory._resolve_embedding_type()
# Assert
assert embedding_type == EmbeddingType.OPENAI
def test_resolve_embedding_type_exception(self, mock_config):
# Mock
mock_config.embedding.api_type = None
mock_config.llm.api_type = LLMType.GEMINI
# Assert
with pytest.raises(TypeError):
self.embedding_factory._resolve_embedding_type()
def test_raise_for_key(self):
with pytest.raises(ValueError):
self.embedding_factory._raise_for_key("key")

View file

@ -1,6 +1,8 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.embeddings import MockEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
@ -43,6 +45,14 @@ class TestRetrieverFactory:
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
@pytest.fixture
def mock_nodes(self, mocker):
return [TextNode(text="msg")]
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
@ -52,42 +62,40 @@ class TestRetrieverFactory:
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
retriever = self.retriever_factory.get_retriever(
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
)
assert isinstance(retriever, SimpleHybridRetriever)
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, ElasticsearchRetriever)
@ -111,3 +119,19 @@ class TestRetrieverFactory:
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index
def test_get_or_build_when_get(self, mocker):
want = "existing_index"
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
got = self.retriever_factory._build_es_index(None)
assert got == want
def test_get_or_build_when_build(self, mocker):
want = "call_build_es_index"
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
got = self.retriever_factory._build_es_index(None)
assert got == want