From d0edc555b0b9f35f8099e5612e61d277959bd23a Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 28 Dec 2023 16:07:39 +0800 Subject: [PATCH] add SerDeserMixin for child-classes --- metagpt/actions/action.py | 18 +----- metagpt/environment.py | 26 ++------ metagpt/memory/memory.py | 17 ++---- metagpt/roles/role.py | 45 +++----------- metagpt/schema.py | 61 ++++++++++++++++++- .../serialize_deserialize/test_action.py | 5 ++ .../serialize_deserialize/test_memory.py | 3 + .../serialize_deserialize/test_polymorphic.py | 58 ++++++++++++++++++ .../serialize_deserialize/test_role.py | 15 +++++ .../serialize_deserialize/test_schema.py | 6 +- .../test_serdeser_base.py | 13 ++-- 11 files changed, 171 insertions(+), 96 deletions(-) create mode 100644 tests/metagpt/serialize_deserialize/test_polymorphic.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index f8b857d16..5dbb36332 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -10,7 +10,7 @@ from __future__ import annotations from typing import Any, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM @@ -19,13 +19,12 @@ from metagpt.schema import ( CodeSummarizeContext, CodingContext, RunCodeContext, + SerDeserMixin, TestingContext, ) -action_subclass_registry = {} - -class Action(BaseModel): +class Action(SerDeserMixin, is_polymorphic_base=True): model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"]) name: str = "" @@ -35,9 +34,6 @@ class Action(BaseModel): desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) - # builtin variables - builtin_class_name: str = "" - def __init_with_instruction(self, instruction: str): """Initialize action with instruction""" self.node = ActionNode(key=self.name, expected_type=str, instruction=instruction, example="", schema="raw") @@ -46,17 +42,9 @@ class Action(BaseModel): def __init__(self, **data: Any): super().__init__(**data) - # deserialize child classes dynamically for inherited `action` - object.__setattr__(self, "builtin_class_name", self.__class__.__name__) - self.model_fields["builtin_class_name"].default = self.__class__.__name__ - if "instruction" in data: self.__init_with_instruction(data["instruction"]) - def __init_subclass__(cls, **kwargs: Any) -> None: - super().__init_subclass__(**kwargs) - action_subclass_registry[cls.__name__] = cls - def set_prefix(self, prefix): """Set prefix for later usage""" self.prefix = prefix diff --git a/metagpt/environment.py b/metagpt/environment.py index b9353d9d9..ddb9ad9dd 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -13,13 +13,13 @@ """ import asyncio from pathlib import Path -from typing import Iterable, Set, Union +from typing import Iterable, Set -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator from metagpt.config import CONFIG from metagpt.logs import logger -from metagpt.roles.role import Role, role_subclass_registry +from metagpt.roles.role import Role from metagpt.schema import Message from metagpt.utils.common import is_subscribed, read_json_file, write_json_file @@ -32,28 +32,10 @@ class Environment(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) desc: str = Field(default="") # 环境描述 - roles: dict[str, Role] = Field(default_factory=dict, validate_default=True) + roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True) members: dict[Role, Set] = Field(default_factory=dict, exclude=True) history: str = "" # For debug - @field_validator("roles", mode="before") - @classmethod - def check_roles(cls, roles: dict[str, Union[Role, dict]]) -> dict[str, Role]: - new_roles = dict() - for role_key, role in roles.items(): - if isinstance(role, dict): - item_class_name = role.get("builtin_class_name", None) - if item_class_name: - for name, subclass in role_subclass_registry.items(): - registery_class_name = subclass.model_fields["builtin_class_name"].default - if item_class_name == registery_class_name: - new_role = subclass(**role) - break - new_roles[role_key] = new_role - else: - new_roles[role_key] = role - return new_roles - @model_validator(mode="after") def init_roles(self): self.add_roles(self.roles.values()) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 93f1774dc..593409648 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -8,9 +8,9 @@ """ from collections import defaultdict from pathlib import Path -from typing import Iterable, Set +from typing import DefaultDict, Iterable, Set -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SerializeAsAny from metagpt.const import IGNORED_MESSAGE_ID from metagpt.schema import Message @@ -25,19 +25,10 @@ from metagpt.utils.common import ( class Memory(BaseModel): """The most basic memory: super-memory""" - storage: list[Message] = [] - index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) + storage: list[SerializeAsAny[Message]] = [] + index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list)) ignore_id: bool = False - def __init__(self, **kwargs): - index = kwargs.get("index", {}) - new_index = defaultdict(list) - for action_str, value in index.items(): - new_index[action_str] = [Message(**item_dict) for item_dict in value] - kwargs["index"] = new_index - super(Memory, self).__init__(**kwargs) - self.index = new_index - def serialize(self, stg_path: Path): """stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" memory_path = stg_path.joinpath("memory.json") diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 1d37228e3..623832083 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -24,12 +24,11 @@ from __future__ import annotations from enum import Enum from pathlib import Path -from typing import Any, Iterable, Optional, Set, Type, Union +from typing import Any, Iterable, Optional, Set, Type -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator from metagpt.actions import Action, ActionOutput -from metagpt.actions.action import action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.const import SERDESER_PATH @@ -37,7 +36,7 @@ from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.schema import Message, MessageQueue +from metagpt.schema import Message, MessageQueue, SerDeserMixin from metagpt.utils.common import ( any_to_name, any_to_str, @@ -127,10 +126,7 @@ class RoleContext(BaseModel): return self.memory.get() -role_subclass_registry = {} - - -class Role(BaseModel): +class Role(SerDeserMixin, is_polymorphic_base=True): """Role/Agent""" model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"]) @@ -147,34 +143,16 @@ class Role(BaseModel): ) # Each role has its own LLM, use different system message role_id: str = "" states: list[str] = [] - actions: list[Action] = Field(default=[], validate_default=True) + actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True) rc: RoleContext = Field(default_factory=RoleContext) subscription: set[str] = set() # builtin variables recovered: bool = False # to tag if a recovered role latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted - builtin_class_name: str = "" __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` - @field_validator("actions", mode="before") - @classmethod - def check_actions(cls, actions: list[Union[dict, Action]]) -> list[Action]: - new_actions = [] - for action in actions: - new_action = action - if isinstance(action, dict): - item_class_name = action.get("builtin_class_name", None) - if item_class_name: - for name, subclass in action_subclass_registry.items(): - registery_class_name = subclass.model_fields["builtin_class_name"].default - if item_class_name == registery_class_name: - new_action = subclass(**action) - break - new_actions.append(new_action) - return new_actions - @model_validator(mode="after") def check_subscription(self) -> set: if not self.subscription: @@ -191,20 +169,11 @@ class Role(BaseModel): super().__init__(**data) self.llm.system_prompt = self._get_prefix() - - # deserialize child classes dynamically for inherited `role` - object.__setattr__(self, "builtin_class_name", self.__class__.__name__) - self.model_fields["builtin_class_name"].default = self.__class__.__name__ - self._watch(data.get("watch") or [UserRequirement]) - def __init_subclass__(cls, **kwargs: Any) -> None: - super().__init_subclass__(**kwargs) - role_subclass_registry[cls.__name__] = cls - def _reset(self): - object.__setattr__(self, "states", []) - object.__setattr__(self, "actions", []) + self.states = [] + self.actions = [] @property def _setting(self): diff --git a/metagpt/schema.py b/metagpt/schema.py index 2ceba2251..46064472f 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -23,7 +23,7 @@ from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path -from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from pydantic import ( BaseModel, @@ -33,6 +33,7 @@ from pydantic import ( field_serializer, field_validator, ) +from pydantic_core import core_schema from metagpt.config import CONFIG from metagpt.const import ( @@ -53,6 +54,64 @@ from metagpt.utils.serialize import ( ) +class SerDeserMixin(BaseModel): + """SereDeserMixin for subclass' ser&deser""" + + __is_polymorphic_base = False + __subclasses_map__ = {} + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type["SerDeserMixin"], handler: Callable[[Any], core_schema.CoreSchema] + ) -> core_schema.CoreSchema: + schema = handler(source) + og_schema_ref = schema["ref"] + schema["ref"] += ":mixin" + + return core_schema.no_info_before_validator_function( + cls.__deserialize_with_real_type__, + schema=schema, + ref=og_schema_ref, + serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__), + ) + + @classmethod + def __serialize_add_class_type__( + cls, + value, + handler: core_schema.SerializerFunctionWrapHandler, + ) -> Any: + ret = handler(value) + if not len(cls.__subclasses__()): + # only subclass add `__module_class_name` + ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}" + return ret + + @classmethod + def __deserialize_with_real_type__(cls, value: Any): + if not isinstance(value, dict): + return value + + if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value): + # add right condition to init BaseClass like Action() + return value + module_class_name = value.get("__module_class_name", None) + if module_class_name is None: + raise ValueError("Missing field: __module_class_name") + + class_type = cls.__subclasses_map__.get(module_class_name, None) + + if class_type is None: + raise TypeError("Trying to instantiate {module_class_name} which not defined yet.") + + return class_type(**value) + + def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): + cls.__is_polymorphic_base = is_polymorphic_base + cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls + super().__init_subclass__(**kwargs) + + class SimpleMessage(BaseModel): content: str role: str diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index 4afe1b33e..b3206696b 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -13,6 +13,11 @@ def test_action_serialize(): ser_action_dict = action.model_dump() assert "name" in ser_action_dict assert "llm" not in ser_action_dict # not export + assert "__module_class_name" not in ser_action_dict + + action = Action(name="test") + ser_action_dict = action.model_dump() + assert "test" in ser_action_dict["name"] @pytest.mark.asyncio diff --git a/tests/metagpt/serialize_deserialize/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py index 2a66434e1..aa3e2a465 100644 --- a/tests/metagpt/serialize_deserialize/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -35,6 +35,9 @@ def test_memory_serdeser(): assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign) assert new_msg2.role == "Boss" + memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]}) + assert memory.count() == 2 + def test_memory_serdeser_save(): msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement) diff --git a/tests/metagpt/serialize_deserialize/test_polymorphic.py b/tests/metagpt/serialize_deserialize/test_polymorphic.py new file mode 100644 index 000000000..ed0482c34 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_polymorphic.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of polymorphic conditions + +from pydantic import BaseModel, ConfigDict, SerializeAsAny + +from metagpt.actions import Action +from tests.metagpt.serialize_deserialize.test_serdeser_base import ( + ActionOKV2, + ActionPass, +) + + +class ActionSubClasses(BaseModel): + actions: list[SerializeAsAny[Action]] = [] + + +class ActionSubClassesNoSAA(BaseModel): + """without SerializeAsAny""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + actions: list[Action] = [] + + +def test_serialize_as_any(): + """test subclasses of action with different fields in ser&deser""" + # ActionOKV2 with a extra field `extra_field` + action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()]) + action_subcls_dict = action_subcls.model_dump() + assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field + + +def test_no_serialize_as_any(): + # ActionOKV2 with a extra field `extra_field` + action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()]) + action_subcls_dict = action_subcls.model_dump() + # without `SerializeAsAny`, it will serialize as Action + assert "extra_field" not in action_subcls_dict["actions"][0] + + +def test_polymorphic(): + _ = ActionOKV2( + **{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"} + ) + + action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()]) + action_subcls_dict = action_subcls.model_dump() + + assert "__module_class_name" in action_subcls_dict["actions"][0] + + new_action_subcls = ActionSubClasses(**action_subcls_dict) + assert isinstance(new_action_subcls.actions[0], ActionOKV2) + assert isinstance(new_action_subcls.actions[1], ActionPass) + + new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict) + assert isinstance(new_action_subcls.actions[0], ActionOKV2) + assert isinstance(new_action_subcls.actions[1], ActionPass) diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 3e3d04dbc..d38797baf 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -6,6 +6,7 @@ import shutil import pytest +from pydantic import BaseModel, SerializeAsAny from metagpt.actions import WriteCode from metagpt.actions.add_requirement import UserRequirement @@ -37,6 +38,20 @@ def test_roles(): assert len(role_d.actions) == 1 +def test_role_subclasses(): + """test subclasses of role with same fields in ser&deser""" + + class RoleSubClasses(BaseModel): + roles: list[SerializeAsAny[Role]] = [] + + role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()]) + role_subcls_dict = role_subcls.model_dump() + + new_role_subcls = RoleSubClasses(**role_subcls_dict) + assert isinstance(new_role_subcls.roles[0], RoleA) + assert isinstance(new_role_subcls.roles[1], RoleB) + + def test_role_serialize(): role = Role() ser_role_dict = role.model_dump() diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py index 6aec298a0..e793079f0 100644 --- a/tests/metagpt/serialize_deserialize/test_schema.py +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -7,8 +7,8 @@ from metagpt.actions.write_code import WriteCode from metagpt.schema import Document, Documents, Message from metagpt.utils.common import any_to_str from tests.metagpt.serialize_deserialize.test_serdeser_base import ( + MockICMessage, MockMessage, - TestICMessage, ) @@ -28,10 +28,10 @@ def test_message_serdeser(): assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=` assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump() - message = Message(content="test_ic", instruct_content=TestICMessage()) + message = Message(content="test_ic", instruct_content=MockICMessage()) ser_data = message.model_dump() new_message = Message(**ser_data) - assert new_message.instruct_content != TestICMessage() # TODO + assert new_message.instruct_content != MockICMessage() # TODO message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")})) ser_data = message.model_dump() diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index dc8cc76d6..daa46c99c 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -16,7 +16,7 @@ from metagpt.roles.role import Role, RoleReactMode serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage") -class TestICMessage(BaseModel): +class MockICMessage(BaseModel): content: str = "test_ic" @@ -28,7 +28,7 @@ class MockMessage(BaseModel): class ActionPass(Action): - name: str = Field(default="ActionPass") + name: str = "ActionPass" async def run(self, messages: list["Message"]) -> ActionOutput: await asyncio.sleep(5) # sleep to make other roles can watch the executed Message @@ -40,7 +40,7 @@ class ActionPass(Action): class ActionOK(Action): - name: str = Field(default="ActionOK") + name: str = "ActionOK" async def run(self, messages: list["Message"]) -> str: await asyncio.sleep(5) @@ -48,12 +48,17 @@ class ActionOK(Action): class ActionRaise(Action): - name: str = Field(default="ActionRaise") + name: str = "ActionRaise" async def run(self, messages: list["Message"]) -> str: raise RuntimeError("parse error in ActionRaise") +class ActionOKV2(Action): + name: str = "ActionOKV2" + extra_field: str = "ActionOKV2 Extra Info" + + class RoleA(Role): name: str = Field(default="RoleA") profile: str = Field(default="Role A")