use serialize in SerializationMixin

This commit is contained in:
seehi 2024-08-12 14:37:00 +08:00
parent 887f180e58
commit 587dd0cc81
5 changed files with 150 additions and 125 deletions

View file

@ -1,5 +1,3 @@
from typing import ClassVar
from metagpt.actions import (
UserRequirement,
WriteDesign,
@ -8,17 +6,15 @@ from metagpt.actions import (
WriteTest,
)
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import AGENT, SERDESER_PATH
from metagpt.const import AGENT
from metagpt.environment.base_env import Environment
from metagpt.logs import get_human_input, logger
from metagpt.logs import get_human_input
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
from metagpt.schema import Message
from metagpt.schema import Message, SerializationMixin
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.serialize import deserialize_model, serialize_model
class MGXEnv(Environment):
class MGXEnv(Environment, SerializationMixin):
"""MGX Environment"""
# If True, fixed software sop bypassing TL is allowed, otherwise, TL will fully take over the routing
@ -26,8 +22,6 @@ class MGXEnv(Environment):
direct_chat_roles: set[str] = set() # record direct chat: @role_name
default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json")
def _publish_message(self, message: Message, peekable: bool = True) -> bool:
return super().publish_message(message, peekable)
@ -132,53 +126,13 @@ class MGXEnv(Environment):
def __repr__(self):
return "MGXEnv()"
@handle_exception
def serialize(self, file_path: str = None) -> str:
"""Serializes the current instance to a JSON file.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
Returns:
str: The path to the JSON file where the instance was saved.
"""
file_path = file_path or self.default_serialization_path
serialize_model(self, file_path, remove_unserializable=self.remove_unserializable)
logger.info(f"MGXEnv serialization successful. File saved at: {file_path}")
return file_path
@classmethod
@handle_exception
def deserialize(cls, file_path: str = None) -> "MGXEnv":
"""Deserializes a JSON file to an instance of MGXEnv.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
Returns:
MGXEnv: An instance of MGXEnv.
"""
file_path = file_path or cls.default_serialization_path
model = deserialize_model(cls, file_path)
logger.info(f"MGXEnv deserialization successful. Instance created from file: {file_path}")
return model
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
Args:
data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method.
"""
roles = data.get("roles", {})
for role in roles.values():

View file

@ -44,6 +44,7 @@ from metagpt.const import (
MESSAGE_ROUTE_FROM,
MESSAGE_ROUTE_TO,
MESSAGE_ROUTE_TO_ALL,
SERDESER_PATH,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
@ -56,6 +57,8 @@ from metagpt.utils.common import (
any_to_str_set,
aread,
import_class,
read_json_file,
write_json_file,
)
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.report import TaskReporter
@ -127,6 +130,77 @@ class SerializationMixin(BaseModel, extra="forbid"):
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
super().__init_subclass__(**kwargs)
@handle_exception
def serialize(self, file_path: str = None) -> str:
"""Serializes the current instance to a JSON file.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
Returns:
str: The path to the JSON file where the instance was saved.
"""
file_path = file_path or self.get_serialization_path()
serialized_data = self.model_dump()
self.remove_unserializable(serialized_data)
write_json_file(file_path, serialized_data)
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
return file_path
@classmethod
@handle_exception
def deserialize(cls, file_path: str = None) -> BaseModel:
"""Deserializes a JSON file to an instance of cls.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
Returns:
An instance of the cls.
"""
file_path = file_path or cls.get_serialization_path()
data: dict = read_json_file(file_path)
model = cls(**data)
logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
return model
@classmethod
def get_serialization_path(cls) -> str:
"""Get the serialization path for the class.
This method constructs a file path for serialization based on the class name.
The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName'
is the name of the class.
Returns:
str: The path to the serialization file.
"""
return str(SERDESER_PATH / f"{cls.__qualname__}.json")
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
This method removes keys specified in the "unserializable_fields" list from the provided data dictionary.
It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields
that cannot be serialized.
"""
for key in data.get("unserializable_fields", []):
data.pop(key, None)
class SimpleMessage(BaseModel):
content: str

View file

@ -4,11 +4,8 @@
import copy
import pickle
from typing import Callable, Optional, Type
from pydantic import BaseModel
from metagpt.utils.common import import_class, read_json_file, write_json_file
from metagpt.utils.common import import_class
def actionoutout_schema_to_mapping(schema: dict) -> dict:
@ -84,36 +81,3 @@ def deserialize_message(message_ser: str) -> "Message":
message.instruct_content = ic_new
return message
def serialize_model(model: BaseModel, file_path: str, remove_unserializable: Optional[Callable[[dict], None]] = None):
"""Serializes a Pydantic model to a JSON file.
Args:
model (BaseModel): The Pydantic model to serialize.
file_path (str): The path to the JSON file where the model will be saved.
remove_unserializable (Optional[Callable[[dict], None]]): Optional function to remove unserializable content from the serialized data.
"""
serialized_data = model.model_dump()
if remove_unserializable:
remove_unserializable(serialized_data)
write_json_file(file_path, serialized_data)
def deserialize_model(cls: Type[BaseModel], file_path: str) -> BaseModel:
"""Deserializes a JSON file to a Pydantic model.
Args:
cls (Type[BaseModel]): The Pydantic model class to deserialize into.
file_path (str): The path to the JSON file to read from.
Returns:
BaseModel: An instance of the Pydantic model.
"""
data: dict = read_json_file(file_path)
return cls(**data)