mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 05:42:37 +02:00
Merge branch 'geekan:main' into main
This commit is contained in:
commit
9e4e32e7c7
33 changed files with 321 additions and 140 deletions
|
|
@ -4,7 +4,7 @@ import json
|
|||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
|
|
@ -63,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
response_synthesizer: Optional[BaseSynthesizer] = None,
|
||||
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
index: Optional[BaseIndex] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
retriever=retriever,
|
||||
|
|
@ -71,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
self.index = index
|
||||
self._transformations = transformations or self._default_transformations()
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
|
|
@ -103,12 +103,17 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
cls._fix_document_metadata(documents)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
transformations = transformations or cls._default_transformations()
|
||||
nodes = run_transformations(documents, transformations=transformations)
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_objs(
|
||||
|
|
@ -137,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_index(
|
||||
|
|
@ -183,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
self._fix_document_metadata(documents)
|
||||
|
||||
nodes = run_transformations(documents, transformations=self.index._transformations)
|
||||
nodes = run_transformations(documents, transformations=self._transformations)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def add_objs(self, objs: list[RAGObject]):
|
||||
|
|
@ -199,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_nodes(
|
||||
cls,
|
||||
nodes: list[BaseNode],
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
transformations=transformations,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
|
|
@ -208,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
llm = llm or get_rag_llm()
|
||||
|
||||
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
|
|
@ -215,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
|
|
@ -266,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
return embed_model or get_rag_embedding()
|
||||
|
||||
@staticmethod
|
||||
def _default_transformations():
|
||||
return [SentenceSplitter()]
|
||||
|
|
|
|||
|
|
@ -36,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
|
|||
"""Designed to get objects based on object type."""
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Key is config, such as a pydantic model.
|
||||
"""Get instance by the type of key.
|
||||
|
||||
Call func by the type of key, and the key will be passed to func.
|
||||
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
|
||||
Raise Exception if key not found.
|
||||
"""
|
||||
creator = self._creators.get(type(key))
|
||||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
self._raise_for_key(key)
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
|
||||
|
||||
Return None if not found.
|
||||
"""
|
||||
if config is not None and hasattr(config, key):
|
||||
val = getattr(config, key)
|
||||
if val is not None:
|
||||
|
|
@ -57,6 +64,4 @@ class ConfigBasedFactory(GenericFactory):
|
|||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
|
||||
raise KeyError(
|
||||
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import copy
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
|
@ -24,10 +27,25 @@ from metagpt.rag.schema import (
|
|||
ElasticsearchKeywordRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_or_build_index(build_index_func):
|
||||
"""Decorator to get or build an index.
|
||||
|
||||
Get index using `_extract_index` method, if not found, using build_index_func.
|
||||
"""
|
||||
|
||||
@wraps(build_index_func)
|
||||
def wrapper(self, config, **kwargs):
|
||||
index = self._extract_index(config, **kwargs)
|
||||
if index is not None:
|
||||
return index
|
||||
return build_index_func(self, config, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
|
|
@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
|
||||
|
||||
def _create_default(self, **kwargs) -> RAGRetriever:
|
||||
return self._extract_index(**kwargs).as_retriever()
|
||||
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
|
||||
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_faiss_index(config, **kwargs)
|
||||
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
index = self._extract_index(config, **kwargs)
|
||||
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
|
||||
|
||||
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_chroma_index(config, **kwargs)
|
||||
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
config.index = self._build_es_index(config, **kwargs)
|
||||
|
||||
return ElasticsearchRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
|
||||
return self._val_from_config_or_kwargs("nodes", config, **kwargs)
|
||||
|
||||
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
||||
def _build_default_index(self, **kwargs) -> VectorStoreIndex:
|
||||
index = VectorStoreIndex(
|
||||
nodes=self._extract_nodes(**kwargs),
|
||||
embed_model=self._extract_embed_model(**kwargs),
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
@get_or_build_index
|
||||
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
def _build_index_from_vector_store(
|
||||
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
old_index = self._extract_index(config, **kwargs)
|
||||
new_index = VectorStoreIndex(
|
||||
nodes=list(old_index.docstore.docs.values()),
|
||||
index = VectorStoreIndex(
|
||||
nodes=self._extract_nodes(config, **kwargs),
|
||||
storage_context=storage_context,
|
||||
embed_model=old_index._embed_model,
|
||||
embed_model=self._extract_embed_model(config, **kwargs),
|
||||
)
|
||||
return new_index
|
||||
|
||||
return index
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue