diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 43ef9c4b5..99f94052a 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -1,5 +1,3 @@ -from typing import ClassVar - from metagpt.actions import ( UserRequirement, WriteDesign, @@ -8,17 +6,15 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT, SERDESER_PATH +from metagpt.const import AGENT from metagpt.environment.base_env import Environment -from metagpt.logs import get_human_input, logger +from metagpt.logs import get_human_input from metagpt.roles import Architect, ProductManager, ProjectManager, Role -from metagpt.schema import Message +from metagpt.schema import Message, SerializationMixin from metagpt.utils.common import any_to_str, any_to_str_set -from metagpt.utils.exceptions import handle_exception -from metagpt.utils.serialize import deserialize_model, serialize_model -class MGXEnv(Environment): +class MGXEnv(Environment, SerializationMixin): """MGX Environment""" # If True, fixed software sop bypassing TL is allowed, otherwise, TL will fully take over the routing @@ -26,8 +22,6 @@ class MGXEnv(Environment): direct_chat_roles: set[str] = set() # record direct chat: @role_name - default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json") - def _publish_message(self, message: Message, peekable: bool = True) -> bool: return super().publish_message(message, peekable) @@ -132,53 +126,13 @@ class MGXEnv(Environment): def __repr__(self): return "MGXEnv()" - @handle_exception - def serialize(self, file_path: str = None) -> str: - """Serializes the current instance to a JSON file. - - If an exception occurs, `handle_exception` will catch it and return `None`. - - Args: - file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None. - - Returns: - str: The path to the JSON file where the instance was saved. - """ - - file_path = file_path or self.default_serialization_path - - serialize_model(self, file_path, remove_unserializable=self.remove_unserializable) - logger.info(f"MGXEnv serialization successful. File saved at: {file_path}") - - return file_path - - @classmethod - @handle_exception - def deserialize(cls, file_path: str = None) -> "MGXEnv": - """Deserializes a JSON file to an instance of MGXEnv. - - If an exception occurs, `handle_exception` will catch it and return `None`. - - Args: - file_path (str, optional): The path to the JSON file to read from. Defaults to None. - - Returns: - MGXEnv: An instance of MGXEnv. - """ - - file_path = file_path or cls.default_serialization_path - - model = deserialize_model(cls, file_path) - logger.info(f"MGXEnv deserialization successful. Instance created from file: {file_path}") - - return model - def remove_unserializable(self, data: dict): """Removes unserializable content from the data dictionary. Args: data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method. """ + roles = data.get("roles", {}) for role in roles.values(): diff --git a/metagpt/schema.py b/metagpt/schema.py index 648e2bd73..2431304db 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -44,6 +44,7 @@ from metagpt.const import ( MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, + SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) @@ -56,6 +57,8 @@ from metagpt.utils.common import ( any_to_str_set, aread, import_class, + read_json_file, + write_json_file, ) from metagpt.utils.exceptions import handle_exception from metagpt.utils.report import TaskReporter @@ -127,6 +130,77 @@ class SerializationMixin(BaseModel, extra="forbid"): cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls super().__init_subclass__(**kwargs) + @handle_exception + def serialize(self, file_path: str = None) -> str: + """Serializes the current instance to a JSON file. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None. + + Returns: + str: The path to the JSON file where the instance was saved. + """ + + file_path = file_path or self.get_serialization_path() + + serialized_data = self.model_dump() + self.remove_unserializable(serialized_data) + + write_json_file(file_path, serialized_data) + logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") + + return file_path + + @classmethod + @handle_exception + def deserialize(cls, file_path: str = None) -> BaseModel: + """Deserializes a JSON file to an instance of cls. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file to read from. Defaults to None. + + Returns: + An instance of the cls. + """ + + file_path = file_path or cls.get_serialization_path() + + data: dict = read_json_file(file_path) + + model = cls(**data) + logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}") + + return model + + @classmethod + def get_serialization_path(cls) -> str: + """Get the serialization path for the class. + + This method constructs a file path for serialization based on the class name. + The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName' + is the name of the class. + + Returns: + str: The path to the serialization file. + """ + + return str(SERDESER_PATH / f"{cls.__qualname__}.json") + + def remove_unserializable(self, data: dict): + """Removes unserializable content from the data dictionary. + + This method removes keys specified in the "unserializable_fields" list from the provided data dictionary. + It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields + that cannot be serialized. + """ + + for key in data.get("unserializable_fields", []): + data.pop(key, None) + class SimpleMessage(BaseModel): content: str diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 814621377..c6bd8ad75 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,11 +4,8 @@ import copy import pickle -from typing import Callable, Optional, Type -from pydantic import BaseModel - -from metagpt.utils.common import import_class, read_json_file, write_json_file +from metagpt.utils.common import import_class def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -84,36 +81,3 @@ def deserialize_message(message_ser: str) -> "Message": message.instruct_content = ic_new return message - - -def serialize_model(model: BaseModel, file_path: str, remove_unserializable: Optional[Callable[[dict], None]] = None): - """Serializes a Pydantic model to a JSON file. - - Args: - model (BaseModel): The Pydantic model to serialize. - file_path (str): The path to the JSON file where the model will be saved. - remove_unserializable (Optional[Callable[[dict], None]]): Optional function to remove unserializable content from the serialized data. - """ - - serialized_data = model.model_dump() - - if remove_unserializable: - remove_unserializable(serialized_data) - - write_json_file(file_path, serialized_data) - - -def deserialize_model(cls: Type[BaseModel], file_path: str) -> BaseModel: - """Deserializes a JSON file to a Pydantic model. - - Args: - cls (Type[BaseModel]): The Pydantic model class to deserialize into. - file_path (str): The path to the JSON file to read from. - - Returns: - BaseModel: An instance of the Pydantic model. - """ - - data: dict = read_json_file(file_path) - - return cls(**data) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 48f13f4a2..12e0c7aab 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -11,11 +11,12 @@ import json import pytest +from pydantic import BaseModel from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode -from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO +from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, CodeSummarizeContext, @@ -23,6 +24,7 @@ from metagpt.schema import ( Message, MessageQueue, Plan, + SerializationMixin, SystemMessage, Task, UMLClassAttribute, @@ -398,5 +400,72 @@ def test_create_instruct_value(name, value): assert obj.model_dump() == value +class TestUserModel(SerializationMixin, BaseModel): + name: str + value: int + + +class TestUserModelWithRemove(TestUserModel): + def remove_unserializable(self, data: dict): + for key in ["value", "__module_class_name"]: + data.pop(key, None) + + +class TestSerializationMixin: + @pytest.fixture + def mock_write_json_file(self, mocker): + return mocker.patch("metagpt.schema.write_json_file") + + @pytest.fixture + def mock_read_json_file(self, mocker): + return mocker.patch("metagpt.schema.read_json_file") + + @pytest.fixture + def mock_user_model(self): + return TestUserModel(name="test", value=42) + + def test_serialize(self, mock_write_json_file, mock_user_model): + file_path = "test.json" + + mock_user_model.serialize(file_path) + + mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump()) + + def test_deserialize(self, mock_read_json_file): + file_path = "test.json" + data = {"name": "test", "value": 42} + mock_read_json_file.return_value = data + + model = TestUserModel.deserialize(file_path) + + mock_read_json_file.assert_called_once_with(file_path) + assert model == TestUserModel(**data) + + def test_serialize_with_remove_unserializable(self, mock_write_json_file): + model = TestUserModelWithRemove(name="test", value=42) + file_path = "test.json" + + model.serialize(file_path) + + mock_write_json_file.assert_called_once_with(file_path, {"name": "test"}) + + def test_get_serialization_path(self): + expected_path = str(SERDESER_PATH / "TestUserModel.json") + + assert TestUserModel.get_serialization_path() == expected_path + + def test_remove_unserializable(self, mock_user_model): + data = { + "name": "example", + "unserializable_fields": ["temp_data", "debug_info"], + "temp_data": "some temporary data", + "debug_info": "some debug information", + } + mock_user_model.remove_unserializable(data) + + expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]} + assert data == expected_data + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index 3b20f3fa0..0ba3a8d41 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -6,17 +6,13 @@ from typing import List -from pydantic import BaseModel - from metagpt.actions import WritePRD from metagpt.actions.action_node import ActionNode from metagpt.schema import Message from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, deserialize_message, - deserialize_model, serialize_message, - serialize_model, ) @@ -70,35 +66,3 @@ def test_serialize_and_deserialize_message(): assert new_message.content == message.content assert new_message.cause_by == message.cause_by assert new_message.instruct_content.field1 == out_data["field1"] - - -class TestUserModel(BaseModel): - name: str - value: int - - -def test_serialize_model(mocker): - model = TestUserModel(name="test", value=42) - file_path = "test.json" - mock_write_json_file = mocker.patch("metagpt.utils.serialize.write_json_file") - - # Test without remove_unserializable - serialize_model(model, file_path) - mock_write_json_file.assert_called_once_with(file_path, model.model_dump()) - - # Test with remove_unserializable - def remove_unserializable(data: dict): - data.pop("value", None) - - serialize_model(model, file_path, remove_unserializable) - mock_write_json_file.assert_called_with(file_path, {"name": "test"}) - - -def test_deserialize_model(mocker): - file_path = "test.json" - data = {"name": "test", "value": 42} - mock_read_json_file = mocker.patch("metagpt.utils.serialize.read_json_file", return_value=data) - - model = deserialize_model(TestUserModel, file_path) - mock_read_json_file.assert_called_once_with(file_path) - assert model == TestUserModel(**data)