mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-20 15:38:09 +02:00
add SerDeserMixin for child-classes
This commit is contained in:
parent
2dbaee0ff2
commit
d0edc555b0
11 changed files with 171 additions and 96 deletions
|
|
@ -13,6 +13,11 @@ def test_action_serialize():
|
|||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
assert "__module_class_name" not in ser_action_dict
|
||||
|
||||
action = Action(name="test")
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "test" in ser_action_dict["name"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ def test_memory_serdeser():
|
|||
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
|
||||
assert new_msg2.role == "Boss"
|
||||
|
||||
memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]})
|
||||
assert memory.count() == 2
|
||||
|
||||
|
||||
def test_memory_serdeser_save():
|
||||
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)
|
||||
|
|
|
|||
58
tests/metagpt/serialize_deserialize/test_polymorphic.py
Normal file
58
tests/metagpt/serialize_deserialize/test_polymorphic.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of polymorphic conditions
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SerializeAsAny
|
||||
|
||||
from metagpt.actions import Action
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOKV2,
|
||||
ActionPass,
|
||||
)
|
||||
|
||||
|
||||
class ActionSubClasses(BaseModel):
|
||||
actions: list[SerializeAsAny[Action]] = []
|
||||
|
||||
|
||||
class ActionSubClassesNoSAA(BaseModel):
|
||||
"""without SerializeAsAny"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
actions: list[Action] = []
|
||||
|
||||
|
||||
def test_serialize_as_any():
|
||||
"""test subclasses of action with different fields in ser&deser"""
|
||||
# ActionOKV2 with a extra field `extra_field`
|
||||
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field
|
||||
|
||||
|
||||
def test_no_serialize_as_any():
|
||||
# ActionOKV2 with a extra field `extra_field`
|
||||
action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
# without `SerializeAsAny`, it will serialize as Action
|
||||
assert "extra_field" not in action_subcls_dict["actions"][0]
|
||||
|
||||
|
||||
def test_polymorphic():
|
||||
_ = ActionOKV2(
|
||||
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
|
||||
)
|
||||
|
||||
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
|
||||
assert "__module_class_name" in action_subcls_dict["actions"][0]
|
||||
|
||||
new_action_subcls = ActionSubClasses(**action_subcls_dict)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
||||
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
import shutil
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
|
|
@ -37,6 +38,20 @@ def test_roles():
|
|||
assert len(role_d.actions) == 1
|
||||
|
||||
|
||||
def test_role_subclasses():
|
||||
"""test subclasses of role with same fields in ser&deser"""
|
||||
|
||||
class RoleSubClasses(BaseModel):
|
||||
roles: list[SerializeAsAny[Role]] = []
|
||||
|
||||
role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()])
|
||||
role_subcls_dict = role_subcls.model_dump()
|
||||
|
||||
new_role_subcls = RoleSubClasses(**role_subcls_dict)
|
||||
assert isinstance(new_role_subcls.roles[0], RoleA)
|
||||
assert isinstance(new_role_subcls.roles[1], RoleB)
|
||||
|
||||
|
||||
def test_role_serialize():
|
||||
role = Role()
|
||||
ser_role_dict = role.model_dump()
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from metagpt.actions.write_code import WriteCode
|
|||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
MockICMessage,
|
||||
MockMessage,
|
||||
TestICMessage,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -28,10 +28,10 @@ def test_message_serdeser():
|
|||
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
|
||||
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
|
||||
|
||||
message = Message(content="test_ic", instruct_content=TestICMessage())
|
||||
message = Message(content="test_ic", instruct_content=MockICMessage())
|
||||
ser_data = message.model_dump()
|
||||
new_message = Message(**ser_data)
|
||||
assert new_message.instruct_content != TestICMessage() # TODO
|
||||
assert new_message.instruct_content != MockICMessage() # TODO
|
||||
|
||||
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
|
||||
ser_data = message.model_dump()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from metagpt.roles.role import Role, RoleReactMode
|
|||
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
|
||||
|
||||
|
||||
class TestICMessage(BaseModel):
|
||||
class MockICMessage(BaseModel):
|
||||
content: str = "test_ic"
|
||||
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ class MockMessage(BaseModel):
|
|||
|
||||
|
||||
class ActionPass(Action):
|
||||
name: str = Field(default="ActionPass")
|
||||
name: str = "ActionPass"
|
||||
|
||||
async def run(self, messages: list["Message"]) -> ActionOutput:
|
||||
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
|
||||
|
|
@ -40,7 +40,7 @@ class ActionPass(Action):
|
|||
|
||||
|
||||
class ActionOK(Action):
|
||||
name: str = Field(default="ActionOK")
|
||||
name: str = "ActionOK"
|
||||
|
||||
async def run(self, messages: list["Message"]) -> str:
|
||||
await asyncio.sleep(5)
|
||||
|
|
@ -48,12 +48,17 @@ class ActionOK(Action):
|
|||
|
||||
|
||||
class ActionRaise(Action):
|
||||
name: str = Field(default="ActionRaise")
|
||||
name: str = "ActionRaise"
|
||||
|
||||
async def run(self, messages: list["Message"]) -> str:
|
||||
raise RuntimeError("parse error in ActionRaise")
|
||||
|
||||
|
||||
class ActionOKV2(Action):
|
||||
name: str = "ActionOKV2"
|
||||
extra_field: str = "ActionOKV2 Extra Info"
|
||||
|
||||
|
||||
class RoleA(Role):
|
||||
name: str = Field(default="RoleA")
|
||||
profile: str = Field(default="Role A")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue