mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-29 15:59:42 +02:00
exp pool support bm25
This commit is contained in:
parent
aa8e2fa8c3
commit
a6af03efb8
10 changed files with 231 additions and 50 deletions
|
|
@ -37,7 +37,11 @@ from metagpt.rag.factories import (
|
|||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.base import (
|
||||
ModifiableRAGRetriever,
|
||||
PersistableRAGRetriever,
|
||||
QueryableRAGRetriever,
|
||||
)
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
|
|
@ -144,7 +148,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
nodes = cls.get_obj_nodes(objs)
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
|
|
@ -201,7 +205,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
nodes = self.get_obj_nodes(objs)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
|
||||
|
|
@ -210,6 +214,18 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
def count(self) -> int:
|
||||
"""Count."""
|
||||
self._ensure_retriever_queryable()
|
||||
|
||||
return self._retriever.query_total_count()
|
||||
|
||||
@staticmethod
|
||||
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
|
||||
"""Converts a list of RAGObjects to a list of ObjectNodes."""
|
||||
|
||||
return [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
|
||||
@classmethod
|
||||
def _from_nodes(
|
||||
cls,
|
||||
|
|
@ -258,6 +274,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
def _ensure_retriever_persistable(self):
|
||||
self._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_queryable(self):
|
||||
self._ensure_retriever_of_type(QueryableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
|
||||
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
index = self._extract_index(config, **kwargs)
|
||||
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
|
||||
if index and not config.index:
|
||||
config.index = index
|
||||
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
|
|
|
|||
|
|
@ -45,3 +45,17 @@ class PersistableRAGRetriever(RAGRetriever):
|
|||
@abstractmethod
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""To support persist, must inplement this func"""
|
||||
|
||||
|
||||
class QueryableRAGRetriever(RAGRetriever):
|
||||
"""Support querying total count."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is QueryableRAGRetriever:
|
||||
return check_methods(C, "query_total_count")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def query_total_count(self) -> int:
|
||||
"""To support querying total count, must implement this func"""
|
||||
|
|
|
|||
|
|
@ -47,3 +47,8 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
"""Support persist."""
|
||||
if self._index:
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
||||
def query_total_count(self) -> int:
|
||||
"""Support query total count."""
|
||||
|
||||
return len(self._nodes)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class ChromaRetriever(VectorIndexRetriever):
|
||||
|
|
@ -15,3 +16,10 @@ class ChromaRetriever(VectorIndexRetriever):
|
|||
"""Support persist.
|
||||
|
||||
Chromadb automatically saves, so there is no need to implement."""
|
||||
|
||||
def query_total_count(self) -> int:
|
||||
"""Support query total count."""
|
||||
|
||||
vector_store: ChromaVectorStore = self._vector_store
|
||||
|
||||
return vector_store._collection.count()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue