mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-18 13:55:17 +02:00
use serialize in SerializationMixin
This commit is contained in:
parent
887f180e58
commit
587dd0cc81
5 changed files with 150 additions and 125 deletions
|
|
@ -1,5 +1,3 @@
|
|||
from typing import ClassVar
|
||||
|
||||
from metagpt.actions import (
|
||||
UserRequirement,
|
||||
WriteDesign,
|
||||
|
|
@ -8,17 +6,15 @@ from metagpt.actions import (
|
|||
WriteTest,
|
||||
)
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.const import AGENT, SERDESER_PATH
|
||||
from metagpt.const import AGENT
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.logs import get_human_input, logger
|
||||
from metagpt.logs import get_human_input
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import Message, SerializationMixin
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.serialize import deserialize_model, serialize_model
|
||||
|
||||
|
||||
class MGXEnv(Environment):
|
||||
class MGXEnv(Environment, SerializationMixin):
|
||||
"""MGX Environment"""
|
||||
|
||||
# If True, fixed software sop bypassing TL is allowed, otherwise, TL will fully take over the routing
|
||||
|
|
@ -26,8 +22,6 @@ class MGXEnv(Environment):
|
|||
|
||||
direct_chat_roles: set[str] = set() # record direct chat: @role_name
|
||||
|
||||
default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json")
|
||||
|
||||
def _publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
return super().publish_message(message, peekable)
|
||||
|
||||
|
|
@ -132,53 +126,13 @@ class MGXEnv(Environment):
|
|||
def __repr__(self):
|
||||
return "MGXEnv()"
|
||||
|
||||
@handle_exception
|
||||
def serialize(self, file_path: str = None) -> str:
|
||||
"""Serializes the current instance to a JSON file.
|
||||
|
||||
If an exception occurs, `handle_exception` will catch it and return `None`.
|
||||
|
||||
Args:
|
||||
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The path to the JSON file where the instance was saved.
|
||||
"""
|
||||
|
||||
file_path = file_path or self.default_serialization_path
|
||||
|
||||
serialize_model(self, file_path, remove_unserializable=self.remove_unserializable)
|
||||
logger.info(f"MGXEnv serialization successful. File saved at: {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
def deserialize(cls, file_path: str = None) -> "MGXEnv":
|
||||
"""Deserializes a JSON file to an instance of MGXEnv.
|
||||
|
||||
If an exception occurs, `handle_exception` will catch it and return `None`.
|
||||
|
||||
Args:
|
||||
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
|
||||
|
||||
Returns:
|
||||
MGXEnv: An instance of MGXEnv.
|
||||
"""
|
||||
|
||||
file_path = file_path or cls.default_serialization_path
|
||||
|
||||
model = deserialize_model(cls, file_path)
|
||||
logger.info(f"MGXEnv deserialization successful. Instance created from file: {file_path}")
|
||||
|
||||
return model
|
||||
|
||||
def remove_unserializable(self, data: dict):
|
||||
"""Removes unserializable content from the data dictionary.
|
||||
|
||||
Args:
|
||||
data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method.
|
||||
"""
|
||||
|
||||
roles = data.get("roles", {})
|
||||
|
||||
for role in roles.values():
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from metagpt.const import (
|
|||
MESSAGE_ROUTE_FROM,
|
||||
MESSAGE_ROUTE_TO,
|
||||
MESSAGE_ROUTE_TO_ALL,
|
||||
SERDESER_PATH,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
|
|
@ -56,6 +57,8 @@ from metagpt.utils.common import (
|
|||
any_to_str_set,
|
||||
aread,
|
||||
import_class,
|
||||
read_json_file,
|
||||
write_json_file,
|
||||
)
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.report import TaskReporter
|
||||
|
|
@ -127,6 +130,77 @@ class SerializationMixin(BaseModel, extra="forbid"):
|
|||
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@handle_exception
|
||||
def serialize(self, file_path: str = None) -> str:
|
||||
"""Serializes the current instance to a JSON file.
|
||||
|
||||
If an exception occurs, `handle_exception` will catch it and return `None`.
|
||||
|
||||
Args:
|
||||
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The path to the JSON file where the instance was saved.
|
||||
"""
|
||||
|
||||
file_path = file_path or self.get_serialization_path()
|
||||
|
||||
serialized_data = self.model_dump()
|
||||
self.remove_unserializable(serialized_data)
|
||||
|
||||
write_json_file(file_path, serialized_data)
|
||||
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
def deserialize(cls, file_path: str = None) -> BaseModel:
|
||||
"""Deserializes a JSON file to an instance of cls.
|
||||
|
||||
If an exception occurs, `handle_exception` will catch it and return `None`.
|
||||
|
||||
Args:
|
||||
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
|
||||
|
||||
Returns:
|
||||
An instance of the cls.
|
||||
"""
|
||||
|
||||
file_path = file_path or cls.get_serialization_path()
|
||||
|
||||
data: dict = read_json_file(file_path)
|
||||
|
||||
model = cls(**data)
|
||||
logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_serialization_path(cls) -> str:
|
||||
"""Get the serialization path for the class.
|
||||
|
||||
This method constructs a file path for serialization based on the class name.
|
||||
The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName'
|
||||
is the name of the class.
|
||||
|
||||
Returns:
|
||||
str: The path to the serialization file.
|
||||
"""
|
||||
|
||||
return str(SERDESER_PATH / f"{cls.__qualname__}.json")
|
||||
|
||||
def remove_unserializable(self, data: dict):
|
||||
"""Removes unserializable content from the data dictionary.
|
||||
|
||||
This method removes keys specified in the "unserializable_fields" list from the provided data dictionary.
|
||||
It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields
|
||||
that cannot be serialized.
|
||||
"""
|
||||
|
||||
for key in data.get("unserializable_fields", []):
|
||||
data.pop(key, None)
|
||||
|
||||
|
||||
class SimpleMessage(BaseModel):
|
||||
content: str
|
||||
|
|
|
|||
|
|
@ -4,11 +4,8 @@
|
|||
|
||||
import copy
|
||||
import pickle
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.utils.common import import_class, read_json_file, write_json_file
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
def actionoutout_schema_to_mapping(schema: dict) -> dict:
|
||||
|
|
@ -84,36 +81,3 @@ def deserialize_message(message_ser: str) -> "Message":
|
|||
message.instruct_content = ic_new
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def serialize_model(model: BaseModel, file_path: str, remove_unserializable: Optional[Callable[[dict], None]] = None):
|
||||
"""Serializes a Pydantic model to a JSON file.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The Pydantic model to serialize.
|
||||
file_path (str): The path to the JSON file where the model will be saved.
|
||||
remove_unserializable (Optional[Callable[[dict], None]]): Optional function to remove unserializable content from the serialized data.
|
||||
"""
|
||||
|
||||
serialized_data = model.model_dump()
|
||||
|
||||
if remove_unserializable:
|
||||
remove_unserializable(serialized_data)
|
||||
|
||||
write_json_file(file_path, serialized_data)
|
||||
|
||||
|
||||
def deserialize_model(cls: Type[BaseModel], file_path: str) -> BaseModel:
|
||||
"""Deserializes a JSON file to a Pydantic model.
|
||||
|
||||
Args:
|
||||
cls (Type[BaseModel]): The Pydantic model class to deserialize into.
|
||||
file_path (str): The path to the JSON file to read from.
|
||||
|
||||
Returns:
|
||||
BaseModel: An instance of the Pydantic model.
|
||||
"""
|
||||
|
||||
data: dict = read_json_file(file_path)
|
||||
|
||||
return cls(**data)
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.schema import (
|
||||
AIMessage,
|
||||
CodeSummarizeContext,
|
||||
|
|
@ -23,6 +24,7 @@ from metagpt.schema import (
|
|||
Message,
|
||||
MessageQueue,
|
||||
Plan,
|
||||
SerializationMixin,
|
||||
SystemMessage,
|
||||
Task,
|
||||
UMLClassAttribute,
|
||||
|
|
@ -398,5 +400,72 @@ def test_create_instruct_value(name, value):
|
|||
assert obj.model_dump() == value
|
||||
|
||||
|
||||
class TestUserModel(SerializationMixin, BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
class TestUserModelWithRemove(TestUserModel):
|
||||
def remove_unserializable(self, data: dict):
|
||||
for key in ["value", "__module_class_name"]:
|
||||
data.pop(key, None)
|
||||
|
||||
|
||||
class TestSerializationMixin:
|
||||
@pytest.fixture
|
||||
def mock_write_json_file(self, mocker):
|
||||
return mocker.patch("metagpt.schema.write_json_file")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_read_json_file(self, mocker):
|
||||
return mocker.patch("metagpt.schema.read_json_file")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_model(self):
|
||||
return TestUserModel(name="test", value=42)
|
||||
|
||||
def test_serialize(self, mock_write_json_file, mock_user_model):
|
||||
file_path = "test.json"
|
||||
|
||||
mock_user_model.serialize(file_path)
|
||||
|
||||
mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump())
|
||||
|
||||
def test_deserialize(self, mock_read_json_file):
|
||||
file_path = "test.json"
|
||||
data = {"name": "test", "value": 42}
|
||||
mock_read_json_file.return_value = data
|
||||
|
||||
model = TestUserModel.deserialize(file_path)
|
||||
|
||||
mock_read_json_file.assert_called_once_with(file_path)
|
||||
assert model == TestUserModel(**data)
|
||||
|
||||
def test_serialize_with_remove_unserializable(self, mock_write_json_file):
|
||||
model = TestUserModelWithRemove(name="test", value=42)
|
||||
file_path = "test.json"
|
||||
|
||||
model.serialize(file_path)
|
||||
|
||||
mock_write_json_file.assert_called_once_with(file_path, {"name": "test"})
|
||||
|
||||
def test_get_serialization_path(self):
|
||||
expected_path = str(SERDESER_PATH / "TestUserModel.json")
|
||||
|
||||
assert TestUserModel.get_serialization_path() == expected_path
|
||||
|
||||
def test_remove_unserializable(self, mock_user_model):
|
||||
data = {
|
||||
"name": "example",
|
||||
"unserializable_fields": ["temp_data", "debug_info"],
|
||||
"temp_data": "some temporary data",
|
||||
"debug_info": "some debug information",
|
||||
}
|
||||
mock_user_model.remove_unserializable(data)
|
||||
|
||||
expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]}
|
||||
assert data == expected_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -6,17 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import (
|
||||
actionoutout_schema_to_mapping,
|
||||
deserialize_message,
|
||||
deserialize_model,
|
||||
serialize_message,
|
||||
serialize_model,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -70,35 +66,3 @@ def test_serialize_and_deserialize_message():
|
|||
assert new_message.content == message.content
|
||||
assert new_message.cause_by == message.cause_by
|
||||
assert new_message.instruct_content.field1 == out_data["field1"]
|
||||
|
||||
|
||||
class TestUserModel(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
def test_serialize_model(mocker):
|
||||
model = TestUserModel(name="test", value=42)
|
||||
file_path = "test.json"
|
||||
mock_write_json_file = mocker.patch("metagpt.utils.serialize.write_json_file")
|
||||
|
||||
# Test without remove_unserializable
|
||||
serialize_model(model, file_path)
|
||||
mock_write_json_file.assert_called_once_with(file_path, model.model_dump())
|
||||
|
||||
# Test with remove_unserializable
|
||||
def remove_unserializable(data: dict):
|
||||
data.pop("value", None)
|
||||
|
||||
serialize_model(model, file_path, remove_unserializable)
|
||||
mock_write_json_file.assert_called_with(file_path, {"name": "test"})
|
||||
|
||||
|
||||
def test_deserialize_model(mocker):
|
||||
file_path = "test.json"
|
||||
data = {"name": "test", "value": 42}
|
||||
mock_read_json_file = mocker.patch("metagpt.utils.serialize.read_json_file", return_value=data)
|
||||
|
||||
model = deserialize_model(TestUserModel, file_path)
|
||||
mock_read_json_file.assert_called_once_with(file_path)
|
||||
assert model == TestUserModel(**data)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue