add SerDeserMixin for child-classes

This commit is contained in:
better629 2023-12-28 16:07:39 +08:00
parent 2dbaee0ff2
commit d0edc555b0
11 changed files with 171 additions and 96 deletions

View file

@ -10,7 +10,7 @@ from __future__ import annotations
from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from pydantic import ConfigDict, Field
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
@ -19,13 +19,12 @@ from metagpt.schema import (
CodeSummarizeContext,
CodingContext,
RunCodeContext,
SerDeserMixin,
TestingContext,
)
action_subclass_registry = {}
class Action(BaseModel):
class Action(SerDeserMixin, is_polymorphic_base=True):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
name: str = ""
@ -35,9 +34,6 @@ class Action(BaseModel):
desc: str = "" # for skill manager
node: ActionNode = Field(default=None, exclude=True)
# builtin variables
builtin_class_name: str = ""
def __init_with_instruction(self, instruction: str):
"""Initialize action with instruction"""
self.node = ActionNode(key=self.name, expected_type=str, instruction=instruction, example="", schema="raw")
@ -46,17 +42,9 @@ class Action(BaseModel):
def __init__(self, **data: Any):
super().__init__(**data)
# deserialize child classes dynamically for inherited `action`
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
self.model_fields["builtin_class_name"].default = self.__class__.__name__
if "instruction" in data:
self.__init_with_instruction(data["instruction"])
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
action_subclass_registry[cls.__name__] = cls
def set_prefix(self, prefix):
"""Set prefix for later usage"""
self.prefix = prefix

View file

@ -13,13 +13,13 @@
"""
import asyncio
from pathlib import Path
from typing import Iterable, Set, Union
from typing import Iterable, Set
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.roles.role import Role, role_subclass_registry
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
@ -32,28 +32,10 @@ class Environment(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
desc: str = Field(default="") # 环境描述
roles: dict[str, Role] = Field(default_factory=dict, validate_default=True)
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
history: str = "" # For debug
@field_validator("roles", mode="before")
@classmethod
def check_roles(cls, roles: dict[str, Union[Role, dict]]) -> dict[str, Role]:
new_roles = dict()
for role_key, role in roles.items():
if isinstance(role, dict):
item_class_name = role.get("builtin_class_name", None)
if item_class_name:
for name, subclass in role_subclass_registry.items():
registery_class_name = subclass.model_fields["builtin_class_name"].default
if item_class_name == registery_class_name:
new_role = subclass(**role)
break
new_roles[role_key] = new_role
else:
new_roles[role_key] = role
return new_roles
@model_validator(mode="after")
def init_roles(self):
self.add_roles(self.roles.values())

View file

@ -8,9 +8,9 @@
"""
from collections import defaultdict
from pathlib import Path
from typing import Iterable, Set
from typing import DefaultDict, Iterable, Set
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SerializeAsAny
from metagpt.const import IGNORED_MESSAGE_ID
from metagpt.schema import Message
@ -25,19 +25,10 @@ from metagpt.utils.common import (
class Memory(BaseModel):
"""The most basic memory: super-memory"""
storage: list[Message] = []
index: dict[str, list[Message]] = Field(default_factory=defaultdict(list))
storage: list[SerializeAsAny[Message]] = []
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
ignore_id: bool = False
def __init__(self, **kwargs):
index = kwargs.get("index", {})
new_index = defaultdict(list)
for action_str, value in index.items():
new_index[action_str] = [Message(**item_dict) for item_dict in value]
kwargs["index"] = new_index
super(Memory, self).__init__(**kwargs)
self.index = new_index
def serialize(self, stg_path: Path):
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")

View file

@ -24,12 +24,11 @@ from __future__ import annotations
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, Optional, Set, Type, Union
from typing import Any, Iterable, Optional, Set, Type
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action import action_subclass_registry
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
@ -37,7 +36,7 @@ from metagpt.llm import LLM, HumanProvider
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.schema import Message, MessageQueue
from metagpt.schema import Message, MessageQueue, SerDeserMixin
from metagpt.utils.common import (
any_to_name,
any_to_str,
@ -127,10 +126,7 @@ class RoleContext(BaseModel):
return self.memory.get()
role_subclass_registry = {}
class Role(BaseModel):
class Role(SerDeserMixin, is_polymorphic_base=True):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
@ -147,34 +143,16 @@ class Role(BaseModel):
) # Each role has its own LLM, use different system message
role_id: str = ""
states: list[str] = []
actions: list[Action] = Field(default=[], validate_default=True)
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
rc: RoleContext = Field(default_factory=RoleContext)
subscription: set[str] = set()
# builtin variables
recovered: bool = False # to tag if a recovered role
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
builtin_class_name: str = ""
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
@field_validator("actions", mode="before")
@classmethod
def check_actions(cls, actions: list[Union[dict, Action]]) -> list[Action]:
new_actions = []
for action in actions:
new_action = action
if isinstance(action, dict):
item_class_name = action.get("builtin_class_name", None)
if item_class_name:
for name, subclass in action_subclass_registry.items():
registery_class_name = subclass.model_fields["builtin_class_name"].default
if item_class_name == registery_class_name:
new_action = subclass(**action)
break
new_actions.append(new_action)
return new_actions
@model_validator(mode="after")
def check_subscription(self) -> set:
if not self.subscription:
@ -191,20 +169,11 @@ class Role(BaseModel):
super().__init__(**data)
self.llm.system_prompt = self._get_prefix()
# deserialize child classes dynamically for inherited `role`
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
self.model_fields["builtin_class_name"].default = self.__class__.__name__
self._watch(data.get("watch") or [UserRequirement])
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
role_subclass_registry[cls.__name__] = cls
def _reset(self):
object.__setattr__(self, "states", [])
object.__setattr__(self, "actions", [])
self.states = []
self.actions = []
@property
def _setting(self):

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
@ -33,6 +33,7 @@ from pydantic import (
field_serializer,
field_validator,
)
from pydantic_core import core_schema
from metagpt.config import CONFIG
from metagpt.const import (
@ -53,6 +54,64 @@ from metagpt.utils.serialize import (
)
class SerDeserMixin(BaseModel):
"""SereDeserMixin for subclass' ser&deser"""
__is_polymorphic_base = False
__subclasses_map__ = {}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type["SerDeserMixin"], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
schema = handler(source)
og_schema_ref = schema["ref"]
schema["ref"] += ":mixin"
return core_schema.no_info_before_validator_function(
cls.__deserialize_with_real_type__,
schema=schema,
ref=og_schema_ref,
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
)
@classmethod
def __serialize_add_class_type__(
cls,
value,
handler: core_schema.SerializerFunctionWrapHandler,
) -> Any:
ret = handler(value)
if not len(cls.__subclasses__()):
# only subclass add `__module_class_name`
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
return ret
@classmethod
def __deserialize_with_real_type__(cls, value: Any):
if not isinstance(value, dict):
return value
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
# add right condition to init BaseClass like Action()
return value
module_class_name = value.get("__module_class_name", None)
if module_class_name is None:
raise ValueError("Missing field: __module_class_name")
class_type = cls.__subclasses_map__.get(module_class_name, None)
if class_type is None:
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
return class_type(**value)
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
cls.__is_polymorphic_base = is_polymorphic_base
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
super().__init_subclass__(**kwargs)
class SimpleMessage(BaseModel):
content: str
role: str

View file

@ -13,6 +13,11 @@ def test_action_serialize():
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
assert "llm" not in ser_action_dict # not export
assert "__module_class_name" not in ser_action_dict
action = Action(name="test")
ser_action_dict = action.model_dump()
assert "test" in ser_action_dict["name"]
@pytest.mark.asyncio

View file

@ -35,6 +35,9 @@ def test_memory_serdeser():
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
assert new_msg2.role == "Boss"
memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]})
assert memory.count() == 2
def test_memory_serdeser_save():
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)

