add SerDeserMixin for child-classes

This commit is contained in:
better629 2023-12-28 16:07:39 +08:00
parent 2dbaee0ff2
commit d0edc555b0
11 changed files with 171 additions and 96 deletions

View file

@ -13,6 +13,11 @@ def test_action_serialize():
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
assert "llm" not in ser_action_dict # not export
assert "__module_class_name" not in ser_action_dict
action = Action(name="test")
ser_action_dict = action.model_dump()
assert "test" in ser_action_dict["name"]
@pytest.mark.asyncio

View file

@ -35,6 +35,9 @@ def test_memory_serdeser():
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
assert new_msg2.role == "Boss"
memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]})
assert memory.count() == 2
def test_memory_serdeser_save():
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)

View file

@ -0,0 +1,58 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of polymorphic conditions
from pydantic import BaseModel, ConfigDict, SerializeAsAny
from metagpt.actions import Action
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOKV2,
ActionPass,
)
class ActionSubClasses(BaseModel):
actions: list[SerializeAsAny[Action]] = []
class ActionSubClassesNoSAA(BaseModel):
"""without SerializeAsAny"""
model_config = ConfigDict(arbitrary_types_allowed=True)
actions: list[Action] = []
def test_serialize_as_any():
"""test subclasses of action with different fields in ser&deser"""
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field
def test_no_serialize_as_any():
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
# without `SerializeAsAny`, it will serialize as Action
assert "extra_field" not in action_subcls_dict["actions"][0]
def test_polymorphic():
_ = ActionOKV2(
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
)
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert "__module_class_name" in action_subcls_dict["actions"][0]
new_action_subcls = ActionSubClasses(**action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)

View file

@ -6,6 +6,7 @@
import shutil
import pytest
from pydantic import BaseModel, SerializeAsAny
from metagpt.actions import WriteCode
from metagpt.actions.add_requirement import UserRequirement
@ -37,6 +38,20 @@ def test_roles():
assert len(role_d.actions) == 1
def test_role_subclasses():
"""test subclasses of role with same fields in ser&deser"""
class RoleSubClasses(BaseModel):
roles: list[SerializeAsAny[Role]] = []
role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()])
role_subcls_dict = role_subcls.model_dump()
new_role_subcls = RoleSubClasses(**role_subcls_dict)
assert isinstance(new_role_subcls.roles[0], RoleA)
assert isinstance(new_role_subcls.roles[1], RoleB)
def test_role_serialize():
role = Role()
ser_role_dict = role.model_dump()

View file

@ -7,8 +7,8 @@ from metagpt.actions.write_code import WriteCode
from metagpt.schema import Document, Documents, Message
from metagpt.utils.common import any_to_str
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
MockICMessage,
MockMessage,
TestICMessage,
)
@ -28,10 +28,10 @@ def test_message_serdeser():
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
message = Message(content="test_ic", instruct_content=TestICMessage())
message = Message(content="test_ic", instruct_content=MockICMessage())
ser_data = message.model_dump()
new_message = Message(**ser_data)
assert new_message.instruct_content != TestICMessage() # TODO
assert new_message.instruct_content != MockICMessage() # TODO
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
ser_data = message.model_dump()

View file

@ -16,7 +16,7 @@ from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
class TestICMessage(BaseModel):
class MockICMessage(BaseModel):
content: str = "test_ic"
@ -28,7 +28,7 @@ class MockMessage(BaseModel):
class ActionPass(Action):
name: str = Field(default="ActionPass")
name: str = "ActionPass"
async def run(self, messages: list["Message"]) -> ActionOutput:
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
@ -40,7 +40,7 @@ class ActionPass(Action):
class ActionOK(Action):
name: str = Field(default="ActionOK")
name: str = "ActionOK"
async def run(self, messages: list["Message"]) -> str:
await asyncio.sleep(5)
@ -48,12 +48,17 @@ class ActionOK(Action):
class ActionRaise(Action):
name: str = Field(default="ActionRaise")
name: str = "ActionRaise"
async def run(self, messages: list["Message"]) -> str:
raise RuntimeError("parse error in ActionRaise")
class ActionOKV2(Action):
name: str = "ActionOKV2"
extra_field: str = "ActionOKV2 Extra Info"
class RoleA(Role):
name: str = Field(default="RoleA")
profile: str = Field(default="Role A")