diff --git a/metagpt/schema.py b/metagpt/schema.py index e5df6fb10..1bb07aa95 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -108,9 +108,9 @@ class Message(BaseModel): send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) def __init__(self, **kwargs): - instruct_content = kwargs.get("instruct_content", None) - if instruct_content and not isinstance(instruct_content, BaseModel): - ic = instruct_content + ic = kwargs.get("instruct_content", None) + if ic and not isinstance(ic, BaseModel) and "class" in ic: + # compatible with custom-defined ActionOutput mapping = actionoutput_str_to_mapping(ic["mapping"]) actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import @@ -140,13 +140,17 @@ class Message(BaseModel): def dict(self, *args, **kwargs) -> "DictStrAny": """ overwrite the `dict` to dump dynamic pydantic model""" obj_dict = super(Message, self).dict(*args, **kwargs) - ic = self.instruct_content # deal custom-defined action + ic = self.instruct_content if ic: + # compatible with custom-defined ActionOutput schema = ic.schema() - mapping = actionoutout_schema_to_mapping(schema) - mapping = actionoutput_mapping_to_str(mapping) + # `Documents` contain definitions + if "definitions" not in schema: + # TODO refine with nested BaseModel + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) - obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} return obj_dict def __str__(self): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index a445c9f31..ab7a3d99e 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -450,14 +450,12 @@ def serialize_decorator(func): async def wrapper(self, *args, **kwargs): try: result = await func(self, *args, **kwargs) - self.serialize() # Team.serialize return result except KeyboardInterrupt as kbi: logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") - self.serialize() # Team.serialize except Exception as exp: logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") - self.serialize() # Team.serialize + self.serialize() # Team.serialize return wrapper diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 1d90e8de8..a52dc8f45 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -62,7 +62,7 @@ def serialize_general_message(message: "Message") -> dict: message_cp = copy.deepcopy(message) ic = message_cp.instruct_content if ic: - # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + # model create by pydantic create_model like `pydantic.main.prd`, can't load directly schema = ic.schema() mapping = actionoutout_schema_to_mapping(schema) mapping = actionoutput_mapping_to_str(mapping) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index e87df9b52..d6a477b0e 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -10,6 +10,7 @@ import pytest from metagpt.const import SERDESER_PATH from metagpt.roles import ProjectManager, ProductManager, Architect from metagpt.team import Team +from metagpt.logs import logger from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK @@ -120,6 +121,8 @@ async def test_team_recover_multi_roles_save(): company.run_project(idea) await company.run(n_round=4) + logger.info("Team recovered") + new_company = Team.recover(stg_path) new_company.run_project(idea)