serialize mgxenv

This commit is contained in:
seehi 2024-08-09 10:31:03 +08:00
parent 5f86247c0d
commit 98ac5fbce3
10 changed files with 185 additions and 30 deletions

View file

@ -1,3 +1,5 @@
from typing import ClassVar
from metagpt.actions import (
UserRequirement,
WriteDesign,
@ -6,12 +8,14 @@ from metagpt.actions import (
WriteTest,
)
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import AGENT
from metagpt.const import AGENT, SERDESER_PATH
from metagpt.environment.base_env import Environment
from metagpt.logs import get_human_input
from metagpt.logs import get_human_input, logger
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.serialize import deserialize_model, serialize_model
class MGXEnv(Environment):
@ -22,6 +26,11 @@ class MGXEnv(Environment):
direct_chat_roles: set[str] = set() # record direct chat: @role_name
default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json")
def __repr__(self):
return "MGXEnv()"
def _publish_message(self, message: Message, peekable: bool = True) -> bool:
return super().publish_message(message, peekable)
@ -121,5 +130,54 @@ class MGXEnv(Environment):
converted_msg.content = f"from {sent_from} to {converted_msg.send_to}: {converted_msg.content}"
return converted_msg
def __repr__(self):
return "MGXEnv()"
@handle_exception
def serialize(self, file_path: str = None) -> str:
"""Serializes the current instance to a JSON file.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
Returns:
str: The path to the JSON file where the instance was saved.
"""
file_path = file_path or self.default_serialization_path
serialize_model(self, file_path, remove_unserializable=self.remove_unserializable)
logger.info(f"MGXEnv serialization successful. File saved at: {file_path}")
return file_path
@classmethod
@handle_exception
def deserialize(cls, file_path: str = None) -> "MGXEnv":
"""Deserializes a JSON file to an instance of MGXEnv.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
Returns:
MGXEnv: An instance of MGXEnv.
"""
file_path = file_path or cls.default_serialization_path
model = deserialize_model(cls, file_path)
logger.info(f"MGXEnv deserialization successful. Instance created from file: {file_path}")
return model
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
Args:
data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method.
"""
roles = data.get("roles", {})
for role in roles.values():
[role.pop(key, None) for key in role.get("unserializable_fields", [])]