upgrade llama-index to v0.10

This commit is contained in:
seehi 2024-02-23 11:06:53 +08:00
parent fae24fd381
commit 19a9a98c0b
29 changed files with 725 additions and 370 deletions

View file

@ -3,22 +3,32 @@
from typing import Optional
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
from llama_index.callbacks.base import CallbackManager
from llama_index.core.base_retriever import BaseRetriever
from llama_index.embeddings.base import BaseEmbedding
from llama_index.indices.base import BaseIndex
from llama_index.llms.llm import LLM
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import BaseSynthesizer
from llama_index.schema import NodeWithScore, QueryBundle, QueryType, TextNode
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.ingestion.pipeline import run_transformations
from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
)
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import (
NodeWithScore,
QueryBundle,
QueryType,
TextNode,
TransformComponent,
)
from metagpt.rag.factory import get_rankers, get_retriever
from metagpt.rag.factories import get_rag_llm, get_rankers, get_retriever
from metagpt.rag.interface import RAGObject
from metagpt.rag.llm import get_default_llm
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
from metagpt.rag.schema import BaseRankerConfig, BaseRetrieverConfig
from metagpt.utils.embedding import get_embedding
@ -51,45 +61,47 @@ class SimpleEngine(RetrieverQueryEngine):
cls,
input_dir: str = None,
input_files: list[str] = None,
llm: LLM = None,
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
chunk_size: int = None,
chunk_overlap: int = None,
retriever_configs: list[RetrieverConfigType] = None,
ranker_configs: list[RankerConfigType] = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""This engine is designed to be simple and straightforward
Args:
input_dir: Path to the directory.
input_files: List of file paths to read (Optional; overrides input_dir, exclude).
llm: Must supported by llama index.
embed_model: Must supported by llama index.
chunk_size: The size of text chunks (in tokens) to split documents into for embedding.
chunk_overlap: The number of tokens for overlapping between consecutive chunks. Helps in maintaining context continuity.
transformations: Parse documents to nodes. Default [SentenceSplitter].
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
llm: Must supported by llama index. Default OpenAI.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
service_context = ServiceContext.from_defaults(
llm=llm or get_default_llm(),
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=embed_model or get_embedding(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
retriever = get_retriever(index, configs=retriever_configs)
rankers = get_rankers(configs=ranker_configs, service_context=service_context)
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index)
rankers = get_rankers(configs=ranker_configs, llm=llm)
return cls(retriever=retriever, node_postprocessors=rankers, index=index)
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
async def asearch(self, content: str, **kwargs) -> str:
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
async def aretrieve(self, query_bundle: QueryType) -> list[NodeWithScore]:
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str"""
query_bundle = QueryBundle(query_bundle) if isinstance(query_bundle, str) else query_bundle
query_bundle = QueryBundle(query) if isinstance(query, str) else query
return await super().aretrieve(query_bundle)
def add_docs(self, input_files: list[str]):
@ -97,7 +109,7 @@ class SimpleEngine(RetrieverQueryEngine):
self._ensure_retriever_modifiable()
documents = SimpleDirectoryReader(input_files=input_files).load_data()
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
nodes = run_transformations(documents, transformations=self.index._transformations)
self.retriever.add_nodes(nodes)
def add_objs(self, objs: list[RAGObject]):

View file

@ -0,0 +1,6 @@
"""RAG factories"""
from metagpt.rag.factories.retriever import get_retriever
from metagpt.rag.factories.ranker import get_rankers
from metagpt.rag.factories.llm import get_rag_llm
__all__ = ["get_retriever", "get_rankers", "get_rag_llm"]

View file

@ -0,0 +1,58 @@
"""Base Factory."""
from typing import Any, Callable
class GenericFactory:
"""Designed to get objects based on any keys."""
def __init__(self, creators: dict[Any, Callable] = None):
"""Creators is a dictionary.
Keys are identifiers, and the values are the associated creator function, which create objects.
"""
self._creators = creators or {}
def get_instances(self, keys: list[Any], **kwargs) -> list[Any]:
"""Get instances by keys."""
return [self.get_instance(key, **kwargs) for key in keys]
def get_instance(self, key: Any, **kwargs) -> Any:
"""Get instance by key.
Raise Exception if key not found.
"""
creator = self._creators.get(key)
if creator:
return creator(**kwargs)
raise ValueError(f"Creator not registered for key: {key}")
class ConfigFactory(GenericFactory):
"""Designed to get objects based on object type."""
def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
Call func by the type of key, and the key will be passed to func.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)
raise ValueError(f"Unknown config: {key}")
@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
return val
if key in kwargs:
return kwargs[key]
raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)

View file

@ -0,0 +1,76 @@
"""RAG LLM Factory.
The LLM of LlamaIndex and the LLM of MG are not the same.
"""
from llama_index.core.llms import LLM
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.gemini import Gemini
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.base import GenericFactory
class RAGLLMFactory(GenericFactory):
"""Create LlamaIndex LLM with MG config."""
def __init__(self):
creators = {
LLMType.OPENAI: self._create_openai,
LLMType.AZURE: self._create_azure,
LLMType.ANTHROPIC: self._create_anthropic,
LLMType.GEMINI: self._create_gemini,
LLMType.OLLAMA: self._create_ollama,
}
super().__init__(creators)
def get_rag_llm(self, key: LLMType = None) -> LLM:
"""Key is LLMType, default use config.llm.api_type."""
return super().get_instance(key or config.llm.api_type)
def _create_openai(self):
return OpenAI(
api_base=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
model=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_azure(self):
return AzureOpenAI(
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
model=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_anthropic(self):
return Anthropic(
base_url=config.llm.base_url,
api_key=config.llm.api_key,
model=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_gemini(self):
return Gemini(
api_base=config.llm.base_url,
api_key=config.llm.api_key,
model_name=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_ollama(self):
return Ollama(base_url=config.llm.base_url, model=config.llm.model, temperature=config.llm.temperature)
get_rag_llm = RAGLLMFactory().get_rag_llm

View file

@ -0,0 +1,39 @@
"""RAG Ranker Factory."""
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from metagpt.rag.factories.base import ConfigFactory
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
class RankerFactory(ConfigFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
}
super().__init__(creators)
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return self._create_default(**kwargs)
return super().get_instances(configs, **kwargs)
def _create_default(self, **kwargs) -> list[LLMRerank]:
config = LLMRankerConfig(llm=self._extract_llm(**kwargs))
return [LLMRerank(**config.model_dump())]
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)
def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
config.llm = self._extract_llm(config, **kwargs)
return LLMRerank(**config.model_dump())
get_rankers = RankerFactory().get_rankers

View file

@ -0,0 +1,64 @@
"""RAG Retriever Factory."""
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
FAISSRetrieverConfig,
)
class RetrieverFactory(ConfigFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):
creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
}
super().__init__(creators)
def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations.
If multiple retrievers, using SimpleHybridRetriever.
"""
if not configs:
return self._create_default(**kwargs)
retrievers = super().get_instances(configs, **kwargs)
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
storage_context=storage_context,
embed_model=old_index._embed_model,
)
config.index = new_index
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = self._extract_index(config, **kwargs)
return DynamicBM25Retriever.from_defaults(**config.model_dump())
get_retriever = RetrieverFactory().get_retriever

