use serialize in SerializationMixin

This commit is contained in:
seehi 2024-08-12 14:37:00 +08:00
parent 887f180e58
commit 587dd0cc81
5 changed files with 150 additions and 125 deletions

View file

@ -44,6 +44,7 @@ from metagpt.const import (
MESSAGE_ROUTE_FROM,
MESSAGE_ROUTE_TO,
MESSAGE_ROUTE_TO_ALL,
SERDESER_PATH,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
@ -56,6 +57,8 @@ from metagpt.utils.common import (
any_to_str_set,
aread,
import_class,
read_json_file,
write_json_file,
)
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.report import TaskReporter
@ -127,6 +130,77 @@ class SerializationMixin(BaseModel, extra="forbid"):
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
super().__init_subclass__(**kwargs)
@handle_exception
def serialize(self, file_path: str = None) -> str:
"""Serializes the current instance to a JSON file.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
Returns:
str: The path to the JSON file where the instance was saved.
"""
file_path = file_path or self.get_serialization_path()
serialized_data = self.model_dump()
self.remove_unserializable(serialized_data)
write_json_file(file_path, serialized_data)
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
return file_path
@classmethod
@handle_exception
def deserialize(cls, file_path: str = None) -> BaseModel:
"""Deserializes a JSON file to an instance of cls.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
Returns:
An instance of the cls.
"""
file_path = file_path or cls.get_serialization_path()
data: dict = read_json_file(file_path)
model = cls(**data)
logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
return model
@classmethod
def get_serialization_path(cls) -> str:
"""Get the serialization path for the class.
This method constructs a file path for serialization based on the class name.
The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName'
is the name of the class.
Returns:
str: The path to the serialization file.
"""
return str(SERDESER_PATH / f"{cls.__qualname__}.json")
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
This method removes keys specified in the "unserializable_fields" list from the provided data dictionary.
It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields
that cannot be serialized.
"""
for key in data.get("unserializable_fields", []):
data.pop(key, None)
class SimpleMessage(BaseModel):
content: str