mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add BM25IndexConfig
This commit is contained in:
parent
d29ebc91cc
commit
6c95e601a0
7 changed files with 84 additions and 23 deletions
|
|
@ -61,10 +61,10 @@ class RAGExample:
|
|||
self._print_title("RAG Pipeline")
|
||||
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
self._print_retrieve_result(nodes)
|
||||
|
||||
answer = await self.engine.aquery(question)
|
||||
self._print_result(answer, state="Query")
|
||||
self._print_query_result(answer)
|
||||
|
||||
async def rag_add_docs(self):
|
||||
"""This example show how to add docs, before add docs llm anwser I don't know, after add docs llm give the correct answer, will print something like:
|
||||
|
|
@ -160,28 +160,32 @@ class RAGExample:
|
|||
|
||||
# query
|
||||
answer = engine.query(TRAVEL_QUESTION)
|
||||
self._print_result(answer, state="Query")
|
||||
self._print_query_result(answer)
|
||||
|
||||
@staticmethod
|
||||
def _print_title(title):
|
||||
logger.info(f"{'#'*30} {title} {'#'*30}")
|
||||
|
||||
@staticmethod
|
||||
def _print_result(result, state="Retrieve"):
|
||||
"""print retrieve or query result"""
|
||||
logger.info(f"{state} Result:")
|
||||
def _print_retrieve_result(result):
|
||||
"""Print retrieve result."""
|
||||
logger.info("Retrieve Result:")
|
||||
|
||||
if state == "Retrieve":
|
||||
for i, node in enumerate(result):
|
||||
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
logger.info("")
|
||||
return
|
||||
for i, node in enumerate(result):
|
||||
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
|
||||
logger.info("")
|
||||
|
||||
@staticmethod
|
||||
def _print_query_result(result):
|
||||
"""Print query result."""
|
||||
logger.info("Query Result:")
|
||||
|
||||
logger.info(f"{result}\n")
|
||||
|
||||
async def _retrieve_and_print(self, question):
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
self._print_retrieve_result(nodes)
|
||||
return nodes
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"""Simple Engine."""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
|
|
@ -128,8 +129,8 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
if not retriever_configs or any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("Must provide retriever_configs, and BM25RetrieverConfig is not supported.")
|
||||
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
objs = objs or []
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
|
|
@ -182,11 +183,11 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs):
|
||||
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
|
||||
"""Persist."""
|
||||
self._ensure_retriever_persistable()
|
||||
|
||||
self._persist(persist_dir, **kwargs)
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,12 @@ from llama_index.core.indices.base import BaseIndex
|
|||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigFactory
|
||||
from metagpt.rag.schema import BaseIndexConfig, ChromaIndexConfig, FAISSIndexConfig
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
|
|
@ -16,6 +21,7 @@ class RAGIndexFactory(ConfigFactory):
|
|||
creators = {
|
||||
FAISSIndexConfig: self._create_faiss,
|
||||
ChromaIndexConfig: self._create_chroma,
|
||||
BM25IndexConfig: self._create_bm25,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -46,5 +52,12 @@ class RAGIndexFactory(ConfigFactory):
|
|||
)
|
||||
return index
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
|
||||
get_index = RAGIndexFactory().get_index
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import copy
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
|
|
@ -69,8 +71,9 @@ class RetrieverFactory(ConfigFactory):
|
|||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = self._extract_index(config, **kwargs)
|
||||
return DynamicBM25Retriever.from_defaults(**config.model_dump())
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
nodes = list(config.index.docstore.docs.values())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
"""BM25 retriever."""
|
||||
from typing import Callable, Optional
|
||||
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.core.schema import BaseNode, IndexNode
|
||||
from llama_index.retrievers.bm25 import BM25Retriever
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
|
|
@ -8,8 +12,36 @@ from rank_bm25 import BM25Okapi
|
|||
class DynamicBM25Retriever(BM25Retriever):
|
||||
"""BM25 retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: list[BaseNode],
|
||||
tokenizer: Optional[Callable[[str], list[str]]] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
objects: Optional[list[IndexNode]] = None,
|
||||
object_map: Optional[dict] = None,
|
||||
verbose: bool = False,
|
||||
index: VectorStoreIndex = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
nodes=nodes,
|
||||
tokenizer=tokenizer,
|
||||
similarity_top_k=similarity_top_k,
|
||||
callback_manager=callback_manager,
|
||||
object_map=object_map,
|
||||
objects=objects,
|
||||
verbose=verbose,
|
||||
)
|
||||
self._index = index
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes"""
|
||||
"""Support add nodes."""
|
||||
self._nodes.extend(nodes)
|
||||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
|
|
|||
|
|
@ -89,6 +89,10 @@ class ChromaIndexConfig(VectorIndexConfig):
|
|||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
"""Config for bm25-based index."""
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
|
|
@ -14,13 +15,16 @@ class TestDynamicBM25Retriever:
|
|||
self.doc2.get_content.return_value = "Document content 2"
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟index
|
||||
index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
|
||||
# 模拟nodes和tokenizer参数
|
||||
mock_nodes = []
|
||||
mock_tokenizer = mocker.MagicMock()
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
|
||||
# 初始化DynamicBM25Retriever对象,并提供必需的参数
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer)
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
|
||||
|
||||
def test_add_docs_updates_nodes_and_corpus(self):
|
||||
# Execute
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue