update environment/message to BaseModel, update the ser&deser of roles/actions

This commit is contained in:
better629 2023-11-30 15:18:24 +08:00
parent 9e5c873d77
commit 5e3607f85b
26 changed files with 458 additions and 252 deletions

View file

@ -5,18 +5,17 @@
@Author : alexanderwu
@File : schema.py
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Type, TypedDict
import copy
from typing import Type, TypedDict, Union, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from pydantic.main import ModelMetaclass
from metagpt.logs import logger
# from metagpt.utils.serialize import actionoutout_schema_to_mapping
# from metagpt.actions.action_output import ActionOutput
# from metagpt.actions.action import Action
from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \
actionoutput_str_to_mapping
from metagpt.utils.utils import import_class
class RawMessage(TypedDict):
@ -24,16 +23,72 @@ class RawMessage(TypedDict):
role: str
@dataclass
class Message:
"""list[<role>: <content>]"""
content: str
instruct_content: BaseModel = field(default=None)
role: str = field(default='user') # system / user / assistant
cause_by: Type["Action"] = field(default="")
sent_from: str = field(default="")
send_to: str = field(default="")
restricted_to: str = field(default="")
class Message(BaseModel):
content: str = ""
instruct_content: BaseModel = Field(default=None)
role: str = "user" # system / user / assistant
cause_by: Type["Action"] = Field(default=None)
sent_from: str = ""
send_to: str = ""
restricted_to: str = ""
def __init__(self, **kwargs):
instruct_content = kwargs.get("instruct_content", None)
cause_by = kwargs.get("cause_by", None)
if instruct_content and not isinstance(instruct_content, BaseModel):
ic = instruct_content
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output")
ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping)
ic_new = ic_obj(**ic["value"])
kwargs["instruct_content"] = ic_new
if cause_by and not isinstance(cause_by, ModelMetaclass):
action_class = import_class("Action", "metagpt.actions.action")
kwargs["cause_by"] = action_class.deser_class(cause_by)
super(Message, self).__init__(**kwargs)
def dict(self,
*,
include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False) -> "DictStrAny":
""" overwrite the `dict` to dump dynamic pydantic model"""
obj_dict = super(Message, self).dict(include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
ic = self.instruct_content # deal custom-defined action
if ic:
schema = ic.schema()
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
cb = self.cause_by
if cb:
obj_dict["cause_by"] = cb.ser_class()
return obj_dict
#
#
# @dataclass
# class Message:
# """list[<role>: <content>]"""
# content: str
# instruct_content: BaseModel = field(default=None)
# role: str = field(default='user') # system / user / assistant
# cause_by: Type["Action"] = field(default="")
# sent_from: str = field(default="")
# send_to: str = field(default="")
# restricted_to: str = field(default="")
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
@ -42,45 +97,16 @@ class Message:
def __repr__(self):
return self.__str__()
# def serialize(self):
# message_cp: Message = copy.deepcopy(self)
# ic = message_cp.instruct_content
# if ic:
# # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
# schema = ic.schema()
# mapping = actionoutout_schema_to_mapping(schema)
#
# message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
# cb = message_cp.cause_by
# if cb:
# message_cp.cause_by = cb.serialize()
#
# return message_cp.dict()
#
# @classmethod
# def deserialize(cls, message_dict: dict):
# instruct_content = message_dict.get("instruct_content")
# if instruct_content:
# ic = instruct_content
# ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
# ic_new = ic_obj(**ic["value"])
# message_dict.instruct_content = ic_new
# cause_by = message_dict.get("cause_by")
# if cause_by:
# message_dict.cause_by = Action.deserialize(cause_by)
#
# return Message(**message_dict)
def dict(self):
return {
"content": self.content,
"instruct_content": self.instruct_content,
"role": self.role,
"cause_by": self.cause_by,
"sent_from": self.sent_from,
"send_to": self.send_to,
"restricted_to": self.restricted_to
}
# def dict(self):
# return {
# "content": self.content,
# "instruct_content": self.instruct_content,
# "role": self.role,
# "cause_by": self.cause_by,
# "sent_from": self.sent_from,
# "send_to": self.send_to,
# "restricted_to": self.restricted_to
# }
def to_dict(self) -> dict:
return {