mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
Merge branch 'feat-exp-pool-bm25' into 'mgx_ops'
Feat exp pool bm25 See merge request pub/MetaGPT!327
This commit is contained in:
commit
c6cc5e2da3
14 changed files with 244 additions and 54 deletions
|
|
@ -1,8 +1,15 @@
|
|||
from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class ExperiencePoolRetrievalType(Enum):
|
||||
BM25 = "bm25"
|
||||
CHROMA = "chroma"
|
||||
|
||||
|
||||
class ExperiencePoolConfig(YamlModel):
|
||||
enabled: bool = Field(
|
||||
default=False,
|
||||
|
|
@ -11,3 +18,6 @@ class ExperiencePoolConfig(YamlModel):
|
|||
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
|
||||
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
|
||||
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
|
||||
retrieval_type: ExperiencePoolRetrievalType = Field(
|
||||
default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -134,14 +134,14 @@ class ExpCacheHandler(BaseModel):
|
|||
"""Fetch experiences by query_type."""
|
||||
|
||||
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
|
||||
logger.debug(f"Found {len(self._exps)} experiences for req '{self._req[:20]}...' and tag '{self.tag}'")
|
||||
logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'")
|
||||
|
||||
async def get_one_perfect_exp(self) -> Optional[Any]:
|
||||
"""Get a potentially perfect experience, and resolve resp."""
|
||||
|
||||
for exp in self._exps:
|
||||
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
|
||||
logger.debug(f"Got one perfect experience for req '{exp.req[:20]}...'")
|
||||
logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'")
|
||||
return self.serializer.deserialize_resp(exp.resp)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""Experience Manager."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType
|
||||
from metagpt.exp_pool.schema import (
|
||||
DEFAULT_COLLECTION_NAME,
|
||||
DEFAULT_SIMILARITY_TOP_K,
|
||||
|
|
@ -15,7 +17,7 @@ from metagpt.logs import logger
|
|||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
class ExperienceManager(BaseModel):
|
||||
|
|
@ -32,40 +34,16 @@ class ExperienceManager(BaseModel):
|
|||
config: Config = Field(default_factory=Config.default)
|
||||
|
||||
_storage: Any = None
|
||||
_vector_store: Any = None
|
||||
|
||||
@property
|
||||
def storage(self):
|
||||
if self._storage is None:
|
||||
try:
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
|
||||
except ImportError:
|
||||
raise ImportError("To use the experience pool, you need to install the rag module.")
|
||||
|
||||
retriever_configs = [
|
||||
ChromaRetrieverConfig(
|
||||
persist_path=self.config.exp_pool.persist_path,
|
||||
collection_name=DEFAULT_COLLECTION_NAME,
|
||||
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
|
||||
)
|
||||
]
|
||||
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
|
||||
|
||||
self._storage: SimpleEngine = SimpleEngine.from_objs(
|
||||
retriever_configs=retriever_configs, ranker_configs=ranker_configs
|
||||
)
|
||||
logger.info(f"exp_pool config: {self.config.exp_pool}")
|
||||
|
||||
self._storage = self._resolve_storage()
|
||||
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def vector_store(self):
|
||||
if not self._vector_store:
|
||||
self._vector_store: ChromaVectorStore = self.storage._retriever._vector_store
|
||||
|
||||
return self._vector_store
|
||||
|
||||
@handle_exception
|
||||
def create_exp(self, exp: Experience):
|
||||
"""Adds an experience to the storage if writing is enabled.
|
||||
|
|
@ -78,6 +56,7 @@ class ExperienceManager(BaseModel):
|
|||
return
|
||||
|
||||
self.storage.add_objs([exp])
|
||||
self.storage.persist(self.config.exp_pool.persist_path)
|
||||
|
||||
@handle_exception(default_return=[])
|
||||
async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]:
|
||||
|
|
@ -110,7 +89,92 @@ class ExperienceManager(BaseModel):
|
|||
def get_exps_count(self) -> int:
|
||||
"""Get the total number of experiences."""
|
||||
|
||||
return self.vector_store._collection.count()
|
||||
return self.storage.count()
|
||||
|
||||
def _resolve_storage(self) -> "SimpleEngine":
|
||||
"""Selects the appropriate storage creation method based on the configured retrieval type."""
|
||||
|
||||
storage_creators = {
|
||||
ExperiencePoolRetrievalType.BM25: self._create_bm25_storage,
|
||||
ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage,
|
||||
}
|
||||
|
||||
return storage_creators[self.config.exp_pool.retrieval_type]()
|
||||
|
||||
def _create_bm25_storage(self) -> "SimpleEngine":
|
||||
"""Creates or loads BM25 storage.
|
||||
|
||||
This function attempts to create a new BM25 storage if the specified
|
||||
document store path does not exist. If the path exists, it loads the
|
||||
existing BM25 storage.
|
||||
|
||||
Returns:
|
||||
SimpleEngine: An instance of SimpleEngine configured with BM25 storage.
|
||||
|
||||
Raises:
|
||||
ImportError: If required modules are not installed.
|
||||
"""
|
||||
|
||||
try:
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import (
|
||||
BM25IndexConfig,
|
||||
BM25RetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("To use the experience pool, you need to install the rag module.")
|
||||
|
||||
persist_path = Path(self.config.exp_pool.persist_path)
|
||||
docstore_path = persist_path / "docstore.json"
|
||||
|
||||
if not docstore_path.exists():
|
||||
logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.")
|
||||
exps = [Experience(req="req", resp="resp")]
|
||||
|
||||
retriever_configs = [BM25RetrieverConfig(create_index=True)]
|
||||
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
|
||||
|
||||
storage = SimpleEngine.from_objs(
|
||||
objs=exps, retriever_configs=retriever_configs, ranker_configs=ranker_configs
|
||||
)
|
||||
return storage
|
||||
|
||||
logger.debug(f"Path `{docstore_path}` exists, try to load bm25 storage.")
|
||||
storage = SimpleEngine.from_index(
|
||||
BM25IndexConfig(persist_path=persist_path), retriever_configs=[BM25RetrieverConfig()]
|
||||
)
|
||||
|
||||
return storage
|
||||
|
||||
def _create_chroma_storage(self) -> "SimpleEngine":
|
||||
"""Creates Chroma storage.
|
||||
|
||||
Returns:
|
||||
SimpleEngine: An instance of SimpleEngine configured with Chroma storage.
|
||||
|
||||
Raises:
|
||||
ImportError: If required modules are not installed.
|
||||
"""
|
||||
|
||||
try:
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
|
||||
except ImportError:
|
||||
raise ImportError("To use the experience pool, you need to install the rag module.")
|
||||
|
||||
retriever_configs = [
|
||||
ChromaRetrieverConfig(
|
||||
persist_path=self.config.exp_pool.persist_path,
|
||||
collection_name=DEFAULT_COLLECTION_NAME,
|
||||
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
|
||||
)
|
||||
]
|
||||
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
|
||||
|
||||
storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
return storage
|
||||
|
||||
|
||||
_exp_manager = None
|
||||
|
|
|
|||
|
|
@ -37,7 +37,11 @@ from metagpt.rag.factories import (
|
|||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.base import (
|
||||
ModifiableRAGRetriever,
|
||||
PersistableRAGRetriever,
|
||||
QueryableRAGRetriever,
|
||||
)
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
|
|
@ -144,7 +148,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
nodes = cls.get_obj_nodes(objs)
|
||||
|
||||
return cls._from_nodes(
|
||||
nodes=nodes,
|
||||
|
|
@ -201,7 +205,7 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
nodes = self.get_obj_nodes(objs)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
|
||||
|
|
@ -210,6 +214,18 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
def count(self) -> int:
|
||||
"""Count."""
|
||||
self._ensure_retriever_queryable()
|
||||
|
||||
return self._retriever.query_total_count()
|
||||
|
||||
@staticmethod
|
||||
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
|
||||
"""Converts a list of RAGObjects to a list of ObjectNodes."""
|
||||
|
||||
return [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
|
||||
@classmethod
|
||||
def _from_nodes(
|
||||
cls,
|
||||
|
|
@ -258,6 +274,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
def _ensure_retriever_persistable(self):
|
||||
self._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_queryable(self):
|
||||
self._ensure_retriever_of_type(QueryableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
|
||||
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import chromadb
|
|||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
|
@ -85,6 +86,12 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
index = self._extract_index(config, **kwargs)
|
||||
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
|
||||
|
||||
if index and not config.index:
|
||||
config.index = index
|
||||
|
||||
if not config.index and config.create_index:
|
||||
config.index = VectorStoreIndex(nodes, embed_model=MockEmbedding(embed_dim=1))
|
||||
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
|
|
|
|||
|
|
@ -45,3 +45,17 @@ class PersistableRAGRetriever(RAGRetriever):
|
|||
@abstractmethod
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""To support persist, must inplement this func"""
|
||||
|
||||
|
||||
class QueryableRAGRetriever(RAGRetriever):
|
||||
"""Support querying total count."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is QueryableRAGRetriever:
|
||||
return check_methods(C, "query_total_count")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def query_total_count(self) -> int:
|
||||
"""To support querying total count, must implement this func"""
|
||||
|
|
|
|||
|
|
@ -47,3 +47,8 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
"""Support persist."""
|
||||
if self._index:
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
||||
def query_total_count(self) -> int:
|
||||
"""Support query total count."""
|
||||
|
||||
return len(self._nodes)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class ChromaRetriever(VectorIndexRetriever):
|
||||
|
|
@ -15,3 +16,10 @@ class ChromaRetriever(VectorIndexRetriever):
|
|||
"""Support persist.
|
||||
|
||||
Chromadb automatically saves, so there is no need to implement."""
|
||||
|
||||
def query_total_count(self) -> int:
|
||||
"""Support query total count."""
|
||||
|
||||
vector_store: ChromaVectorStore = self._vector_store
|
||||
|
||||
return vector_store._collection.count()
|
||||
|
|
|
|||
|
|
@ -60,6 +60,11 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
|
|||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for BM25-based retrievers."""
|
||||
|
||||
create_index: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether to create an index for the nodes. It is useful when you need to persist data while only using BM25.",
|
||||
exclude=True,
|
||||
)
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue