exp pool support bm25

This commit is contained in:
seehi 2024-08-16 21:13:08 +08:00
parent aa8e2fa8c3
commit a6af03efb8
10 changed files with 231 additions and 50 deletions

View file

@ -37,7 +37,11 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.base import (
ModifiableRAGRetriever,
PersistableRAGRetriever,
QueryableRAGRetriever,
)
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseIndexConfig,
@ -144,7 +148,7 @@ class SimpleEngine(RetrieverQueryEngine):
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
nodes = cls.get_obj_nodes(objs)
return cls._from_nodes(
nodes=nodes,
@ -201,7 +205,7 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
nodes = self.get_obj_nodes(objs)
self._save_nodes(nodes)
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
@ -210,6 +214,18 @@ class SimpleEngine(RetrieverQueryEngine):
self._persist(str(persist_dir), **kwargs)
def count(self) -> int:
"""Count."""
self._ensure_retriever_queryable()
return self._retriever.query_total_count()
@staticmethod
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
"""Converts a list of RAGObjects to a list of ObjectNodes."""
return [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
@classmethod
def _from_nodes(
cls,
@ -258,6 +274,9 @@ class SimpleEngine(RetrieverQueryEngine):
def _ensure_retriever_persistable(self):
self._ensure_retriever_of_type(PersistableRAGRetriever)
def _ensure_retriever_queryable(self):
self._ensure_retriever_of_type(QueryableRAGRetriever)
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.

View file

@ -84,6 +84,8 @@ class RetrieverFactory(ConfigBasedFactory):
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
if index and not config.index:
config.index = index
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())

View file

@ -45,3 +45,17 @@ class PersistableRAGRetriever(RAGRetriever):
@abstractmethod
def persist(self, persist_dir: str, **kwargs) -> None:
"""To support persist, must inplement this func"""
class QueryableRAGRetriever(RAGRetriever):
"""Support querying total count."""
@classmethod
def __subclasshook__(cls, C):
if cls is QueryableRAGRetriever:
return check_methods(C, "query_total_count")
return NotImplemented
@abstractmethod
def query_total_count(self) -> int:
"""To support querying total count, must implement this func"""

View file

@ -47,3 +47,8 @@ class DynamicBM25Retriever(BM25Retriever):
"""Support persist."""
if self._index:
self._index.storage_context.persist(persist_dir)
def query_total_count(self) -> int:
"""Support query total count."""
return len(self._nodes)

View file

@ -2,6 +2,7 @@
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
from llama_index.vector_stores.chroma import ChromaVectorStore
class ChromaRetriever(VectorIndexRetriever):
@ -15,3 +16,10 @@ class ChromaRetriever(VectorIndexRetriever):
"""Support persist.
Chromadb automatically saves, so there is no need to implement."""
def query_total_count(self) -> int:
"""Support query total count."""
vector_store: ChromaVectorStore = self._vector_store
return vector_store._collection.count()