use serialize in SerializationMixin

This commit is contained in:
seehi 2024-08-12 14:37:00 +08:00
parent 887f180e58
commit 587dd0cc81
5 changed files with 150 additions and 125 deletions

View file

@ -11,11 +11,12 @@
import json
import pytest
from pydantic import BaseModel
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
CodeSummarizeContext,
@ -23,6 +24,7 @@ from metagpt.schema import (
Message,
MessageQueue,
Plan,
SerializationMixin,
SystemMessage,
Task,
UMLClassAttribute,
@ -398,5 +400,72 @@ def test_create_instruct_value(name, value):
assert obj.model_dump() == value
class TestUserModel(SerializationMixin, BaseModel):
name: str
value: int
class TestUserModelWithRemove(TestUserModel):
def remove_unserializable(self, data: dict):
for key in ["value", "__module_class_name"]:
data.pop(key, None)
class TestSerializationMixin:
@pytest.fixture
def mock_write_json_file(self, mocker):
return mocker.patch("metagpt.schema.write_json_file")
@pytest.fixture
def mock_read_json_file(self, mocker):
return mocker.patch("metagpt.schema.read_json_file")
@pytest.fixture
def mock_user_model(self):
return TestUserModel(name="test", value=42)
def test_serialize(self, mock_write_json_file, mock_user_model):
file_path = "test.json"
mock_user_model.serialize(file_path)
mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump())
def test_deserialize(self, mock_read_json_file):
file_path = "test.json"
data = {"name": "test", "value": 42}
mock_read_json_file.return_value = data
model = TestUserModel.deserialize(file_path)
mock_read_json_file.assert_called_once_with(file_path)
assert model == TestUserModel(**data)
def test_serialize_with_remove_unserializable(self, mock_write_json_file):
model = TestUserModelWithRemove(name="test", value=42)
file_path = "test.json"
model.serialize(file_path)
mock_write_json_file.assert_called_once_with(file_path, {"name": "test"})
def test_get_serialization_path(self):
expected_path = str(SERDESER_PATH / "TestUserModel.json")
assert TestUserModel.get_serialization_path() == expected_path
def test_remove_unserializable(self, mock_user_model):
data = {
"name": "example",
"unserializable_fields": ["temp_data", "debug_info"],
"temp_data": "some temporary data",
"debug_info": "some debug information",
}
mock_user_model.remove_unserializable(data)
expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]}
assert data == expected_data
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -6,17 +6,13 @@
from typing import List
from pydantic import BaseModel
from metagpt.actions import WritePRD
from metagpt.actions.action_node import ActionNode
from metagpt.schema import Message
from metagpt.utils.serialize import (
actionoutout_schema_to_mapping,
deserialize_message,
deserialize_model,
serialize_message,
serialize_model,
)
@ -70,35 +66,3 @@ def test_serialize_and_deserialize_message():
assert new_message.content == message.content
assert new_message.cause_by == message.cause_by
assert new_message.instruct_content.field1 == out_data["field1"]
class TestUserModel(BaseModel):
name: str
value: int
def test_serialize_model(mocker):
model = TestUserModel(name="test", value=42)
file_path = "test.json"
mock_write_json_file = mocker.patch("metagpt.utils.serialize.write_json_file")
# Test without remove_unserializable
serialize_model(model, file_path)
mock_write_json_file.assert_called_once_with(file_path, model.model_dump())
# Test with remove_unserializable
def remove_unserializable(data: dict):
data.pop("value", None)
serialize_model(model, file_path, remove_unserializable)
mock_write_json_file.assert_called_with(file_path, {"name": "test"})
def test_deserialize_model(mocker):
file_path = "test.json"
data = {"name": "test", "value": 42}
mock_read_json_file = mocker.patch("metagpt.utils.serialize.read_json_file", return_value=data)
model = deserialize_model(TestUserModel, file_path)
mock_read_json_file.assert_called_once_with(file_path)
assert model == TestUserModel(**data)