View file

@ -1,114 +0,0 @@
"""Factory for creating retriever, ranker"""
from typing import Any, Callable
import faiss
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
from llama_index.indices.base import BaseIndex
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
RankerConfigType,
RetrieverConfigType,
)
class BaseFactory:
"""
A base factory class for creating instances based on provided configurations.
It uses a registry of creator functions mapped to configuration types to instantiate objects dynamically.
"""
def __init__(self, creators: dict[Any, Callable]):
"""Creators is a dictionary mapping configuration types to creator functions."""
self.creators = creators
def get_instances(self, configs: list[Any] = None, **kwargs) -> list[Any]:
"""Get instances by configs"""
return [self._get_instance(config, **kwargs) for config in configs]
def _get_instance(self, config: Any, **kwargs) -> Any:
create_func = self.creators.get(type(config))
if create_func:
return create_func(config, **kwargs)
raise ValueError(f"Unknown config: {config}")
class RetrieverFactory(BaseFactory):
"""Modify creators for dynamically instance implementation"""
def __init__(self):
creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
}
super().__init__(creators)
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations.
If multiple retrievers, using SimpleHybridRetriever
"""
if not configs:
return self._default_instance(index)
retrievers = super().get_instances(configs, index=index)
return (
SimpleHybridRetriever(*retrievers, service_context=index.service_context)
if len(retrievers) > 1
else retrievers[0]
)
def _default_instance(self, index: BaseIndex) -> RAGRetriever:
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, index: BaseIndex) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(
nodes=list(index.docstore.docs.values()),
storage_context=storage_context,
service_context=index.service_context,
)
return FAISSRetriever(**config.model_dump(), index=vector_index)
def _create_bm25_retriever(self, config: BM25RetrieverConfig, index: BaseIndex) -> DynamicBM25Retriever:
return DynamicBM25Retriever.from_defaults(**config.model_dump(), index=index)
class RankerFactory(BaseFactory):
"""Modify creators for dynamically instance implementation"""
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
}
super().__init__(creators)
def get_rankers(
self, configs: list[RankerConfigType] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return [self._default_instance(service_context)]
return super().get_instances(configs, service_context=service_context)
def _default_instance(self, service_context: ServiceContext) -> LLMRerank:
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _create_llm_ranker(self, config: LLMRankerConfig, service_context: ServiceContext = None) -> LLMRerank:
return LLMRerank(**config.model_dump(), service_context=service_context)
get_retriever = RetrieverFactory().get_retriever
get_rankers = RankerFactory().get_rankers

View file

@ -1,14 +1,15 @@
"""RAG Interface."""
"""RAG Interfaces."""
from typing import Any, Protocol
class RAGObject(Protocol):
"""Support rag add object"""
"""Support rag add object."""
def rag_key(self) -> str:
"""For rag search."""
def model_dump(self) -> dict[str, Any]:
"""For rag persist.
Pydantic Model don't need to implement this, as there is a built-in function named model_dump
Pydantic Model don't need to implement this, as there is a built-in function named model_dump.
"""

View file

@ -1,11 +0,0 @@
"""RAG LLM
The LLM of LlamaIndex and the LLM of MG are not the same.
"""
from llama_index.llms import OpenAI
from metagpt.config2 import config
def get_default_llm() -> OpenAI:
"""OpenAI"""
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key, model=config.llm.model)

View file

@ -4,8 +4,8 @@ from abc import abstractmethod
from typing import Optional
from llama_index import QueryBundle
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import NodeWithScore
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore
class RAGRanker(BaseNodePostprocessor):

View file

@ -3,8 +3,8 @@
from abc import abstractmethod
from llama_index.retrievers import BaseRetriever
from llama_index.schema import BaseNode, NodeWithScore, QueryType
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import BaseNode, NodeWithScore, QueryType
from metagpt.utils.reflection import check_methods

View file

@ -1,6 +1,7 @@
"""BM25 retriever."""
from llama_index.retrievers import BM25Retriever
from llama_index.schema import BaseNode
from llama_index.core.schema import BaseNode
from llama_index.retrievers.bm25 import BM25Retriever
from rank_bm25 import BM25Okapi
class DynamicBM25Retriever(BM25Retriever):
@ -8,11 +9,6 @@ class DynamicBM25Retriever(BM25Retriever):
def add_nodes(self, nodes: list[BaseNode], **kwargs):
"""Support add nodes"""
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("Please install rank_bm25: pip install rank-bm25")
self._nodes.extend(nodes)
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)

View file

@ -1,6 +1,6 @@
"""FAISS retriever."""
from llama_index.retrievers import VectorIndexRetriever
from llama_index.schema import BaseNode
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class FAISSRetriever(VectorIndexRetriever):

View file

@ -1,23 +1,20 @@
"""Hybrid retriever."""
from llama_index import ServiceContext
from llama_index.schema import BaseNode, QueryType
import copy
from llama_index.core.schema import BaseNode, QueryType
from metagpt.rag.retrievers.base import RAGRetriever
class SimpleHybridRetriever(RAGRetriever):
"""
SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers.
"""
"""A composite retriever that aggregates search results from multiple retrievers."""
def __init__(self, *retrievers, service_context: ServiceContext = None):
def __init__(self, *retrievers):
self.retrievers: list[RAGRetriever] = retrievers
self.service_context = service_context
super().__init__()
async def _aretrieve(self, query: QueryType, **kwargs):
"""
Asynchronously retrieves and aggregates search results from all configured retrievers.
"""Asynchronously retrieves and aggregates search results from all configured retrievers.
This method queries each retriever in the `retrievers` list with the given query and
additional keyword arguments. It then combines the results, ensuring that each node is
@ -25,7 +22,9 @@ class SimpleHybridRetriever(RAGRetriever):
"""
all_nodes = []
for retriever in self.retrievers:
nodes = await retriever.aretrieve(query, **kwargs)
# 防止retriever可能改变query的属性
query_copy = copy.deepcopy(query)
nodes = await retriever.aretrieve(query_copy, **kwargs)
all_nodes.extend(nodes)
# combine all nodes

