mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
add exp_pool test
This commit is contained in:
parent
96cd6b5f64
commit
d600cc47f4
2 changed files with 80 additions and 2 deletions
|
|
@ -1,7 +1,7 @@
|
|||
"""Experience schema."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -38,13 +38,14 @@ class Trajectory(BaseModel):
|
|||
plan: str = Field(default="", description="The plan.")
|
||||
action: str = Field(default="", description="Action for the plan.")
|
||||
observation: str = Field(default="", description="Output of the action.")
|
||||
reward: int = Field(default=0, description="Measure the action.")
|
||||
|
||||
|
||||
class Experience(BaseModel):
|
||||
"""Experience."""
|
||||
|
||||
req: str = Field(..., description="")
|
||||
resp: str = Field(..., description="The type is string/json/code.")
|
||||
resp: Any = Field(..., description="The type is string/json/code.")
|
||||
metric: Optional[Metric] = Field(default=None, description="Metric.")
|
||||
exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.")
|
||||
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")
|
||||
|
|
|
|||
77
tests/metagpt/exp_pool/test_manager.py
Normal file
77
tests/metagpt/exp_pool/test_manager.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.exp_pool.manager import ExperienceManager
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
class TestExperienceManager:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self, mocker):
|
||||
engine = mocker.MagicMock(spec=SimpleEngine)
|
||||
engine.add_objs = mocker.MagicMock()
|
||||
engine.aretrieve = mocker.AsyncMock(return_value=[])
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def mock_experience_manager(self, mock_config, mock_storage):
|
||||
return ExperienceManager(config=mock_config, storage=mock_storage)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_experience(self):
|
||||
return Experience(req="req", resp="resp")
|
||||
|
||||
def test_initialize_storage(self, mock_experience_manager, mock_storage):
|
||||
assert mock_experience_manager.storage is mock_storage
|
||||
|
||||
def test_create_exp(self, mock_experience_manager, mock_experience):
|
||||
mock_experience_manager.create_exp(mock_experience)
|
||||
mock_experience_manager.storage.add_objs.assert_called_once_with([mock_experience])
|
||||
|
||||
def test_create_exp_write_disabled(self, mock_experience_manager, mock_experience, mock_config):
|
||||
mock_config.exp_pool.enable_write = False
|
||||
mock_experience_manager.create_exp(mock_experience)
|
||||
mock_experience_manager.storage.add_objs.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_exps(self, mock_experience_manager, mocker):
|
||||
req = "req"
|
||||
resp = "resp"
|
||||
tag = "test"
|
||||
experiences = [Experience(req=req, resp=resp, tag="test"), Experience(req=req, resp=resp, tag="other")]
|
||||
mock_experience_manager.storage.aretrieve.return_value = [
|
||||
mocker.MagicMock(metadata={"obj": exp}) for exp in experiences
|
||||
]
|
||||
|
||||
result = await mock_experience_manager.query_exps(req, tag)
|
||||
assert len(result) == 1
|
||||
assert result[0].tag == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_exps_no_read_permission(self, mock_experience_manager, mock_config):
|
||||
mock_config.exp_pool.enable_read = False
|
||||
result = await mock_experience_manager.query_exps("query")
|
||||
assert result == []
|
||||
|
||||
def test_extract_one_perfect_exp(self, mock_experience_manager):
|
||||
experiences = [
|
||||
Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE)),
|
||||
Experience(req="req", resp="resp"),
|
||||
]
|
||||
perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences)
|
||||
assert perfect_exp is not None
|
||||
assert perfect_exp.metric.score == MAX_SCORE
|
||||
|
||||
def test_is_perfect_exp(self):
|
||||
exp = Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE))
|
||||
assert ExperienceManager.is_perfect_exp(exp) == True
|
||||
|
||||
exp = Experience(req="req", resp="resp")
|
||||
assert ExperienceManager.is_perfect_exp(exp) == False
|
||||
Loading…
Add table
Add a link
Reference in a new issue