View file

@ -0,0 +1,58 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of polymorphic conditions
from pydantic import BaseModel, ConfigDict, SerializeAsAny
from metagpt.actions import Action
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOKV2,
ActionPass,
)
class ActionSubClasses(BaseModel):
actions: list[SerializeAsAny[Action]] = []
class ActionSubClassesNoSAA(BaseModel):
"""without SerializeAsAny"""
model_config = ConfigDict(arbitrary_types_allowed=True)
actions: list[Action] = []
def test_serialize_as_any():
"""test subclasses of action with different fields in ser&deser"""
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field
def test_no_serialize_as_any():
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
# without `SerializeAsAny`, it will serialize as Action
assert "extra_field" not in action_subcls_dict["actions"][0]
def test_polymorphic():
_ = ActionOKV2(
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
)
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert "__module_class_name" in action_subcls_dict["actions"][0]
new_action_subcls = ActionSubClasses(**action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)

View file

@ -6,6 +6,7 @@
import shutil
import pytest
from pydantic import BaseModel, SerializeAsAny
from metagpt.actions import WriteCode
from metagpt.actions.add_requirement import UserRequirement
@ -37,6 +38,20 @@ def test_roles():
assert len(role_d.actions) == 1
def test_role_subclasses():
"""test subclasses of role with same fields in ser&deser"""
class RoleSubClasses(BaseModel):
roles: list[SerializeAsAny[Role]] = []
role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()])
role_subcls_dict = role_subcls.model_dump()
new_role_subcls = RoleSubClasses(**role_subcls_dict)
assert isinstance(new_role_subcls.roles[0], RoleA)
assert isinstance(new_role_subcls.roles[1], RoleB)
def test_role_serialize():
role = Role()
ser_role_dict = role.model_dump()

View file

@ -7,8 +7,8 @@ from metagpt.actions.write_code import WriteCode
from metagpt.schema import Document, Documents, Message
from metagpt.utils.common import any_to_str
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
MockICMessage,
MockMessage,
TestICMessage,
)
@ -28,10 +28,10 @@ def test_message_serdeser():
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
message = Message(content="test_ic", instruct_content=TestICMessage())
message = Message(content="test_ic", instruct_content=MockICMessage())
ser_data = message.model_dump()
new_message = Message(**ser_data)
assert new_message.instruct_content != TestICMessage() # TODO
assert new_message.instruct_content != MockICMessage() # TODO
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
ser_data = message.model_dump()

View file

@ -16,7 +16,7 @@ from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
class TestICMessage(BaseModel):
class MockICMessage(BaseModel):
content: str = "test_ic"
@ -28,7 +28,7 @@ class MockMessage(BaseModel):
class ActionPass(Action):
name: str = Field(default="ActionPass")
name: str = "ActionPass"
async def run(self, messages: list["Message"]) -> ActionOutput:
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
@ -40,7 +40,7 @@ class ActionPass(Action):
class ActionOK(Action):
name: str = Field(default="ActionOK")
name: str = "ActionOK"
async def run(self, messages: list["Message"]) -> str:
await asyncio.sleep(5)
@ -48,12 +48,17 @@ class ActionOK(Action):
class ActionRaise(Action):
name: str = Field(default="ActionRaise")
name: str = "ActionRaise"
async def run(self, messages: list["Message"]) -> str:
raise RuntimeError("parse error in ActionRaise")
class ActionOKV2(Action):
name: str = "ActionOKV2"
extra_field: str = "ActionOKV2 Extra Info"
class RoleA(Role):
name: str = Field(default="RoleA")
profile: str = Field(default="Role A")