mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add SerDeserMixin for child-classes
This commit is contained in:
parent
2dbaee0ff2
commit
d0edc555b0
11 changed files with 171 additions and 96 deletions
|
|
@ -10,7 +10,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
|
|
@ -19,13 +19,12 @@ from metagpt.schema import (
|
|||
CodeSummarizeContext,
|
||||
CodingContext,
|
||||
RunCodeContext,
|
||||
SerDeserMixin,
|
||||
TestingContext,
|
||||
)
|
||||
|
||||
action_subclass_registry = {}
|
||||
|
||||
|
||||
class Action(BaseModel):
|
||||
class Action(SerDeserMixin, is_polymorphic_base=True):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
||||
name: str = ""
|
||||
|
|
@ -35,9 +34,6 @@ class Action(BaseModel):
|
|||
desc: str = "" # for skill manager
|
||||
node: ActionNode = Field(default=None, exclude=True)
|
||||
|
||||
# builtin variables
|
||||
builtin_class_name: str = ""
|
||||
|
||||
def __init_with_instruction(self, instruction: str):
|
||||
"""Initialize action with instruction"""
|
||||
self.node = ActionNode(key=self.name, expected_type=str, instruction=instruction, example="", schema="raw")
|
||||
|
|
@ -46,17 +42,9 @@ class Action(BaseModel):
|
|||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
# deserialize child classes dynamically for inherited `action`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
self.model_fields["builtin_class_name"].default = self.__class__.__name__
|
||||
|
||||
if "instruction" in data:
|
||||
self.__init_with_instruction(data["instruction"])
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
action_subclass_registry[cls.__name__] = cls
|
||||
|
||||
def set_prefix(self, prefix):
|
||||
"""Set prefix for later usage"""
|
||||
self.prefix = prefix
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@
|
|||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set, Union
|
||||
from typing import Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role, role_subclass_registry
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
|
||||
|
||||
|
|
@ -32,28 +32,10 @@ class Environment(BaseModel):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
desc: str = Field(default="") # 环境描述
|
||||
roles: dict[str, Role] = Field(default_factory=dict, validate_default=True)
|
||||
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
|
||||
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
|
||||
history: str = "" # For debug
|
||||
|
||||
@field_validator("roles", mode="before")
|
||||
@classmethod
|
||||
def check_roles(cls, roles: dict[str, Union[Role, dict]]) -> dict[str, Role]:
|
||||
new_roles = dict()
|
||||
for role_key, role in roles.items():
|
||||
if isinstance(role, dict):
|
||||
item_class_name = role.get("builtin_class_name", None)
|
||||
if item_class_name:
|
||||
for name, subclass in role_subclass_registry.items():
|
||||
registery_class_name = subclass.model_fields["builtin_class_name"].default
|
||||
if item_class_name == registery_class_name:
|
||||
new_role = subclass(**role)
|
||||
break
|
||||
new_roles[role_key] = new_role
|
||||
else:
|
||||
new_roles[role_key] = role
|
||||
return new_roles
|
||||
|
||||
@model_validator(mode="after")
|
||||
def init_roles(self):
|
||||
self.add_roles(self.roles.values())
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@
|
|||
"""
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
from typing import DefaultDict, Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SerializeAsAny
|
||||
|
||||
from metagpt.const import IGNORED_MESSAGE_ID
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -25,19 +25,10 @@ from metagpt.utils.common import (
|
|||
class Memory(BaseModel):
|
||||
"""The most basic memory: super-memory"""
|
||||
|
||||
storage: list[Message] = []
|
||||
index: dict[str, list[Message]] = Field(default_factory=defaultdict(list))
|
||||
storage: list[SerializeAsAny[Message]] = []
|
||||
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
|
||||
ignore_id: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
index = kwargs.get("index", {})
|
||||
new_index = defaultdict(list)
|
||||
for action_str, value in index.items():
|
||||
new_index[action_str] = [Message(**item_dict) for item_dict in value]
|
||||
kwargs["index"] = new_index
|
||||
super(Memory, self).__init__(**kwargs)
|
||||
self.index = new_index
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
|
|
|
|||
|
|
@ -24,12 +24,11 @@ from __future__ import annotations
|
|||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional, Set, Type, Union
|
||||
from typing import Any, Iterable, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action import action_subclass_registry
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
|
|
@ -37,7 +36,7 @@ from metagpt.llm import LLM, HumanProvider
|
|||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message, MessageQueue
|
||||
from metagpt.schema import Message, MessageQueue, SerDeserMixin
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
any_to_str,
|
||||
|
|
@ -127,10 +126,7 @@ class RoleContext(BaseModel):
|
|||
return self.memory.get()
|
||||
|
||||
|
||||
role_subclass_registry = {}
|
||||
|
||||
|
||||
class Role(BaseModel):
|
||||
class Role(SerDeserMixin, is_polymorphic_base=True):
|
||||
"""Role/Agent"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
|
@ -147,34 +143,16 @@ class Role(BaseModel):
|
|||
) # Each role has its own LLM, use different system message
|
||||
role_id: str = ""
|
||||
states: list[str] = []
|
||||
actions: list[Action] = Field(default=[], validate_default=True)
|
||||
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
|
||||
rc: RoleContext = Field(default_factory=RoleContext)
|
||||
subscription: set[str] = set()
|
||||
|
||||
# builtin variables
|
||||
recovered: bool = False # to tag if a recovered role
|
||||
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
|
||||
builtin_class_name: str = ""
|
||||
|
||||
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
|
||||
|
||||
@field_validator("actions", mode="before")
|
||||
@classmethod
|
||||
def check_actions(cls, actions: list[Union[dict, Action]]) -> list[Action]:
|
||||
new_actions = []
|
||||
for action in actions:
|
||||
new_action = action
|
||||
if isinstance(action, dict):
|
||||
item_class_name = action.get("builtin_class_name", None)
|
||||
if item_class_name:
|
||||
for name, subclass in action_subclass_registry.items():
|
||||
registery_class_name = subclass.model_fields["builtin_class_name"].default
|
||||
if item_class_name == registery_class_name:
|
||||
new_action = subclass(**action)
|
||||
break
|
||||
new_actions.append(new_action)
|
||||
return new_actions
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_subscription(self) -> set:
|
||||
if not self.subscription:
|
||||
|
|
@ -191,20 +169,11 @@ class Role(BaseModel):
|
|||
super().__init__(**data)
|
||||
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
|
||||
# deserialize child classes dynamically for inherited `role`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
self.model_fields["builtin_class_name"].default = self.__class__.__name__
|
||||
|
||||
self._watch(data.get("watch") or [UserRequirement])
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
role_subclass_registry[cls.__name__] = cls
|
||||
|
||||
def _reset(self):
|
||||
object.__setattr__(self, "states", [])
|
||||
object.__setattr__(self, "actions", [])
|
||||
self.states = []
|
||||
self.actions = []
|
||||
|
||||
@property
|
||||
def _setting(self):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from abc import ABC
|
|||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
|
|
@ -33,6 +33,7 @@ from pydantic import (
|
|||
field_serializer,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
|
|
@ -53,6 +54,64 @@ from metagpt.utils.serialize import (
|
|||
)
|
||||
|
||||
|
||||
class SerDeserMixin(BaseModel):
|
||||
"""SereDeserMixin for subclass' ser&deser"""
|
||||
|
||||
__is_polymorphic_base = False
|
||||
__subclasses_map__ = {}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: type["SerDeserMixin"], handler: Callable[[Any], core_schema.CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = handler(source)
|
||||
og_schema_ref = schema["ref"]
|
||||
schema["ref"] += ":mixin"
|
||||
|
||||
return core_schema.no_info_before_validator_function(
|
||||
cls.__deserialize_with_real_type__,
|
||||
schema=schema,
|
||||
ref=og_schema_ref,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __serialize_add_class_type__(
|
||||
cls,
|
||||
value,
|
||||
handler: core_schema.SerializerFunctionWrapHandler,
|
||||
) -> Any:
|
||||
ret = handler(value)
|
||||
if not len(cls.__subclasses__()):
|
||||
# only subclass add `__module_class_name`
|
||||
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def __deserialize_with_real_type__(cls, value: Any):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
|
||||
# add right condition to init BaseClass like Action()
|
||||
return value
|
||||
module_class_name = value.get("__module_class_name", None)
|
||||
if module_class_name is None:
|
||||
raise ValueError("Missing field: __module_class_name")
|
||||
|
||||
class_type = cls.__subclasses_map__.get(module_class_name, None)
|
||||
|
||||
if class_type is None:
|
||||
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
|
||||
|
||||
return class_type(**value)
|
||||
|
||||
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
|
||||
cls.__is_polymorphic_base = is_polymorphic_base
|
||||
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
class SimpleMessage(BaseModel):
|
||||
content: str
|
||||
role: str
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@ def test_action_serialize():
|
|||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
assert "__module_class_name" not in ser_action_dict
|
||||
|
||||
action = Action(name="test")
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "test" in ser_action_dict["name"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ def test_memory_serdeser():
|
|||
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
|
||||
assert new_msg2.role == "Boss"
|
||||
|
||||
memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]})
|
||||
assert memory.count() == 2
|
||||
|
||||
|
||||
def test_memory_serdeser_save():
|
||||
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)
|
||||
|
|
|
|||
58
tests/metagpt/serialize_deserialize/test_polymorphic.py
Normal file
58
tests/metagpt/serialize_deserialize/test_polymorphic.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of polymorphic conditions
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SerializeAsAny
|
||||
|
||||
from metagpt.actions import Action
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOKV2,
|
||||
ActionPass,
|
||||
)
|
||||
|
||||
|
||||
class ActionSubClasses(BaseModel):
|
||||
actions: list[SerializeAsAny[Action]] = []
|
||||
|
||||
|
||||
class ActionSubClassesNoSAA(BaseModel):
|
||||
"""without SerializeAsAny"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
actions: list[Action] = []
|
||||
|
||||
|
||||
def test_serialize_as_any():
|
||||
"""test subclasses of action with different fields in ser&deser"""
|
||||
# ActionOKV2 with a extra field `extra_field`
|
||||
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field
|
||||
|
||||
|
||||
def test_no_serialize_as_any():
|
||||
# ActionOKV2 with a extra field `extra_field`
|
||||
action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
# without `SerializeAsAny`, it will serialize as Action
|
||||
assert "extra_field" not in action_subcls_dict["actions"][0]
|
||||
|
||||
|
||||
def test_polymorphic():
|
||||
_ = ActionOKV2(
|
||||
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
|
||||
)
|
||||
|
||||
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
|
||||
assert "__module_class_name" in action_subcls_dict["actions"][0]
|
||||
|
||||
new_action_subcls = ActionSubClasses(**action_subcls_dict)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
||||
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
import shutil
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, SerializeAsAny
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
|
|
@ -37,6 +38,20 @@ def test_roles():
|
|||
assert len(role_d.actions) == 1
|
||||
|
||||
|
||||
def test_role_subclasses():
|
||||
"""test subclasses of role with same fields in ser&deser"""
|
||||
|
||||
class RoleSubClasses(BaseModel):
|
||||
roles: list[SerializeAsAny[Role]] = []
|
||||
|
||||
role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()])
|
||||
role_subcls_dict = role_subcls.model_dump()
|
||||
|
||||
new_role_subcls = RoleSubClasses(**role_subcls_dict)
|
||||
assert isinstance(new_role_subcls.roles[0], RoleA)
|
||||
assert isinstance(new_role_subcls.roles[1], RoleB)
|
||||
|
||||
|
||||
def test_role_serialize():
|
||||
role = Role()
|
||||
ser_role_dict = role.model_dump()
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from metagpt.actions.write_code import WriteCode
|
|||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
MockICMessage,
|
||||
MockMessage,
|
||||
TestICMessage,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -28,10 +28,10 @@ def test_message_serdeser():
|
|||
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
|
||||
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
|
||||
|
||||
message = Message(content="test_ic", instruct_content=TestICMessage())
|
||||
message = Message(content="test_ic", instruct_content=MockICMessage())
|
||||
ser_data = message.model_dump()
|
||||
new_message = Message(**ser_data)
|
||||
assert new_message.instruct_content != TestICMessage() # TODO
|
||||
assert new_message.instruct_content != MockICMessage() # TODO
|
||||
|
||||
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
|
||||
ser_data = message.model_dump()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from metagpt.roles.role import Role, RoleReactMode
|
|||
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
|
||||
|
||||
|
||||
class TestICMessage(BaseModel):
|
||||
class MockICMessage(BaseModel):
|
||||
content: str = "test_ic"
|
||||
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ class MockMessage(BaseModel):
|
|||
|
||||
|
||||
class ActionPass(Action):
|
||||
name: str = Field(default="ActionPass")
|
||||
name: str = "ActionPass"
|
||||
|
||||
async def run(self, messages: list["Message"]) -> ActionOutput:
|
||||
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
|
||||
|
|
@ -40,7 +40,7 @@ class ActionPass(Action):
|
|||
|
||||
|
||||
class ActionOK(Action):
|
||||
name: str = Field(default="ActionOK")
|
||||
name: str = "ActionOK"
|
||||
|
||||
async def run(self, messages: list["Message"]) -> str:
|
||||
await asyncio.sleep(5)
|
||||
|
|
@ -48,12 +48,17 @@ class ActionOK(Action):
|
|||
|
||||
|
||||
class ActionRaise(Action):
|
||||
name: str = Field(default="ActionRaise")
|
||||
name: str = "ActionRaise"
|
||||
|
||||
async def run(self, messages: list["Message"]) -> str:
|
||||
raise RuntimeError("parse error in ActionRaise")
|
||||
|
||||
|
||||
class ActionOKV2(Action):
|
||||
name: str = "ActionOKV2"
|
||||
extra_field: str = "ActionOKV2 Extra Info"
|
||||
|
||||
|
||||
class RoleA(Role):
|
||||
name: str = Field(default="RoleA")
|
||||
profile: str = Field(default="Role A")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue