mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
update exp_pool manager
This commit is contained in:
parent
243c7a65d6
commit
665ca6ff97
9 changed files with 31 additions and 22 deletions
|
|
@ -79,7 +79,7 @@ exp_pool:
|
|||
enable_read: false
|
||||
enable_write: false
|
||||
persist_path: .chroma_exp_data # The directory.
|
||||
storage_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding.
|
||||
retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding.
|
||||
|
||||
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
|
||||
azure_tts_region: "eastus"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pydantic import Field
|
|||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class ExperiencePoolStorageType(Enum):
|
||||
class ExperiencePoolRetrievalType(Enum):
|
||||
BM25 = "bm25"
|
||||
CHROMA = "chroma"
|
||||
|
||||
|
|
@ -18,6 +18,6 @@ class ExperiencePoolConfig(YamlModel):
|
|||
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
|
||||
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
|
||||
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
|
||||
storage_type: ExperiencePoolStorageType = Field(
|
||||
default=ExperiencePoolStorageType.BM25, description="The storage type for experience pool."
|
||||
retrieval_type: ExperiencePoolRetrievalType = Field(
|
||||
default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -132,14 +132,14 @@ class ExpCacheHandler(BaseModel):
|
|||
"""Fetch experiences by query_type."""
|
||||
|
||||
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
|
||||
logger.debug(f"Found {len(self._exps)} experiences for req '{self._req[:20]}...' and tag '{self.tag}'")
|
||||
logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'")
|
||||
|
||||
async def get_one_perfect_exp(self) -> Optional[Any]:
|
||||
"""Get a potentially perfect experience, and resolve resp."""
|
||||
|
||||
for exp in self._exps:
|
||||
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
|
||||
logger.debug(f"Got one perfect experience for req '{exp.req[:20]}...'")
|
||||
logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'")
|
||||
return self.serializer.deserialize_resp(exp.resp)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.exp_pool_config import ExperiencePoolStorageType
|
||||
from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType
|
||||
from metagpt.exp_pool.schema import (
|
||||
DEFAULT_COLLECTION_NAME,
|
||||
DEFAULT_SIMILARITY_TOP_K,
|
||||
|
|
@ -95,11 +95,11 @@ class ExperienceManager(BaseModel):
|
|||
"""Selects the appropriate storage creation method based on the configured storage type."""
|
||||
|
||||
storage_creators = {
|
||||
ExperiencePoolStorageType.BM25: self._create_bm25_storage,
|
||||
ExperiencePoolStorageType.CHROMA: self._create_chroma_storage,
|
||||
ExperiencePoolRetrievalType.BM25: self._create_bm25_storage,
|
||||
ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage,
|
||||
}
|
||||
|
||||
return storage_creators[self.config.exp_pool.storage_type]()
|
||||
return storage_creators[self.config.exp_pool.retrieval_type]()
|
||||
|
||||
def _create_bm25_storage(self) -> "SimpleEngine":
|
||||
"""Creates or loads BM25 storage.
|
||||
|
|
@ -116,8 +116,6 @@ class ExperienceManager(BaseModel):
|
|||
"""
|
||||
|
||||
try:
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import (
|
||||
BM25IndexConfig,
|
||||
|
|
@ -133,11 +131,8 @@ class ExperienceManager(BaseModel):
|
|||
if not docstore_path.exists():
|
||||
logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.")
|
||||
exps = [Experience(req="req", resp="resp")]
|
||||
nodes = SimpleEngine.get_obj_nodes(exps)
|
||||
embed_model = SimpleEngine._resolve_embed_model(configs=[BM25RetrieverConfig()])
|
||||
index = VectorStoreIndex(nodes, embed_model=embed_model)
|
||||
|
||||
retriever_configs = [BM25RetrieverConfig(index=index)]
|
||||
retriever_configs = [BM25RetrieverConfig(create_index=True)]
|
||||
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
|
||||
|
||||
storage = SimpleEngine.from_objs(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import chromadb
|
|||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
|
|
@ -84,9 +85,13 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
index = self._extract_index(config, **kwargs)
|
||||
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
|
||||
|
||||
if index and not config.index:
|
||||
config.index = index
|
||||
|
||||
if not config.index and config.create_index:
|
||||
config.index = VectorStoreIndex(nodes, embed_model=MockEmbedding(embed_dim=1))
|
||||
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
|
|
|
|||
|
|
@ -60,6 +60,11 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
|
|||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for BM25-based retrievers."""
|
||||
|
||||
create_index: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether to create an index for the nodes. It is useful when you need to persist data while only using BM25.",
|
||||
exclude=True,
|
||||
)
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
from metagpt.config2 import Config
|
||||
from metagpt.configs.exp_pool_config import (
|
||||
ExperiencePoolConfig,
|
||||
ExperiencePoolStorageType,
|
||||
ExperiencePoolRetrievalType,
|
||||
)
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.exp_pool.manager import Experience, ExperienceManager
|
||||
|
|
@ -16,7 +16,7 @@ class TestExperienceManager:
|
|||
return Config(
|
||||
llm=LLMConfig(),
|
||||
exp_pool=ExperiencePoolConfig(
|
||||
enable_write=True, enable_read=True, enabled=True, storage_type=ExperiencePoolStorageType.BM25
|
||||
enable_write=True, enable_read=True, enabled=True, retrieval_type=ExperiencePoolRetrievalType.BM25
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ class TestExperienceManager:
|
|||
assert exp_manager.get_exps_count() == 10
|
||||
|
||||
def test_resolve_storage_bm25(self, mocker, mock_config):
|
||||
mock_config.exp_pool.storage_type = ExperiencePoolStorageType.BM25
|
||||
mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.BM25
|
||||
mocker.patch.object(ExperienceManager, "_create_bm25_storage", return_value=mocker.MagicMock())
|
||||
manager = ExperienceManager(config=mock_config)
|
||||
storage = manager._resolve_storage()
|
||||
|
|
@ -104,7 +104,7 @@ class TestExperienceManager:
|
|||
assert storage is not None
|
||||
|
||||
def test_resolve_storage_chroma(self, mocker, mock_config):
|
||||
mock_config.exp_pool.storage_type = ExperiencePoolStorageType.CHROMA
|
||||
mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.CHROMA
|
||||
mocker.patch.object(ExperienceManager, "_create_chroma_storage", return_value=mocker.MagicMock())
|
||||
manager = ExperienceManager(config=mock_config)
|
||||
storage = manager._resolve_storage()
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class TestSimpleEngine:
|
|||
)
|
||||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files, fs=None)
|
||||
mock_get_retriever.assert_called_once()
|
||||
mock_get_rankers.assert_called_once()
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
|
@ -12,7 +13,10 @@ class TestRAGEmbeddingFactory:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
config = Config.default().model_copy(deep=True)
|
||||
default = mocker.patch("metagpt.config2.Config.default")
|
||||
default.return_value = config
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def mock_openai_embedding(mocker):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue