mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-14 15:25:17 +02:00
use llm cache to make exp_pool
This commit is contained in:
parent
d902a6f18c
commit
c624c0ffc7
41 changed files with 844 additions and 368 deletions
|
|
@ -91,10 +91,10 @@ async def test_action_node_two_layer():
|
|||
assert node_b in root.children.values()
|
||||
|
||||
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
|
||||
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
|
||||
answer1 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
|
||||
assert "579" in answer1.content
|
||||
|
||||
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
|
||||
answer2 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
|
||||
assert "579" in answer2.content
|
||||
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ async def test_action_node_review():
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = await node_a.review()
|
||||
|
||||
_ = await node_a.fill(context=None, llm=LLM())
|
||||
_ = await node_a.fill(req=None, llm=LLM())
|
||||
setattr(node_a.instruct_content, key, "game snake") # wrong content to review
|
||||
|
||||
review_comments = await node_a.review(review_mode=ReviewMode.AUTO)
|
||||
|
|
@ -126,7 +126,7 @@ async def test_action_node_review():
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = await node.review()
|
||||
|
||||
_ = await node.fill(context=None, llm=LLM())
|
||||
_ = await node.fill(req=None, llm=LLM())
|
||||
|
||||
review_comments = await node.review(review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 1
|
||||
|
|
@ -151,7 +151,7 @@ async def test_action_node_revise():
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = await node_a.review()
|
||||
|
||||
_ = await node_a.fill(context=None, llm=LLM())
|
||||
_ = await node_a.fill(req=None, llm=LLM())
|
||||
setattr(node_a.instruct_content, key, "game snake") # wrong content to revise
|
||||
revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
|
|
@ -164,7 +164,7 @@ async def test_action_node_revise():
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = await node.revise()
|
||||
|
||||
_ = await node.fill(context=None, llm=LLM())
|
||||
_ = await node.fill(req=None, llm=LLM())
|
||||
setattr(node.instruct_content, key, "game snake")
|
||||
revise_contents = await node.revise(revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
|
|
@ -257,7 +257,7 @@ async def test_action_node_with_image(mocker):
|
|||
invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png")
|
||||
img_base64 = encode_image(invoice_path)
|
||||
mocker.patch("metagpt.provider.openai_api.OpenAILLM._cons_kwargs", _cons_kwargs)
|
||||
node = await invoice.fill(context="", llm=LLM(), images=[img_base64])
|
||||
node = await invoice.fill(req="", llm=LLM(), images=[img_base64])
|
||||
assert node.instruct_content.invoice
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ async def test_write_design_an(mocker):
|
|||
mocker.patch("metagpt.actions.design_api_an.REFINED_DESIGN_NODE.fill", return_value=root)
|
||||
|
||||
prompt = NEW_REQ_TEMPLATE.format(old_design=DESIGN_SAMPLE, context=dict_to_markdown(REFINED_PRD_JSON))
|
||||
node = await REFINED_DESIGN_NODE.fill(prompt, llm)
|
||||
node = await REFINED_DESIGN_NODE.fill(req=prompt, llm=llm)
|
||||
|
||||
assert "Refined Implementation Approach" in node.instruct_content.model_dump()
|
||||
assert "Refined File list" in node.instruct_content.model_dump()
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ async def test_project_management_an(mocker):
|
|||
root.instruct_content.model_dump = mock_task_json
|
||||
mocker.patch("metagpt.actions.project_management_an.PM_NODE.fill", return_value=root)
|
||||
|
||||
node = await PM_NODE.fill(dict_to_markdown(REFINED_DESIGN_JSON), llm)
|
||||
node = await PM_NODE.fill(req=dict_to_markdown(REFINED_DESIGN_JSON), llm=llm)
|
||||
|
||||
assert "Logic Analysis" in node.instruct_content.model_dump()
|
||||
assert "Task list" in node.instruct_content.model_dump()
|
||||
|
|
@ -59,7 +59,7 @@ async def test_project_management_an_inc(mocker):
|
|||
mocker.patch("metagpt.actions.project_management_an.REFINED_PM_NODE.fill", return_value=root)
|
||||
|
||||
prompt = NEW_REQ_TEMPLATE.format(old_task=TASK_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON))
|
||||
node = await REFINED_PM_NODE.fill(prompt, llm)
|
||||
node = await REFINED_PM_NODE.fill(req=prompt, llm=llm)
|
||||
|
||||
assert "Refined Logic Analysis" in node.instruct_content.model_dump()
|
||||
assert "Refined Task list" in node.instruct_content.model_dump()
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ async def test_write_prd_an(mocker):
|
|||
requirements=NEW_REQUIREMENT_SAMPLE,
|
||||
old_prd=PRD_SAMPLE,
|
||||
)
|
||||
node = await REFINED_PRD_NODE.fill(prompt, llm)
|
||||
node = await REFINED_PRD_NODE.fill(req=prompt, llm=llm)
|
||||
|
||||
assert "Refined Requirements" in node.instruct_content.model_dump()
|
||||
assert "Refined Product Goals" in node.instruct_content.model_dump()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import (
|
||||
EXP_TEMPLATE,
|
||||
BaseContextBuilder,
|
||||
Experience,
|
||||
)
|
||||
from metagpt.exp_pool.schema import Metric, Score
|
||||
|
||||
|
||||
class TestBaseContextBuilder:
|
||||
class ConcreteContextBuilder(BaseContextBuilder):
|
||||
async def build(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@pytest.fixture
|
||||
def context_builder(self):
|
||||
return self.ConcreteContextBuilder()
|
||||
|
||||
def test_format_exps(self, context_builder):
|
||||
exp1 = Experience(req="req1", resp="resp1", metric=Metric(score=Score(val=8)))
|
||||
exp2 = Experience(req="req2", resp="resp2", metric=Metric(score=Score(val=9)))
|
||||
context_builder.exps = [exp1, exp2]
|
||||
|
||||
result = context_builder.format_exps()
|
||||
expected = "\n".join(
|
||||
[
|
||||
f"1. {EXP_TEMPLATE.format(req='req1', resp='resp1', score=8)}",
|
||||
f"2. {EXP_TEMPLATE.format(req='req2', resp='resp2', score=9)}",
|
||||
]
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_replace_content_between_markers(self):
|
||||
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
|
||||
new_content = "New content"
|
||||
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
|
||||
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
|
||||
assert result == expected
|
||||
|
||||
def test_replace_content_between_markers_no_match(self):
|
||||
text = "Start\nNo markers\nEnd"
|
||||
new_content = "New content"
|
||||
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
|
||||
assert result == text
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder
|
||||
|
||||
|
||||
class TestRoleZeroContextBuilder:
|
||||
@pytest.fixture
|
||||
def context_builder(self):
|
||||
return RoleZeroContextBuilder()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_empty_req(self, context_builder):
|
||||
result = await context_builder.build(req=[])
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_no_experiences(self, context_builder, mocker):
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="")
|
||||
req = [{"content": "Original content"}]
|
||||
result = await context_builder.build(req=req)
|
||||
assert result == req
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_with_experiences(self, context_builder, mocker):
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="Formatted experiences")
|
||||
mocker.patch.object(RoleZeroContextBuilder, "replace_example_content", return_value="Updated content")
|
||||
req = [{"content": "Original content"}]
|
||||
result = await context_builder.build(req=req)
|
||||
assert result == [{"content": "Updated content"}]
|
||||
|
||||
def test_replace_example_content(self, context_builder, mocker):
|
||||
mocker.patch.object(BaseContextBuilder, "replace_content_between_markers", return_value="Replaced content")
|
||||
result = context_builder.replace_example_content("Original text", "New example content")
|
||||
assert result == "Replaced content"
|
||||
context_builder.replace_content_between_markers.assert_called_once_with(
|
||||
"Original text", "# Example", "# Instruction", "New example content"
|
||||
)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
from metagpt.exp_pool.context_builders.simple import (
|
||||
SIMPLE_CONTEXT_TEMPLATE,
|
||||
SimpleContextBuilder,
|
||||
)
|
||||
|
||||
|
||||
class TestSimpleContextBuilder:
|
||||
@pytest.fixture
|
||||
def context_builder(self):
|
||||
return SimpleContextBuilder()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_with_experiences(self, context_builder, mocker):
|
||||
# Mock the format_exps method
|
||||
mock_exps = "Mocked experiences"
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps)
|
||||
|
||||
req = "Test request"
|
||||
result = await context_builder.build(req=req)
|
||||
|
||||
expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=mock_exps)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_without_experiences(self, context_builder, mocker):
|
||||
# Mock the format_exps method to return an empty string
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="")
|
||||
|
||||
req = "Test request"
|
||||
result = await context_builder.build(req=req)
|
||||
|
||||
assert result == req
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_without_req(self, context_builder, mocker):
|
||||
# Mock the format_exps method
|
||||
mock_exps = "Mocked experiences"
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps)
|
||||
|
||||
result = await context_builder.build()
|
||||
|
||||
expected = SIMPLE_CONTEXT_TEMPLATE.format(req="", exps=mock_exps)
|
||||
assert result == expected
|
||||
|
|
@ -1,29 +1,17 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.exp_pool.context_builders import SimpleContextBuilder
|
||||
from metagpt.exp_pool.decorator import ExpCacheHandler, exp_cache
|
||||
from metagpt.exp_pool.manager import ExperienceManager
|
||||
from metagpt.exp_pool.perfect_judges import SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import Experience, QueryType, Score
|
||||
from metagpt.exp_pool.scorers import SimpleScorer
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
def for_test_function(a, b, c=None):
|
||||
return a + b if c is None else a + b + c
|
||||
|
||||
|
||||
class ForTestClass:
|
||||
def for_test_method(self, x, y):
|
||||
return x * y
|
||||
|
||||
@classmethod
|
||||
def for_test_class_method(cls, x, y):
|
||||
return x**y
|
||||
|
||||
|
||||
class TestExpCache:
|
||||
class TestExpCacheHandler:
|
||||
@pytest.fixture
|
||||
def mock_func(self, mocker):
|
||||
return mocker.AsyncMock()
|
||||
|
|
@ -34,7 +22,6 @@ class TestExpCache:
|
|||
manager.storage = mocker.MagicMock(spec=SimpleEngine)
|
||||
manager.query_exps = mocker.AsyncMock()
|
||||
manager.create_exp = mocker.MagicMock()
|
||||
manager.extract_one_perfect_exp = mocker.MagicMock()
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -44,174 +31,165 @@ class TestExpCache:
|
|||
return scorer
|
||||
|
||||
@pytest.fixture
|
||||
def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer):
|
||||
def mock_perfect_judge(self, mocker):
|
||||
return mocker.MagicMock(spec=SimplePerfectJudge)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_builder(self, mocker):
|
||||
return mocker.MagicMock(spec=SimpleContextBuilder)
|
||||
|
||||
@pytest.fixture
|
||||
def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer, mock_perfect_judge, mock_context_builder):
|
||||
return ExpCacheHandler(
|
||||
func=mock_func, args=(), kwargs={}, exp_manager=mock_exp_manager, exp_scorer=mock_scorer, pass_exps=False
|
||||
func=mock_func,
|
||||
args=(),
|
||||
kwargs={"req": "test_req"},
|
||||
exp_manager=mock_exp_manager,
|
||||
exp_scorer=mock_scorer,
|
||||
exp_perfect_judge=mock_perfect_judge,
|
||||
context_builder=mock_context_builder,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_experiences(self, exp_cache_handler, mock_exp_manager):
|
||||
await exp_cache_handler.fetch_experiences(QueryType.SEMANTIC)
|
||||
mock_exp_manager.query_exps.assert_called_once()
|
||||
mock_exp_manager.query_exps.return_value = [Experience(req="test_req", resp="test_resp")]
|
||||
await exp_cache_handler.fetch_experiences()
|
||||
mock_exp_manager.query_exps.assert_called_once_with(
|
||||
"test_req", query_type=QueryType.SEMANTIC, tag=exp_cache_handler.tag
|
||||
)
|
||||
assert len(exp_cache_handler._exps) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perfect_experience_found(self, exp_cache_handler, mock_exp_manager, mock_func):
|
||||
# Setup: Assume perfect experience is found
|
||||
perfect_exp = Experience(req="req", resp="resp")
|
||||
mock_exp_manager.extract_one_perfect_exp.return_value = perfect_exp
|
||||
|
||||
# Exec
|
||||
exp_cache_handler._exps = [perfect_exp] # Simulate fetched experiences
|
||||
result = exp_cache_handler.get_one_perfect_experience()
|
||||
|
||||
# Assert
|
||||
assert result.resp == "resp"
|
||||
mock_func.assert_not_called() # Function should not be called
|
||||
async def test_get_one_perfect_exp(self, exp_cache_handler, mock_perfect_judge):
|
||||
exp = Experience(req="test_req", resp="perfect_resp")
|
||||
exp_cache_handler._exps = [exp]
|
||||
mock_perfect_judge.is_perfect_exp.return_value = True
|
||||
result = await exp_cache_handler.get_one_perfect_exp()
|
||||
assert result == "perfect_resp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_function_when_no_perfect_exp(self, exp_cache_handler, mock_exp_manager, mock_func):
|
||||
# Setup: No perfect experience
|
||||
mock_exp_manager.extract_one_perfect_exp.return_value = None
|
||||
mock_func.return_value = "Computed result"
|
||||
|
||||
# Exec
|
||||
async def test_execute_function(self, exp_cache_handler, mock_func, mock_context_builder):
|
||||
mock_context_builder.build.return_value = "built_context"
|
||||
mock_func.return_value = "function_result"
|
||||
await exp_cache_handler.execute_function()
|
||||
|
||||
# Assert
|
||||
assert exp_cache_handler._result == "Computed result"
|
||||
mock_func.assert_called_once()
|
||||
mock_context_builder.build.assert_called_once()
|
||||
mock_func.assert_called_once_with(req="built_context")
|
||||
assert exp_cache_handler._raw_resp == "function_result"
|
||||
assert exp_cache_handler._resp == "function_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_and_save_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager):
|
||||
# Setup
|
||||
mock_scorer.evaluate.return_value = Score(value=100)
|
||||
exp_cache_handler._result = "Computed result"
|
||||
|
||||
# Exec
|
||||
await exp_cache_handler.evaluate_experience()
|
||||
exp_cache_handler.save_experience()
|
||||
|
||||
# Assert
|
||||
async def test_process_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager):
|
||||
exp_cache_handler._resp = "test_resp"
|
||||
mock_scorer.evaluate.return_value = Score(val=8)
|
||||
await exp_cache_handler.process_experience()
|
||||
mock_scorer.evaluate.assert_called_once()
|
||||
mock_exp_manager.create_exp.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_execution_with_exps(self, exp_cache_handler, mock_exp_manager, mock_func):
|
||||
# Setup
|
||||
exp_cache_handler.pass_exps_to_func = True
|
||||
mock_func.return_value = "Async result with exps"
|
||||
mock_exp_manager.extract_one_perfect_exp.return_value = None
|
||||
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
|
||||
async def test_evaluate_experience(self, exp_cache_handler, mock_scorer):
|
||||
exp_cache_handler._resp = "test_resp"
|
||||
mock_scorer.evaluate.return_value = Score(val=9)
|
||||
await exp_cache_handler.evaluate_experience()
|
||||
assert exp_cache_handler._score.val == 9
|
||||
|
||||
# Exec
|
||||
await exp_cache_handler.execute_function()
|
||||
|
||||
# Assert
|
||||
mock_func.assert_called_once_with(exps=exp_cache_handler._exps)
|
||||
assert exp_cache_handler._result == "Async result with exps"
|
||||
|
||||
def test_sync_function_execution_with_exps(self, mocker, exp_cache_handler, mock_exp_manager, mock_func):
|
||||
# Setup
|
||||
exp_cache_handler.func = mocker.Mock(return_value="Sync result with exps")
|
||||
exp_cache_handler.pass_exps_to_func = True
|
||||
mock_exp_manager.extract_one_perfect_exp.return_value = None
|
||||
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
|
||||
|
||||
# Exec
|
||||
asyncio.get_event_loop().run_until_complete(exp_cache_handler.execute_function())
|
||||
|
||||
# Assert
|
||||
exp_cache_handler.func.assert_called_once_with(exps=exp_cache_handler._exps)
|
||||
assert exp_cache_handler._result == "Sync result with exps"
|
||||
|
||||
def test_wrapper_selection_async(self, mocker, exp_cache_handler, mock_func):
|
||||
# Setup
|
||||
mock_func = mocker.AsyncMock()
|
||||
|
||||
# Exec
|
||||
wrapper = ExpCacheHandler.choose_wrapper(mock_func, exp_cache_handler.execute_function)
|
||||
|
||||
# Assert
|
||||
assert asyncio.iscoroutinefunction(wrapper), "Wrapper should be asynchronous"
|
||||
|
||||
def test_wrapper_selection_sync(self, exp_cache_handler, mocker):
|
||||
# Setup
|
||||
sync_func = mocker.Mock()
|
||||
|
||||
# Exec
|
||||
wrapper = ExpCacheHandler.choose_wrapper(sync_func, exp_cache_handler.execute_function)
|
||||
|
||||
# Assert
|
||||
assert not asyncio.iscoroutinefunction(wrapper), "Wrapper should be synchronous"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func, args, kwargs, expected",
|
||||
[
|
||||
(for_test_function, (1, 2), {"c": 3}, 'for_test_function@[1~2]@{"c"!3}'),
|
||||
(ForTestClass().for_test_method, (4, 5), {}, "ForTestClass.for_test_method@[4~5]@{}"),
|
||||
(ForTestClass.for_test_class_method, (6, 7), {}, "ForTestClass.for_test_class_method@[6~7]@{}"),
|
||||
(for_test_function, (), {}, "for_test_function@[]@{}"),
|
||||
(
|
||||
for_test_function,
|
||||
("hello", [1, 2]),
|
||||
{"key": "value"},
|
||||
'for_test_function@["hello"~[1~2]]@{"key"!"value"}',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_generate_req_identifier(self, func, args, kwargs, expected):
|
||||
req_identifier = ExpCacheHandler.generate_req_identifier(func, *args, **kwargs)
|
||||
assert req_identifier == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_with_perfect_experience(self, mocker, mock_exp_manager):
|
||||
# Mock perfect experience
|
||||
perfect_exp = Experience(req="test_req", resp="perfect_response")
|
||||
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[perfect_exp])
|
||||
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=perfect_exp)
|
||||
async_mock_func = mocker.AsyncMock()
|
||||
|
||||
# Setup
|
||||
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager)
|
||||
|
||||
# Exec
|
||||
result: Experience = await decorated_func()
|
||||
|
||||
# Assert
|
||||
assert result.resp == "perfect_response", "Should return the perfect experience response"
|
||||
async_mock_func.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_without_perfect_experience(self, mocker, mock_exp_manager):
|
||||
# Mock
|
||||
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[])
|
||||
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None)
|
||||
async_mock_func = mocker.AsyncMock(return_value="computed_response")
|
||||
async_mock_func.__signature__ = inspect.signature(for_test_function)
|
||||
|
||||
# Setup
|
||||
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager)
|
||||
|
||||
# Exec
|
||||
result = await decorated_func()
|
||||
|
||||
# Assert
|
||||
assert result == "computed_response", "Should execute and return the function's response"
|
||||
async_mock_func.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_saves_new_experience(self, mocker, mock_exp_manager, mock_scorer):
|
||||
# Mock
|
||||
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[])
|
||||
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None)
|
||||
async_mock_func = mocker.AsyncMock(return_value="computed_response")
|
||||
mock_scorer.evaluate = mocker.AsyncMock(return_value=Score(value=100))
|
||||
|
||||
# Setup
|
||||
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager, scorer=mock_scorer)
|
||||
|
||||
# Exec
|
||||
await decorated_func()
|
||||
|
||||
# Assert
|
||||
def test_save_experience(self, exp_cache_handler, mock_exp_manager):
|
||||
exp_cache_handler._req = "test_req"
|
||||
exp_cache_handler._resp = "test_resp"
|
||||
exp_cache_handler._score = Score(val=7)
|
||||
exp_cache_handler.save_experience()
|
||||
mock_exp_manager.create_exp.assert_called_once()
|
||||
|
||||
def test_choose_wrapper_async(self, mocker):
|
||||
async def async_func():
|
||||
pass
|
||||
|
||||
wrapper = ExpCacheHandler.choose_wrapper(async_func, mocker.AsyncMock())
|
||||
assert asyncio.iscoroutinefunction(wrapper)
|
||||
|
||||
def test_choose_wrapper_sync(self, mocker):
|
||||
def sync_func():
|
||||
pass
|
||||
|
||||
wrapper = ExpCacheHandler.choose_wrapper(sync_func, mocker.AsyncMock())
|
||||
assert not asyncio.iscoroutinefunction(wrapper)
|
||||
|
||||
def test_validate_params(self):
|
||||
with pytest.raises(ValueError):
|
||||
ExpCacheHandler(func=lambda x: x, args=(), kwargs={})
|
||||
|
||||
def test_generate_tag(self):
|
||||
class TestClass:
|
||||
def test_method(self):
|
||||
pass
|
||||
|
||||
handler = ExpCacheHandler(func=TestClass().test_method, args=(TestClass(),), kwargs={"req": "test"})
|
||||
assert handler._generate_tag() == "TestClass.test_method"
|
||||
|
||||
handler = ExpCacheHandler(func=lambda x: x, args=(), kwargs={"req": "test"})
|
||||
assert handler._generate_tag() == "<lambda>"
|
||||
|
||||
|
||||
class TestExpCache:
|
||||
@pytest.fixture
|
||||
def mock_exp_manager(self, mocker):
|
||||
manager = mocker.MagicMock(spec=ExperienceManager)
|
||||
manager.storage = mocker.MagicMock(spec=SimpleEngine)
|
||||
manager.query_exps = mocker.AsyncMock()
|
||||
manager.create_exp = mocker.MagicMock()
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scorer(self, mocker):
|
||||
scorer = mocker.MagicMock(spec=SimpleScorer)
|
||||
scorer.evaluate = mocker.AsyncMock(return_value=Score())
|
||||
return scorer
|
||||
|
||||
@pytest.fixture
|
||||
def mock_perfect_judge(self, mocker):
|
||||
return mocker.MagicMock(spec=SimplePerfectJudge)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.exp_pool.decorator.config")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_disabled(self, mock_config, mock_exp_manager):
|
||||
mock_config.exp_pool.enable_read = False
|
||||
|
||||
@exp_cache(manager=mock_exp_manager)
|
||||
async def test_func(req):
|
||||
return "result"
|
||||
|
||||
result = await test_func(req="test")
|
||||
assert result == "result"
|
||||
mock_exp_manager.query_exps.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_enabled_no_perfect_exp(self, mock_config, mock_exp_manager, mock_scorer):
|
||||
mock_config.exp_pool.enable_read = True
|
||||
mock_exp_manager.query_exps.return_value = []
|
||||
|
||||
@exp_cache(manager=mock_exp_manager, scorer=mock_scorer)
|
||||
async def test_func(req):
|
||||
return "computed_result"
|
||||
|
||||
result = await test_func(req="test")
|
||||
assert result == "computed_result"
|
||||
mock_exp_manager.query_exps.assert_called()
|
||||
mock_exp_manager.create_exp.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_enabled_with_perfect_exp(self, mock_config, mock_exp_manager, mock_perfect_judge):
|
||||
mock_config.exp_pool.enable_read = True
|
||||
perfect_exp = Experience(req="test", resp="perfect_result")
|
||||
mock_exp_manager.query_exps.return_value = [perfect_exp]
|
||||
mock_perfect_judge.is_perfect_exp.return_value = True
|
||||
|
||||
@exp_cache(manager=mock_exp_manager, perfect_judge=mock_perfect_judge)
|
||||
async def test_func(req):
|
||||
return "should_not_be_called"
|
||||
|
||||
result = await test_func(req="test")
|
||||
assert result == "perfect_result"
|
||||
mock_exp_manager.query_exps.assert_called_once()
|
||||
mock_exp_manager.create_exp.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -4,20 +4,25 @@ from metagpt.config2 import Config
|
|||
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.exp_pool.manager import ExperienceManager
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score
|
||||
from metagpt.exp_pool.schema import Experience
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
class TestExperienceManager:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
|
||||
return Config(
|
||||
llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, init_exp=False)
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self, mocker):
|
||||
engine = mocker.MagicMock(spec=SimpleEngine)
|
||||
engine.add_objs = mocker.MagicMock()
|
||||
engine.aretrieve = mocker.AsyncMock(return_value=[])
|
||||
engine._retriever = mocker.MagicMock()
|
||||
engine._retriever._vector_store = mocker.MagicMock()
|
||||
engine._retriever._vector_store._get = mocker.MagicMock(return_value=mocker.MagicMock(ids=[]))
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -33,7 +38,7 @@ class TestExperienceManager:
|
|||
|
||||
def test_create_exp(self, mock_experience_manager, mock_experience):
|
||||
mock_experience_manager.create_exp(mock_experience)
|
||||
mock_experience_manager.storage.add_objs.assert_called_once_with([mock_experience])
|
||||
mock_experience_manager.storage.add_objs.assert_called_with([mock_experience])
|
||||
|
||||
def test_create_exp_write_disabled(self, mock_experience_manager, mock_experience, mock_config):
|
||||
mock_config.exp_pool.enable_write = False
|
||||
|
|
@ -60,18 +65,44 @@ class TestExperienceManager:
|
|||
result = await mock_experience_manager.query_exps("query")
|
||||
assert result == []
|
||||
|
||||
def test_extract_one_perfect_exp(self, mock_experience_manager):
|
||||
experiences = [
|
||||
Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))),
|
||||
Experience(req="req", resp="resp"),
|
||||
]
|
||||
perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences)
|
||||
assert perfect_exp is not None
|
||||
assert perfect_exp.metric.score.val == MAX_SCORE
|
||||
def test_init_exp_pool(self, mock_experience_manager, mock_config, mocker):
|
||||
mock_experience_manager._has_exps = mocker.MagicMock(return_value=False)
|
||||
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
|
||||
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
|
||||
|
||||
def test_is_perfect_exp(self):
|
||||
exp = Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
|
||||
assert ExperienceManager.is_perfect_exp(exp) == True
|
||||
mock_config.exp_pool.init_exp = True
|
||||
mock_experience_manager.init_exp_pool()
|
||||
|
||||
exp = Experience(req="req", resp="resp")
|
||||
assert ExperienceManager.is_perfect_exp(exp) == False
|
||||
mock_experience_manager._has_exps.assert_called_once()
|
||||
mock_experience_manager._init_teamleader_exps.assert_called_once()
|
||||
mock_experience_manager._init_engineer2_exps.assert_called_once()
|
||||
|
||||
def test_init_exp_pool_already_has_exps(self, mock_experience_manager, mock_config, mocker):
|
||||
mock_experience_manager._has_exps = mocker.MagicMock(return_value=True)
|
||||
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
|
||||
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
|
||||
|
||||
mock_config.exp_pool.init_exp = True
|
||||
mock_experience_manager.init_exp_pool()
|
||||
|
||||
mock_experience_manager._has_exps.assert_called_once()
|
||||
mock_experience_manager._init_teamleader_exps.assert_not_called()
|
||||
mock_experience_manager._init_engineer2_exps.assert_not_called()
|
||||
|
||||
def test_has_exps(self, mock_experience_manager, mock_storage):
|
||||
mock_storage._retriever._vector_store._get.return_value.ids = ["id1"]
|
||||
|
||||
assert mock_experience_manager._has_exps() is True
|
||||
|
||||
mock_storage._retriever._vector_store._get.return_value.ids = []
|
||||
assert mock_experience_manager._has_exps() is False
|
||||
|
||||
def test_init_teamleader_exps(self, mock_experience_manager, mocker):
|
||||
mock_experience_manager._init_exp = mocker.MagicMock()
|
||||
mock_experience_manager._init_teamleader_exps()
|
||||
mock_experience_manager._init_exp.assert_called_once()
|
||||
|
||||
def test_init_engineer2_exps(self, mock_experience_manager, mocker):
|
||||
mock_experience_manager._init_exp = mocker.MagicMock()
|
||||
mock_experience_manager._init_engineer2_exps()
|
||||
mock_experience_manager._init_exp.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.perfect_judges import SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score
|
||||
|
||||
|
||||
class TestSimplePerfectJudge:
|
||||
@pytest.fixture
|
||||
def simple_perfect_judge(self):
|
||||
return SimplePerfectJudge()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_perfect_match(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_imperfect_score(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE - 1)))
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_mismatched_request(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "different_request")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_no_metric(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp")
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_no_score(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp", metric=Metric())
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
|
||||
assert result is False
|
||||
49
tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py
Normal file
49
tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.schema import Score
|
||||
from metagpt.exp_pool.scorers.simple import SIMPLE_SCORER_TEMPLATE, SimpleScorer
|
||||
from metagpt.llm import BaseLLM
|
||||
|
||||
|
||||
class TestSimpleScorer:
|
||||
@pytest.fixture
|
||||
def mock_llm(self, mocker):
|
||||
mock_llm = mocker.MagicMock(spec=BaseLLM)
|
||||
return mock_llm
|
||||
|
||||
@pytest.fixture
|
||||
def simple_scorer(self, mock_llm):
|
||||
return SimpleScorer(llm=mock_llm)
|
||||
|
||||
def test_init(self, mock_llm):
|
||||
scorer = SimpleScorer(llm=mock_llm)
|
||||
assert isinstance(scorer.llm, BaseLLM)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate(self, simple_scorer, mock_llm):
|
||||
# Mock function to evaluate
|
||||
def mock_func(a, b):
|
||||
"""This is a mock function."""
|
||||
return a + b
|
||||
|
||||
# Mock LLM response
|
||||
mock_llm.aask.return_value = '```json\n{"val": 8, "reason": "Good performance"}\n```'
|
||||
|
||||
# Test evaluate method
|
||||
result = await simple_scorer.evaluate(mock_func, 5, args=(2, 3), kwargs={})
|
||||
|
||||
# Assert LLM was called with correct prompt
|
||||
expected_prompt = SIMPLE_SCORER_TEMPLATE.format(
|
||||
func_name=mock_func.__name__,
|
||||
func_doc=mock_func.__doc__,
|
||||
func_signature="(a, b)",
|
||||
func_args=(2, 3),
|
||||
func_kwargs={},
|
||||
func_result=5,
|
||||
)
|
||||
mock_llm.aask.assert_called_once_with(expected_prompt)
|
||||
|
||||
# Assert the result is correct
|
||||
assert isinstance(result, Score)
|
||||
assert result.val == 8
|
||||
assert result.reason == "Good performance"
|
||||
Loading…
Add table
Add a link
Reference in a new issue