add pydantic v2 support and change role's private fields into public

This commit is contained in:
better629 2023-12-27 14:00:54 +08:00
parent 66925dd791
commit afaa7385c4
67 changed files with 518 additions and 555 deletions

View file

@ -23,9 +23,16 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, TypeVar
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
field_serializer,
field_validator,
)
from metagpt.config import CONFIG
from metagpt.const import (
@ -102,33 +109,64 @@ class Documents(BaseModel):
class Message(BaseModel):
"""list[<role>: <content>]"""
id: str # According to Section 2.2.3.1.1 of RFC 135
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
content: str
instruct_content: BaseModel = None
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
role: str = "user" # system / user / assistant
cause_by: str = ""
sent_from: str = ""
send_to: Set = Field(default={MESSAGE_ROUTE_TO_ALL})
cause_by: str = Field(default="", validate_default=True)
sent_from: str = Field(default="", validate_default=True)
send_to: set = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
def __init__(self, content: str = "", **kwargs):
ic = kwargs.get("instruct_content", None)
@field_validator("id", mode="before")
@classmethod
def check_id(cls, id: str) -> str:
return id if id else uuid.uuid4().hex
@field_validator("instruct_content", mode="before")
@classmethod
def check_instruct_content(cls, ic: Any) -> BaseModel:
if ic and not isinstance(ic, BaseModel) and "class" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
ic_new = ic_obj(**ic["value"])
kwargs["instruct_content"] = ic_new
ic = ic_obj(**ic["value"])
return ic
kwargs["id"] = kwargs.get("id", uuid.uuid4().hex)
kwargs["content"] = kwargs.get("content", content)
kwargs["cause_by"] = any_to_str(
kwargs.get("cause_by", import_class("UserRequirement", "metagpt.actions.add_requirement"))
)
kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", ""))
kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL}))
super(Message, self).__init__(**kwargs)
@field_validator("cause_by", mode="before")
@classmethod
def check_cause_by(cls, cause_by: Any) -> str:
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
@field_validator("sent_from", mode="before")
@classmethod
def check_sent_from(cls, sent_from: Any) -> str:
return any_to_str(sent_from if sent_from else "")
@field_validator("send_to", mode="before")
@classmethod
def check_send_to(cls, send_to: Any) -> set:
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
@field_serializer("instruct_content", mode="plain")
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
ic_dict = None
if ic:
# compatible with custom-defined ActionOutput
schema = ic.model_json_schema()
# `Documents` contain definitions
if "definitions" not in schema:
# TODO refine with nested BaseModel
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
return ic_dict
def __init__(self, content: str = "", **data: Any):
data["content"] = data.get("content", content)
super().__init__(**data)
def __setattr__(self, key, val):
"""Override `@property.setter`, convert non-string parameters into string parameters."""
@ -142,22 +180,6 @@ class Message(BaseModel):
new_val = val
super().__setattr__(key, new_val)
def dict(self, *args, **kwargs) -> dict[str, Any]:
"""overwrite the `dict` to dump dynamic pydantic model"""
obj_dict = super(Message, self).model_dump(*args, **kwargs)
ic = self.instruct_content
if ic:
# compatible with custom-defined ActionOutput
schema = ic.model_json_schema()
# `Documents` contain definitions
if "definitions" not in schema:
# TODO refine with nested BaseModel
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
return obj_dict
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
if self.instruct_content:
@ -173,7 +195,7 @@ class Message(BaseModel):
def dump(self) -> str:
"""Convert the object to json string"""
return self.json(exclude_none=True)
return self.model_dump_json(exclude_none=True)
@staticmethod
@handle_exception(exception_type=JSONDecodeError, default_return=None)