add BM25IndexConfig

This commit is contained in:
seehi 2024-03-14 17:44:37 +08:00
parent d29ebc91cc
commit 6c95e601a0
7 changed files with 84 additions and 23 deletions

View file

@ -61,10 +61,10 @@ class RAGExample:
self._print_title("RAG Pipeline")
nodes = await self.engine.aretrieve(question)
self._print_result(nodes, state="Retrieve")
self._print_retrieve_result(nodes)
answer = await self.engine.aquery(question)
self._print_result(answer, state="Query")
self._print_query_result(answer)
async def rag_add_docs(self):
"""This example show how to add docs, before add docs llm anwser I don't know, after add docs llm give the correct answer, will print something like:
@ -160,28 +160,32 @@ class RAGExample:
# query
answer = engine.query(TRAVEL_QUESTION)
self._print_result(answer, state="Query")
self._print_query_result(answer)
@staticmethod
def _print_title(title):
logger.info(f"{'#'*30} {title} {'#'*30}")
@staticmethod
def _print_result(result, state="Retrieve"):
"""print retrieve or query result"""
logger.info(f"{state} Result:")
def _print_retrieve_result(result):
"""Print retrieve result."""
logger.info("Retrieve Result:")
if state == "Retrieve":
for i, node in enumerate(result):
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
logger.info("")
return
for i, node in enumerate(result):
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
logger.info("")
@staticmethod
def _print_query_result(result):
"""Print query result."""
logger.info("Query Result:")
logger.info(f"{result}\n")
async def _retrieve_and_print(self, question):
nodes = await self.engine.aretrieve(question)
self._print_result(nodes, state="Retrieve")
self._print_retrieve_result(nodes)
return nodes

View file

@ -1,7 +1,8 @@
"""Simple Engine."""
import json
from typing import Optional
import os
from typing import Optional, Union
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
@ -128,8 +129,8 @@ class SimpleEngine(RetrieverQueryEngine):
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
if not retriever_configs or any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("Must provide retriever_configs, and BM25RetrieverConfig is not supported.")
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
objs = objs or []
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
@ -182,11 +183,11 @@ class SimpleEngine(RetrieverQueryEngine):
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
self._save_nodes(nodes)
def persist(self, persist_dir: str, **kwargs):
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
"""Persist."""
self._ensure_retriever_persistable()
self._persist(persist_dir, **kwargs)
self._persist(str(persist_dir), **kwargs)
@classmethod
def _from_index(

View file

@ -7,7 +7,12 @@ from llama_index.core.indices.base import BaseIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigFactory
from metagpt.rag.schema import BaseIndexConfig, ChromaIndexConfig, FAISSIndexConfig
from metagpt.rag.schema import (
BaseIndexConfig,
BM25IndexConfig,
ChromaIndexConfig,
FAISSIndexConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
@ -16,6 +21,7 @@ class RAGIndexFactory(ConfigFactory):
creators = {
FAISSIndexConfig: self._create_faiss,
ChromaIndexConfig: self._create_chroma,
BM25IndexConfig: self._create_bm25,
}
super().__init__(creators)
@ -46,5 +52,12 @@ class RAGIndexFactory(ConfigFactory):
)
return index
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
get_index = RAGIndexFactory().get_index

View file

@ -1,5 +1,7 @@
"""RAG Retriever Factory."""
import copy
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
@ -69,8 +71,9 @@ class RetrieverFactory(ConfigFactory):
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = self._extract_index(config, **kwargs)
return DynamicBM25Retriever.from_defaults(**config.model_dump())
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = list(config.index.docstore.docs.values())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))

View file

@ -1,6 +1,10 @@
"""BM25 retriever."""
from typing import Callable, Optional
from llama_index.core.schema import BaseNode
from llama_index.core import VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.schema import BaseNode, IndexNode
from llama_index.retrievers.bm25 import BM25Retriever
from rank_bm25 import BM25Okapi
@ -8,8 +12,36 @@ from rank_bm25 import BM25Okapi
class DynamicBM25Retriever(BM25Retriever):
"""BM25 retriever."""
def __init__(
self,
nodes: list[BaseNode],
tokenizer: Optional[Callable[[str], list[str]]] = None,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
objects: Optional[list[IndexNode]] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
index: VectorStoreIndex = None,
) -> None:
super().__init__(
nodes=nodes,
tokenizer=tokenizer,
similarity_top_k=similarity_top_k,
callback_manager=callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
self._index = index
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes"""
"""Support add nodes."""
self._nodes.extend(nodes)
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)

View file

@ -89,6 +89,10 @@ class ChromaIndexConfig(VectorIndexConfig):
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""

View file

@ -1,4 +1,5 @@
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import Node
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
@ -14,13 +15,16 @@ class TestDynamicBM25Retriever:
self.doc2.get_content.return_value = "Document content 2"
self.mock_nodes = [self.doc1, self.doc2]
# 模拟index
index = mocker.MagicMock(spec=VectorStoreIndex)
# 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
# 初始化DynamicBM25Retriever对象并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer)
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
def test_add_docs_updates_nodes_and_corpus(self):
# Execute