From 2eedc23a827acc892c4928bf14c7e1b99f081c59 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 4 Jun 2024 22:08:40 +0800 Subject: [PATCH] add exp_pool test --- metagpt/exp_pool/schema.py | 5 +- tests/metagpt/exp_pool/test_manager.py | 77 ++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 tests/metagpt/exp_pool/test_manager.py diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py index e6ae4ee1d..1afcc1508 100644 --- a/metagpt/exp_pool/schema.py +++ b/metagpt/exp_pool/schema.py @@ -1,7 +1,7 @@ """Experience schema.""" from enum import Enum -from typing import Optional +from typing import Any, Optional from llama_index.core.schema import TextNode from pydantic import BaseModel, Field @@ -38,13 +38,14 @@ class Trajectory(BaseModel): plan: str = Field(default="", description="The plan.") action: str = Field(default="", description="Action for the plan.") observation: str = Field(default="", description="Output of the action.") + reward: int = Field(default=0, description="Measure the action.") class Experience(BaseModel): """Experience.""" req: str = Field(..., description="") - resp: str = Field(..., description="The type is string/json/code.") + resp: Any = Field(..., description="The type is string/json/code.") metric: Optional[Metric] = Field(default=None, description="Metric.") exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.") entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.") diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py new file mode 100644 index 000000000..a0d7005f5 --- /dev/null +++ b/tests/metagpt/exp_pool/test_manager.py @@ -0,0 +1,77 @@ +import pytest + +from metagpt.config2 import Config +from metagpt.configs.exp_pool_config import ExperiencePoolConfig +from metagpt.configs.llm_config import LLMConfig +from metagpt.exp_pool.manager import ExperienceManager +from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric +from metagpt.rag.engines import SimpleEngine + + +class TestExperienceManager: + @pytest.fixture + def mock_config(self): + return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True)) + + @pytest.fixture + def mock_storage(self, mocker): + engine = mocker.MagicMock(spec=SimpleEngine) + engine.add_objs = mocker.MagicMock() + engine.aretrieve = mocker.AsyncMock(return_value=[]) + return engine + + @pytest.fixture + def mock_experience_manager(self, mock_config, mock_storage): + return ExperienceManager(config=mock_config, storage=mock_storage) + + @pytest.fixture + def mock_experience(self): + return Experience(req="req", resp="resp") + + def test_initialize_storage(self, mock_experience_manager, mock_storage): + assert mock_experience_manager.storage is mock_storage + + def test_create_exp(self, mock_experience_manager, mock_experience): + mock_experience_manager.create_exp(mock_experience) + mock_experience_manager.storage.add_objs.assert_called_once_with([mock_experience]) + + def test_create_exp_write_disabled(self, mock_experience_manager, mock_experience, mock_config): + mock_config.exp_pool.enable_write = False + mock_experience_manager.create_exp(mock_experience) + mock_experience_manager.storage.add_objs.assert_not_called() + + @pytest.mark.asyncio + async def test_query_exps(self, mock_experience_manager, mocker): + req = "req" + resp = "resp" + tag = "test" + experiences = [Experience(req=req, resp=resp, tag="test"), Experience(req=req, resp=resp, tag="other")] + mock_experience_manager.storage.aretrieve.return_value = [ + mocker.MagicMock(metadata={"obj": exp}) for exp in experiences + ] + + result = await mock_experience_manager.query_exps(req, tag) + assert len(result) == 1 + assert result[0].tag == "test" + + @pytest.mark.asyncio + async def test_query_exps_no_read_permission(self, mock_experience_manager, mock_config): + mock_config.exp_pool.enable_read = False + result = await mock_experience_manager.query_exps("query") + assert result == [] + + def test_extract_one_perfect_exp(self, mock_experience_manager): + experiences = [ + Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE)), + Experience(req="req", resp="resp"), + ] + perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences) + assert perfect_exp is not None + assert perfect_exp.metric.score == MAX_SCORE + + def test_is_perfect_exp(self): + exp = Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE)) + assert ExperienceManager.is_perfect_exp(exp) == True + + exp = Experience(req="req", resp="resp") + assert ExperienceManager.is_perfect_exp(exp) == False