rm expicit serialize&deserialize interface and update unittests

This commit is contained in:
better629 2024-01-08 22:15:56 +08:00
parent d2233beff4
commit 98ee696cf0
27 changed files with 154 additions and 290 deletions

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
@ -32,8 +32,9 @@ from pydantic import (
PrivateAttr,
field_serializer,
field_validator,
model_serializer,
model_validator,
)
from pydantic_core import core_schema
from metagpt.const import (
MESSAGE_ROUTE_CAUSE_BY,
@ -53,7 +54,7 @@ from metagpt.utils.serialize import (
)
class SerializationMixin(BaseModel):
class SerializationMixin(BaseModel, extra="forbid"):
"""
PolyMorphic subclasses Serialization / Deserialization Mixin
- First of all, we need to know that pydantic is not designed for polymorphism.
@ -68,49 +69,44 @@ class SerializationMixin(BaseModel):
__is_polymorphic_base = False
__subclasses_map__ = {}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
schema = handler(source)
og_schema_ref = schema["ref"]
schema["ref"] += ":mixin"
return core_schema.no_info_before_validator_function(
cls.__deserialize_with_real_type__,
schema=schema,
ref=og_schema_ref,
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
)
@classmethod
def __serialize_add_class_type__(
cls,
value,
handler: core_schema.SerializerFunctionWrapHandler,
) -> Any:
ret = handler(value)
if not len(cls.__subclasses__()):
# only subclass add `__module_class_name`
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
@model_serializer(mode="wrap")
def __serialize_with_class_type__(self, default_serializer) -> Any:
# default serializer, then append the `__module_class_name` field and return
ret = default_serializer(self)
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
return ret
@model_validator(mode="wrap")
@classmethod
def __deserialize_with_real_type__(cls, value: Any):
if not isinstance(value, dict):
return value
def __convert_to_real_type__(cls, value: Any, handler):
if isinstance(value, dict) is False:
return handler(value)
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
# add right condition to init BaseClass like Action()
return value
module_class_name = value.get("__module_class_name", None)
if module_class_name is None:
raise ValueError("Missing field: __module_class_name")
# it is a dict so make sure to remove the __module_class_name
# because we don't allow extra keywords but want to ensure
# e.g Cat.model_validate(cat.model_dump()) works
class_full_name = value.pop("__module_class_name", None)
class_type = cls.__subclasses_map__.get(module_class_name, None)
# if it's not the polymorphic base we construct via default handler
if not cls.__is_polymorphic_base:
if class_full_name is None:
return handler(value)
elif str(cls) == f"<class '{class_full_name}'>":
return handler(value)
else:
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
pass
# otherwise we lookup the correct polymorphic type and construct that
# instead
if class_full_name is None:
raise ValueError("Missing __module_class_name field")
class_type = cls.__subclasses_map__.get(class_full_name, None)
if class_type is None:
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
# TODO could try dynamic import
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
return class_type(**value)