mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 20:32:38 +02:00
use serialize in SerializationMixin
This commit is contained in:
parent
887f180e58
commit
587dd0cc81
5 changed files with 150 additions and 125 deletions
|
|
@ -11,11 +11,12 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.schema import (
|
||||
AIMessage,
|
||||
CodeSummarizeContext,
|
||||
|
|
@ -23,6 +24,7 @@ from metagpt.schema import (
|
|||
Message,
|
||||
MessageQueue,
|
||||
Plan,
|
||||
SerializationMixin,
|
||||
SystemMessage,
|
||||
Task,
|
||||
UMLClassAttribute,
|
||||
|
|
@ -398,5 +400,72 @@ def test_create_instruct_value(name, value):
|
|||
assert obj.model_dump() == value
|
||||
|
||||
|
||||
class TestUserModel(SerializationMixin, BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
class TestUserModelWithRemove(TestUserModel):
|
||||
def remove_unserializable(self, data: dict):
|
||||
for key in ["value", "__module_class_name"]:
|
||||
data.pop(key, None)
|
||||
|
||||
|
||||
class TestSerializationMixin:
|
||||
@pytest.fixture
|
||||
def mock_write_json_file(self, mocker):
|
||||
return mocker.patch("metagpt.schema.write_json_file")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_read_json_file(self, mocker):
|
||||
return mocker.patch("metagpt.schema.read_json_file")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_model(self):
|
||||
return TestUserModel(name="test", value=42)
|
||||
|
||||
def test_serialize(self, mock_write_json_file, mock_user_model):
|
||||
file_path = "test.json"
|
||||
|
||||
mock_user_model.serialize(file_path)
|
||||
|
||||
mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump())
|
||||
|
||||
def test_deserialize(self, mock_read_json_file):
|
||||
file_path = "test.json"
|
||||
data = {"name": "test", "value": 42}
|
||||
mock_read_json_file.return_value = data
|
||||
|
||||
model = TestUserModel.deserialize(file_path)
|
||||
|
||||
mock_read_json_file.assert_called_once_with(file_path)
|
||||
assert model == TestUserModel(**data)
|
||||
|
||||
def test_serialize_with_remove_unserializable(self, mock_write_json_file):
|
||||
model = TestUserModelWithRemove(name="test", value=42)
|
||||
file_path = "test.json"
|
||||
|
||||
model.serialize(file_path)
|
||||
|
||||
mock_write_json_file.assert_called_once_with(file_path, {"name": "test"})
|
||||
|
||||
def test_get_serialization_path(self):
|
||||
expected_path = str(SERDESER_PATH / "TestUserModel.json")
|
||||
|
||||
assert TestUserModel.get_serialization_path() == expected_path
|
||||
|
||||
def test_remove_unserializable(self, mock_user_model):
|
||||
data = {
|
||||
"name": "example",
|
||||
"unserializable_fields": ["temp_data", "debug_info"],
|
||||
"temp_data": "some temporary data",
|
||||
"debug_info": "some debug information",
|
||||
}
|
||||
mock_user_model.remove_unserializable(data)
|
||||
|
||||
expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]}
|
||||
assert data == expected_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -6,17 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import (
|
||||
actionoutout_schema_to_mapping,
|
||||
deserialize_message,
|
||||
deserialize_model,
|
||||
serialize_message,
|
||||
serialize_model,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -70,35 +66,3 @@ def test_serialize_and_deserialize_message():
|
|||
assert new_message.content == message.content
|
||||
assert new_message.cause_by == message.cause_by
|
||||
assert new_message.instruct_content.field1 == out_data["field1"]
|
||||
|
||||
|
||||
class TestUserModel(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
def test_serialize_model(mocker):
|
||||
model = TestUserModel(name="test", value=42)
|
||||
file_path = "test.json"
|
||||
mock_write_json_file = mocker.patch("metagpt.utils.serialize.write_json_file")
|
||||
|
||||
# Test without remove_unserializable
|
||||
serialize_model(model, file_path)
|
||||
mock_write_json_file.assert_called_once_with(file_path, model.model_dump())
|
||||
|
||||
# Test with remove_unserializable
|
||||
def remove_unserializable(data: dict):
|
||||
data.pop("value", None)
|
||||
|
||||
serialize_model(model, file_path, remove_unserializable)
|
||||
mock_write_json_file.assert_called_with(file_path, {"name": "test"})
|
||||
|
||||
|
||||
def test_deserialize_model(mocker):
|
||||
file_path = "test.json"
|
||||
data = {"name": "test", "value": 42}
|
||||
mock_read_json_file = mocker.patch("metagpt.utils.serialize.read_json_file", return_value=data)
|
||||
|
||||
model = deserialize_model(TestUserModel, file_path)
|
||||
mock_read_json_file.assert_called_once_with(file_path)
|
||||
assert model == TestUserModel(**data)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue