add serializers to support serialization and deserialization.

This commit is contained in:
seehi 2024-07-10 10:24:04 +08:00
parent 086ef5e805
commit b5934a412b
19 changed files with 234 additions and 144 deletions

View file

@ -30,16 +30,3 @@ class TestBaseContextBuilder:
]
)
assert result == expected
def test_replace_content_between_markers(self):
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
new_content = "New content"
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
assert result == expected
def test_replace_content_between_markers_no_match(self):
text = "Start\nNo markers\nEnd"
new_content = "New content"
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
assert result == text

View file

@ -30,9 +30,22 @@ class TestRoleZeroContextBuilder:
assert result == [{"content": "Updated content"}]
def test_replace_example_content(self, context_builder, mocker):
mocker.patch.object(BaseContextBuilder, "replace_content_between_markers", return_value="Replaced content")
mocker.patch.object(RoleZeroContextBuilder, "replace_content_between_markers", return_value="Replaced content")
result = context_builder.replace_example_content("Original text", "New example content")
assert result == "Replaced content"
context_builder.replace_content_between_markers.assert_called_once_with(
"Original text", "# Example", "# Instruction", "New example content"
)
def test_replace_content_between_markers(self):
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
new_content = "New content"
result = RoleZeroContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
assert result == expected
def test_replace_content_between_markers_no_match(self):
text = "Start\nNo markers\nEnd"
new_content = "New content"
result = RoleZeroContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
assert result == text

View file

@ -32,7 +32,8 @@ class TestSimpleContextBuilder:
req = "Test request"
result = await context_builder.build(req=req)
assert result == req
expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps="")
assert result == expected
@pytest.mark.asyncio
async def test_build_without_req(self, context_builder, mocker):

View file

@ -11,9 +11,7 @@ from metagpt.rag.engines import SimpleEngine
class TestExperienceManager:
@pytest.fixture
def mock_config(self):
return Config(
llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, init_exp=False)
)
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
@pytest.fixture
def mock_storage(self, mocker):
@ -65,30 +63,6 @@ class TestExperienceManager:
result = await mock_experience_manager.query_exps("query")
assert result == []
def test_init_exp_pool(self, mock_experience_manager, mock_config, mocker):
mock_experience_manager._has_exps = mocker.MagicMock(return_value=False)
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
mock_config.exp_pool.init_exp = True
mock_experience_manager.init_exp_pool()
mock_experience_manager._has_exps.assert_called_once()
mock_experience_manager._init_teamleader_exps.assert_called_once()
mock_experience_manager._init_engineer2_exps.assert_called_once()
def test_init_exp_pool_already_has_exps(self, mock_experience_manager, mock_config, mocker):
mock_experience_manager._has_exps = mocker.MagicMock(return_value=True)
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
mock_config.exp_pool.init_exp = True
mock_experience_manager.init_exp_pool()
mock_experience_manager._has_exps.assert_called_once()
mock_experience_manager._init_teamleader_exps.assert_not_called()
mock_experience_manager._init_engineer2_exps.assert_not_called()
def test_has_exps(self, mock_experience_manager, mock_storage):
mock_storage._retriever._vector_store._get.return_value.ids = ["id1"]