add exp_pool test

This commit is contained in:
seehi 2024-06-04 22:08:40 +08:00
parent 96cd6b5f64
commit d600cc47f4
2 changed files with 80 additions and 2 deletions

View file

@ -1,7 +1,7 @@
"""Experience schema."""
from enum import Enum
from typing import Optional
from typing import Any, Optional
from llama_index.core.schema import TextNode
from pydantic import BaseModel, Field
@ -38,13 +38,14 @@ class Trajectory(BaseModel):
plan: str = Field(default="", description="The plan.")
action: str = Field(default="", description="Action for the plan.")
observation: str = Field(default="", description="Output of the action.")
reward: int = Field(default=0, description="Measure the action.")
class Experience(BaseModel):
"""Experience."""
req: str = Field(..., description="")
resp: str = Field(..., description="The type is string/json/code.")
resp: Any = Field(..., description="The type is string/json/code.")
metric: Optional[Metric] = Field(default=None, description="Metric.")
exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.")
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")

View file

@ -0,0 +1,77 @@
import pytest
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
from metagpt.configs.llm_config import LLMConfig
from metagpt.exp_pool.manager import ExperienceManager
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric
from metagpt.rag.engines import SimpleEngine
class TestExperienceManager:
@pytest.fixture
def mock_config(self):
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
@pytest.fixture
def mock_storage(self, mocker):
engine = mocker.MagicMock(spec=SimpleEngine)
engine.add_objs = mocker.MagicMock()
engine.aretrieve = mocker.AsyncMock(return_value=[])
return engine
@pytest.fixture
def mock_experience_manager(self, mock_config, mock_storage):
return ExperienceManager(config=mock_config, storage=mock_storage)
@pytest.fixture
def mock_experience(self):
return Experience(req="req", resp="resp")
def test_initialize_storage(self, mock_experience_manager, mock_storage):
assert mock_experience_manager.storage is mock_storage
def test_create_exp(self, mock_experience_manager, mock_experience):
mock_experience_manager.create_exp(mock_experience)
mock_experience_manager.storage.add_objs.assert_called_once_with([mock_experience])
def test_create_exp_write_disabled(self, mock_experience_manager, mock_experience, mock_config):
mock_config.exp_pool.enable_write = False
mock_experience_manager.create_exp(mock_experience)
mock_experience_manager.storage.add_objs.assert_not_called()
@pytest.mark.asyncio
async def test_query_exps(self, mock_experience_manager, mocker):
req = "req"
resp = "resp"
tag = "test"
experiences = [Experience(req=req, resp=resp, tag="test"), Experience(req=req, resp=resp, tag="other")]
mock_experience_manager.storage.aretrieve.return_value = [
mocker.MagicMock(metadata={"obj": exp}) for exp in experiences
]
result = await mock_experience_manager.query_exps(req, tag)
assert len(result) == 1
assert result[0].tag == "test"
@pytest.mark.asyncio
async def test_query_exps_no_read_permission(self, mock_experience_manager, mock_config):
mock_config.exp_pool.enable_read = False
result = await mock_experience_manager.query_exps("query")
assert result == []
def test_extract_one_perfect_exp(self, mock_experience_manager):
experiences = [
Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE)),
Experience(req="req", resp="resp"),
]
perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences)
assert perfect_exp is not None
assert perfect_exp.metric.score == MAX_SCORE
def test_is_perfect_exp(self):
exp = Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE))
assert ExperienceManager.is_perfect_exp(exp) == True
exp = Experience(req="req", resp="resp")
assert ExperienceManager.is_perfect_exp(exp) == False