use pydantic's exclude

This commit is contained in:
seehi 2024-08-12 17:16:28 +08:00
parent 11894d12f3
commit b4207cec92
4 changed files with 17 additions and 49 deletions

View file

@ -9,9 +9,10 @@
"""
import json
from typing import Annotated
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
@ -405,10 +406,8 @@ class TestUserModel(SerializationMixin, BaseModel):
value: int
class TestUserModelWithRemove(TestUserModel):
def remove_unserializable(self, data: dict):
for key in ["value", "__module_class_name"]:
data.pop(key, None)
class TestUserModelWithExclude(TestUserModel):
age: Annotated[int, Field(exclude=True)]
class TestSerializationMixin:
@ -441,31 +440,25 @@ class TestSerializationMixin:
mock_read_json_file.assert_called_once_with(file_path)
assert model == TestUserModel(**data)
def test_serialize_with_remove_unserializable(self, mock_write_json_file):
model = TestUserModelWithRemove(name="test", value=42)
def test_serialize_with_exclude(self, mock_write_json_file):
model = TestUserModelWithExclude(name="test", value=42, age=10)
file_path = "test.json"
model.serialize(file_path)
mock_write_json_file.assert_called_once_with(file_path, {"name": "test"})
expected_data = {
"name": "test",
"value": 42,
"__module_class_name": "tests.metagpt.test_schema.TestUserModelWithExclude",
}
mock_write_json_file.assert_called_once_with(file_path, expected_data)
def test_get_serialization_path(self):
expected_path = str(SERDESER_PATH / "TestUserModel.json")
assert TestUserModel.get_serialization_path() == expected_path
def test_remove_unserializable(self, mock_user_model):
data = {
"name": "example",
"unserializable_fields": ["temp_data", "debug_info"],
"temp_data": "some temporary data",
"debug_info": "some debug information",
}
mock_user_model.remove_unserializable(data)
expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]}
assert data == expected_data
if __name__ == "__main__":
pytest.main([__file__, "-s"])