View file

@ -1,36 +1,52 @@
"""RAG schemas"""
"""RAG schemas."""
from typing import Union
from typing import Any
from pydantic import BaseModel, Field
from llama_index.core.indices.base import BaseIndex
from pydantic import BaseModel, ConfigDict, Field
class RetrieverConfig(BaseModel):
"""Common config for retrievers."""
class BaseRetrieverConfig(BaseModel):
"""Common config for retrievers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
class FAISSRetrieverConfig(RetrieverConfig):
class IndexRetrieverConfig(BaseRetrieverConfig):
"""Config for Index-basd retrievers."""
index: BaseIndex = Field(default=None, description="Index for retriver")
class FAISSRetrieverConfig(IndexRetrieverConfig):
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
class BM25RetrieverConfig(RetrieverConfig):
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
class RankerConfig(BaseModel):
"""Common config for rankers."""
class BaseRankerConfig(BaseModel):
"""Common config for rankers.
top_n: int = 5
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factory.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
top_n: int = Field(default=5, description="The number of top results to return.")
class LLMRankerConfig(RankerConfig):
class LLMRankerConfig(BaseRankerConfig):
"""Config for LLM-based rankers."""
# If add new config, it is necessary to add the corresponding instance implementation in rag.factory
RetrieverConfigType = Union[FAISSRetrieverConfig, BM25RetrieverConfig]
RankerConfigType = LLMRankerConfig
llm: Any = Field(
default=None,
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1",
)