update exp_pool manager

This commit is contained in:
seehi 2024-08-19 14:07:13 +08:00
parent 243c7a65d6
commit 665ca6ff97
9 changed files with 31 additions and 22 deletions

View file

@ -5,7 +5,7 @@ from pydantic import Field
from metagpt.utils.yaml_model import YamlModel
class ExperiencePoolStorageType(Enum):
class ExperiencePoolRetrievalType(Enum):
BM25 = "bm25"
CHROMA = "chroma"
@ -18,6 +18,6 @@ class ExperiencePoolConfig(YamlModel):
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
storage_type: ExperiencePoolStorageType = Field(
default=ExperiencePoolStorageType.BM25, description="The storage type for experience pool."
retrieval_type: ExperiencePoolRetrievalType = Field(
default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool."
)

View file

@ -132,14 +132,14 @@ class ExpCacheHandler(BaseModel):
"""Fetch experiences by query_type."""
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
logger.debug(f"Found {len(self._exps)} experiences for req '{self._req[:20]}...' and tag '{self.tag}'")
logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'")
async def get_one_perfect_exp(self) -> Optional[Any]:
"""Get a potentially perfect experience, and resolve resp."""
for exp in self._exps:
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
logger.debug(f"Got one perfect experience for req '{exp.req[:20]}...'")
logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'")
return self.serializer.deserialize_resp(exp.resp)
return None

View file

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict, Field
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolStorageType
from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType
from metagpt.exp_pool.schema import (
DEFAULT_COLLECTION_NAME,
DEFAULT_SIMILARITY_TOP_K,
@ -95,11 +95,11 @@ class ExperienceManager(BaseModel):
"""Selects the appropriate storage creation method based on the configured storage type."""
storage_creators = {
ExperiencePoolStorageType.BM25: self._create_bm25_storage,
ExperiencePoolStorageType.CHROMA: self._create_chroma_storage,
ExperiencePoolRetrievalType.BM25: self._create_bm25_storage,
ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage,
}
return storage_creators[self.config.exp_pool.storage_type]()
return storage_creators[self.config.exp_pool.retrieval_type]()
def _create_bm25_storage(self) -> "SimpleEngine":
"""Creates or loads BM25 storage.
@ -116,8 +116,6 @@ class ExperienceManager(BaseModel):
"""
try:
from llama_index.core import VectorStoreIndex
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import (
BM25IndexConfig,
@ -133,11 +131,8 @@ class ExperienceManager(BaseModel):
if not docstore_path.exists():
logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.")
exps = [Experience(req="req", resp="resp")]
nodes = SimpleEngine.get_obj_nodes(exps)
embed_model = SimpleEngine._resolve_embed_model(configs=[BM25RetrieverConfig()])
index = VectorStoreIndex(nodes, embed_model=embed_model)
retriever_configs = [BM25RetrieverConfig(index=index)]
retriever_configs = [BM25RetrieverConfig(create_index=True)]
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
storage = SimpleEngine.from_objs(

View file

@ -7,6 +7,7 @@ import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
@ -84,9 +85,13 @@ class RetrieverFactory(ConfigBasedFactory):
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
if index and not config.index:
config.index = index
if not config.index and config.create_index:
config.index = VectorStoreIndex(nodes, embed_model=MockEmbedding(embed_dim=1))
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:

View file

@ -60,6 +60,11 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
create_index: bool = Field(
default=False,
description="Indicates whether to create an index for the nodes. It is useful when you need to persist data while only using BM25.",
exclude=True,
)
_no_embedding: bool = PrivateAttr(default=True)