Merge branch 'feat-exp-pool-bm25' into 'mgx_ops'

Feat exp pool bm25

See merge request pub/MetaGPT!327
This commit is contained in:
王金淋 2024-08-19 06:45:03 +00:00
commit c6cc5e2da3
14 changed files with 244 additions and 54 deletions

View file

@ -79,6 +79,7 @@ exp_pool:
enable_read: false
enable_write: false
persist_path: .chroma_exp_data # The directory.
retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding.
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
azure_tts_region: "eastus"

View file

@ -3,7 +3,7 @@ # Experience Pool
## Prerequisites
- Ensure the RAG module is installed: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html
- Set embedding: https://docs.deepwisdom.ai/main/en/guide/in_depth_guides/rag_module.html
- Set both `enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml`
- Set `enabled``enable_read` and `enable_write` to `true` in the `exp_pool` section of `config2.yaml`
## Example Files

View file

@ -1,8 +1,15 @@
from enum import Enum
from pydantic import Field
from metagpt.utils.yaml_model import YamlModel
class ExperiencePoolRetrievalType(Enum):
BM25 = "bm25"
CHROMA = "chroma"
class ExperiencePoolConfig(YamlModel):
enabled: bool = Field(
default=False,
@ -11,3 +18,6 @@ class ExperiencePoolConfig(YamlModel):
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
retrieval_type: ExperiencePoolRetrievalType = Field(
default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool."
)

View file

@ -134,14 +134,14 @@ class ExpCacheHandler(BaseModel):
"""Fetch experiences by query_type."""
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
logger.debug(f"Found {len(self._exps)} experiences for req '{self._req[:20]}...' and tag '{self.tag}'")
logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'")
async def get_one_perfect_exp(self) -> Optional[Any]:
"""Get a potentially perfect experience, and resolve resp."""
for exp in self._exps:
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
logger.debug(f"Got one perfect experience for req '{exp.req[:20]}...'")
logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'")
return self.serializer.deserialize_resp(exp.resp)
return None

View file

@ -1,10 +1,12 @@
"""Experience Manager."""
from pathlib import Path
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict, Field
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType
from metagpt.exp_pool.schema import (
DEFAULT_COLLECTION_NAME,
DEFAULT_SIMILARITY_TOP_K,
@ -15,7 +17,7 @@ from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
if TYPE_CHECKING:
from llama_index.vector_stores.chroma import ChromaVectorStore
from metagpt.rag.engines import SimpleEngine
class ExperienceManager(BaseModel):
@ -32,40 +34,16 @@ class ExperienceManager(BaseModel):
config: Config = Field(default_factory=Config.default)
_storage: Any = None
_vector_store: Any = None
@property
def storage(self):
if self._storage is None:
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
except ImportError:
raise ImportError("To use the experience pool, you need to install the rag module.")
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
collection_name=DEFAULT_COLLECTION_NAME,
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
)
]
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
self._storage: SimpleEngine = SimpleEngine.from_objs(
retriever_configs=retriever_configs, ranker_configs=ranker_configs
)
logger.info(f"exp_pool config: {self.config.exp_pool}")
self._storage = self._resolve_storage()
return self._storage
@property
def vector_store(self):
if not self._vector_store:
self._vector_store: ChromaVectorStore = self.storage._retriever._vector_store
return self._vector_store
@handle_exception
def create_exp(self, exp: Experience):
"""Adds an experience to the storage if writing is enabled.
@ -78,6 +56,7 @@ class ExperienceManager(BaseModel):
return
self.storage.add_objs([exp])
self.storage.persist(self.config.exp_pool.persist_path)
@handle_exception(default_return=[])
async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]:
@ -110,7 +89,92 @@ class ExperienceManager(BaseModel):
def get_exps_count(self) -> int:
"""Get the total number of experiences."""
return self.vector_store._collection.count()
return self.storage.count()
def _resolve_storage(self) -> "SimpleEngine":
"""Selects the appropriate storage creation method based on the configured retrieval type."""
storage_creators = {
ExperiencePoolRetrievalType.BM25: self._create_bm25_storage,
ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage,
}
return storage_creators[self.config.exp_pool.retrieval_type]()
def _create_bm25_storage(self) -> "SimpleEngine":
"""Creates or loads BM25 storage.
This function attempts to create a new BM25 storage if the specified
document store path does not exist. If the path exists, it loads the
existing BM25 storage.
Returns:
SimpleEngine: An instance of SimpleEngine configured with BM25 storage.
Raises:
ImportError: If required modules are not installed.
"""
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import (
BM25IndexConfig,
BM25RetrieverConfig,
LLMRankerConfig,
)
except ImportError:
raise ImportError("To use the experience pool, you need to install the rag module.")
persist_path = Path(self.config.exp_pool.persist_path)
docstore_path = persist_path / "docstore.json"
if not docstore_path.exists():
logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.")
exps = [Experience(req="req", resp="resp")]
retriever_configs = [BM25RetrieverConfig(create_index=True)]
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
storage = SimpleEngine.from_objs(
objs=exps, retriever_configs=retriever_configs, ranker_configs=ranker_configs
)
return storage
logger.debug(f"Path `{docstore_path}` exists, try to load bm25 storage.")
storage = SimpleEngine.from_index(
BM25IndexConfig(persist_path=persist_path), retriever_configs=[BM25RetrieverConfig()]
)
return storage
def _create_chroma_storage(self) -> "SimpleEngine":
"""Creates Chroma storage.
Returns:
SimpleEngine: An instance of SimpleEngine configured with Chroma storage.
Raises:
ImportError: If required modules are not installed.
"""
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
except ImportError:
raise ImportError("To use the experience pool, you need to install the rag module.")
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
collection_name=DEFAULT_COLLECTION_NAME,
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
)
]
ranker_configs = [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)]
storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
return storage
_exp_manager = None

View file

@ -37,7 +37,11 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.base import (
ModifiableRAGRetriever,
PersistableRAGRetriever,
QueryableRAGRetriever,
)
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseIndexConfig,
@ -144,7 +148,7 @@ class SimpleEngine(RetrieverQueryEngine):
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
nodes = cls.get_obj_nodes(objs)
return cls._from_nodes(
nodes=nodes,
@ -201,7 +205,7 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
nodes = self.get_obj_nodes(objs)
self._save_nodes(nodes)
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
@ -210,6 +214,18 @@ class SimpleEngine(RetrieverQueryEngine):
self._persist(str(persist_dir), **kwargs)
def count(self) -> int:
"""Count."""
self._ensure_retriever_queryable()
return self._retriever.query_total_count()
@staticmethod
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
"""Converts a list of RAGObjects to a list of ObjectNodes."""
return [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
@classmethod
def _from_nodes(
cls,
@ -258,6 +274,9 @@ class SimpleEngine(RetrieverQueryEngine):
def _ensure_retriever_persistable(self):
self._ensure_retriever_of_type(PersistableRAGRetriever)
def _ensure_retriever_queryable(self):
self._ensure_retriever_of_type(QueryableRAGRetriever)
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.

View file

@ -7,6 +7,7 @@ import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
@ -85,6 +86,12 @@ class RetrieverFactory(ConfigBasedFactory):
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
if index and not config.index:
config.index = index
if not config.index and config.create_index:
config.index = VectorStoreIndex(nodes, embed_model=MockEmbedding(embed_dim=1))
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:

View file

@ -45,3 +45,17 @@ class PersistableRAGRetriever(RAGRetriever):
@abstractmethod
def persist(self, persist_dir: str, **kwargs) -> None:
"""To support persist, must inplement this func"""
class QueryableRAGRetriever(RAGRetriever):
"""Support querying total count."""
@classmethod
def __subclasshook__(cls, C):
if cls is QueryableRAGRetriever:
return check_methods(C, "query_total_count")
return NotImplemented
@abstractmethod
def query_total_count(self) -> int:
"""To support querying total count, must implement this func"""

View file

@ -47,3 +47,8 @@ class DynamicBM25Retriever(BM25Retriever):
"""Support persist."""
if self._index:
self._index.storage_context.persist(persist_dir)
def query_total_count(self) -> int:
"""Support query total count."""
return len(self._nodes)

View file

@ -2,6 +2,7 @@
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
from llama_index.vector_stores.chroma import ChromaVectorStore
class ChromaRetriever(VectorIndexRetriever):
@ -15,3 +16,10 @@ class ChromaRetriever(VectorIndexRetriever):
"""Support persist.
Chromadb automatically saves, so there is no need to implement."""
def query_total_count(self) -> int:
"""Support query total count."""
vector_store: ChromaVectorStore = self._vector_store
return vector_store._collection.count()

View file

@ -60,6 +60,11 @@ class FAISSRetrieverConfig(IndexRetrieverConfig):
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
create_index: bool = Field(
default=False,
description="Indicates whether to create an index for the nodes. It is useful when you need to persist data while only using BM25.",
exclude=True,
)
_no_embedding: bool = PrivateAttr(default=True)

View file

@ -1,7 +1,10 @@
import pytest
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
from metagpt.configs.exp_pool_config import (
ExperiencePoolConfig,
ExperiencePoolRetrievalType,
)
from metagpt.configs.llm_config import LLMConfig
from metagpt.exp_pool.manager import Experience, ExperienceManager
from metagpt.exp_pool.schema import QueryType
@ -10,17 +13,19 @@ from metagpt.exp_pool.schema import QueryType
class TestExperienceManager:
@pytest.fixture
def mock_config(self):
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, enabled=True))
return Config(
llm=LLMConfig(),
exp_pool=ExperiencePoolConfig(
enable_write=True, enable_read=True, enabled=True, retrieval_type=ExperiencePoolRetrievalType.BM25
),
)
@pytest.fixture
def mock_storage(self, mocker):
engine = mocker.MagicMock()
engine.add_objs = mocker.MagicMock()
engine.aretrieve = mocker.AsyncMock(return_value=[])
engine._retriever = mocker.MagicMock()
engine._retriever._vector_store = mocker.MagicMock()
engine._retriever._vector_store._collection = mocker.MagicMock()
engine._retriever._vector_store._collection.count = mocker.MagicMock(return_value=10)
engine.count = mocker.MagicMock(return_value=10)
return engine
@pytest.fixture
@ -29,8 +34,33 @@ class TestExperienceManager:
manager._storage = mock_storage
return manager
def test_vector_store_property(self, exp_manager):
assert exp_manager.vector_store == exp_manager.storage._retriever._vector_store
def test_storage_property(self, exp_manager, mock_storage):
assert exp_manager.storage == mock_storage
def test_storage_property_initialization(self, mocker, mock_config):
mocker.patch.object(ExperienceManager, "_resolve_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
assert manager._storage is None
_ = manager.storage
assert manager._storage is not None
def test_create_exp_write_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_write = False
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_not_called()
def test_create_exp_write_enabled(self, exp_manager):
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_called_once_with([exp])
exp_manager.storage.persist.assert_called_once_with(exp_manager.config.exp_pool.persist_path)
@pytest.mark.asyncio
async def test_query_exps_read_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_read = False
result = await exp_manager.query_exps("query")
assert result == []
@pytest.mark.asyncio
async def test_query_exps_with_exact_match(self, exp_manager, mocker):
@ -65,14 +95,37 @@ class TestExperienceManager:
def test_get_exps_count(self, exp_manager):
assert exp_manager.get_exps_count() == 10
def test_create_exp_write_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_write = False
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_not_called()
def test_resolve_storage_bm25(self, mocker, mock_config):
mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.BM25
mocker.patch.object(ExperienceManager, "_create_bm25_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._resolve_storage()
manager._create_bm25_storage.assert_called_once()
assert storage is not None
@pytest.mark.asyncio
async def test_query_exps_read_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_read = False
result = await exp_manager.query_exps("query")
assert result == []
def test_resolve_storage_chroma(self, mocker, mock_config):
mock_config.exp_pool.retrieval_type = ExperiencePoolRetrievalType.CHROMA
mocker.patch.object(ExperienceManager, "_create_chroma_storage", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._resolve_storage()
manager._create_chroma_storage.assert_called_once()
assert storage is not None
def test_create_bm25_storage(self, mocker, mock_config):
mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.engines.SimpleEngine.from_index", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.engines.SimpleEngine.get_obj_nodes", return_value=[])
mocker.patch("metagpt.rag.engines.SimpleEngine._resolve_embed_model", return_value=mocker.MagicMock())
mocker.patch("llama_index.core.VectorStoreIndex", return_value=mocker.MagicMock())
mocker.patch("metagpt.rag.schema.BM25RetrieverConfig", return_value=mocker.MagicMock())
mocker.patch("pathlib.Path.exists", return_value=False)
manager = ExperienceManager(config=mock_config)
storage = manager._create_bm25_storage()
assert storage is not None
def test_create_chroma_storage(self, mocker, mock_config):
mocker.patch("metagpt.rag.engines.SimpleEngine.from_objs", return_value=mocker.MagicMock())
manager = ExperienceManager(config=mock_config)
storage = manager._create_chroma_storage()
assert storage is not None

View file

@ -75,7 +75,7 @@ class TestSimpleEngine:
)
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files, fs=None)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)

View file

@ -1,5 +1,6 @@
import pytest
from metagpt.config2 import Config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
@ -12,7 +13,10 @@ class TestRAGEmbeddingFactory:
@pytest.fixture
def mock_config(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.config")
config = Config.default().model_copy(deep=True)
default = mocker.patch("metagpt.config2.Config.default")
default.return_value = config
return config
@staticmethod
def mock_openai_embedding(mocker):