exp_pool add enabled

This commit is contained in:
seehi 2024-08-09 16:05:48 +08:00
parent dc685c000f
commit 36db2b067e
7 changed files with 12 additions and 6 deletions

View file

@ -75,6 +75,7 @@ s3:
bucket: "test"
exp_pool:
enabled: false
enable_read: false
enable_write: false
persist_path: .chroma_exp_data # The directory.

View file

@ -4,6 +4,10 @@ from metagpt.utils.yaml_model import YamlModel
class ExperiencePoolConfig(YamlModel):
enabled: bool = Field(
default=False,
description="Flag to enable or disable the experience pool. When disabled, both reading and writing are ineffective.",
)
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")

View file

@ -50,7 +50,7 @@ def exp_cache(
"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
if not config.exp_pool.enable_read:
if not config.exp_pool.enabled:
return func
@functools.wraps(func)

View file

@ -74,7 +74,7 @@ class ExperienceManager(BaseModel):
exp (Experience): The experience to add.
"""
if not self.config.exp_pool.enable_write:
if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_write:
return
self.storage.add_objs([exp])
@ -92,7 +92,7 @@ class ExperienceManager(BaseModel):
list[Experience]: A list of experiences that match the args.
"""
if not self.config.exp_pool.enable_read:
if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_read:
return []
nodes = await self.storage.aretrieve(req)

View file

@ -1,5 +1,5 @@
"""Experience schema."""
import time
from enum import Enum
from typing import Optional
@ -67,6 +67,7 @@ class Experience(BaseModel):
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")
tag: str = Field(default="", description="Tagging experience.")
traj: Optional[Trajectory] = Field(default=None, description="Trajectory.")
timestamp: Optional[float] = Field(default_factory=time.time)
def rag_key(self):
return self.req

View file

@ -159,7 +159,7 @@ class TestExpCache:
@pytest.mark.asyncio
async def test_exp_cache_disabled(self, mock_config, mock_exp_manager):
mock_config.exp_pool.enable_read = False
mock_config.exp_pool.enabled = False
@exp_cache(manager=mock_exp_manager)
async def test_func(req):

View file

@ -10,7 +10,7 @@ from metagpt.exp_pool.schema import QueryType
class TestExperienceManager:
@pytest.fixture
def mock_config(self):
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, enabled=True))
@pytest.fixture
def mock_storage(self, mocker):