simplify some ser&desr code

This commit is contained in:
better629 2023-12-01 14:43:45 +08:00
parent 6208400f71
commit f563b2c608
5 changed files with 54 additions and 146 deletions

View file

@ -52,6 +52,12 @@ class Action(BaseModel):
super().__init_subclass__(**kwargs)
action_subclass_registry[cls.__name__] = cls
def dict(self, *args, **kwargs) -> "DictStrAny":
obj_dict = super(Action, self).dict(*args, **kwargs)
if "llm" in obj_dict:
obj_dict.pop("llm")
return obj_dict
def set_prefix(self, prefix, profile):
"""Set prefix for later usage"""
self.prefix = prefix
@ -63,20 +69,6 @@ class Action(BaseModel):
def __repr__(self):
return self.__str__()
def serialize(self):
return {
"action_class": self.__class__.__name__,
"module_name": self.__module__,
"name": self.name
}
@classmethod
def deserialize(cls, action_dict: dict) -> "Action":
action_class_str = action_dict.pop("action_class")
module_name = action_dict.pop("module_name")
action_class = import_class(action_class_str, module_name)
return action_class(**action_dict)
@classmethod
def ser_class(cls) -> dict:
""" serialize class type"""

View file

@ -70,10 +70,8 @@ class Environment(BaseModel):
roles_info = read_json_file(roles_path)
roles = []
for role_info in roles_info:
role_class = role_info.get("role_class")
role_name = role_info.get("role_name")
role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}")
# role stored in ./environment/roles/{role_class}_{role_name}
role_path = stg_path.joinpath(f'roles/{role_info.get("role_class")}_{role_info.get("role_name")}')
role = Role.deserialize(role_path)
roles.append(role)

View file

@ -36,23 +36,9 @@ class Memory(BaseModel):
super(Memory, self).__init__(**kwargs)
self.index = new_index
def dict(self,
*,
include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False) -> "DictStrAny":
def dict(self, *args, **kwargs) -> "DictStrAny":
""" overwrite the `dict` to dump dynamic pydantic model"""
obj_dict = super(Memory, self).dict(include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
obj_dict = super(Memory, self).dict(*args, **kwargs)
new_obj_dict = copy.deepcopy(obj_dict)
new_obj_dict["index"] = {}
for action, value in obj_dict["index"].items():

View file

@ -93,7 +93,7 @@ class RoleContext(BaseModel):
memory: Memory = Field(default_factory=Memory)
long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory, exclude=True) # TODO not used now
state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None
todo: Action = Field(default=None)
todo: Action = Field(default=None, exclude=True)
watch: set[Type[Action]] = Field(default_factory=set)
news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used
react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes
@ -101,7 +101,25 @@ class RoleContext(BaseModel):
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs):
watch_info = kwargs.get("watch", set())
watch = set()
for item in watch_info:
action = Action.deser_class(item)
watch.update([action])
kwargs["watch"] = watch
super(RoleContext, self).__init__(**kwargs)
def dict(self, *args, **kwargs) -> "DictStrAny":
obj_dict = super(RoleContext, self).dict(*args, **kwargs)
watch = obj_dict.get("watch", set())
watch_info = []
for item in watch:
watch_info.append(item.ser_class())
obj_dict["watch"] = watch_info
return obj_dict
def check(self, role_id: str):
if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory:
self.long_term_memory.recover_memory(role_id, self)
@ -130,7 +148,6 @@ class Role(BaseModel):
is_human: bool = False
_llm: BaseGPTAPI = Field(default_factory=LLM)
_setting: RoleSetting = Field(default_factory=RoleSetting, alias=True)
_role_id: str = ""
_states: list[str] = Field(default=[])
_actions: list[Action] = Field(default=[])
@ -168,18 +185,12 @@ class Role(BaseModel):
# 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655
self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider()
self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal,
desc=self.desc, constraints=self.constraints,
is_human=self.is_human)
self._private_attributes["_role_id"] = str(self._setting)
for key in self._private_attributes.keys():
if key in kwargs:
object.__setattr__(self, key, kwargs[key])
if key == "_setting":
setting = RoleSetting(**kwargs[key])
object.__setattr__(self, "_setting", setting)
elif key == "_rc":
if key == "_rc":
_rc = RoleContext(**kwargs["_rc"])
object.__setattr__(self, "_rc", _rc)
else:
@ -203,41 +214,23 @@ class Role(BaseModel):
object.__setattr__(self, "_actions", [])
# object.__setattr__(self, "_rc", RoleContext())
@property
def _setting(self):
return f"{self.name}({self.profile})"
def serialize(self, stg_path: Path = None):
stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \
if stg_path is None else stg_path
role_info_path = stg_path.joinpath("role_info.json")
role_info = {
role_info = self.dict(exclude={"_rc": {"memory": True}, "_llm": True})
role_info.update({
"role_class": self.__class__.__name__,
"module_name": self.__module__
}
setting = self._setting.dict()
setting.pop("desc")
setting.pop("is_human") # not all inherited roles have this atrr
role_info.update(setting)
})
role_info_path = stg_path.joinpath("role_info.json")
write_json_file(role_info_path, role_info)
actions_info_path = stg_path.joinpath("actions/actions_info.json")
actions_info = []
for action in self._actions:
actions_info.append(action.serialize())
write_json_file(actions_info_path, actions_info)
watches_info_path = stg_path.joinpath("watches/watches_info.json")
watches_info = []
for watch in self._rc.watch:
watches_info.append(watch.ser_class())
write_json_file(watches_info_path, watches_info)
actions_todo_path = stg_path.joinpath("actions/todo.json")
actions_todo = {
"cur_state": self._rc.state,
"react_mode": self._rc.react_mode.value,
"max_react_loop": self._rc.max_react_loop
}
write_json_file(actions_todo_path, actions_todo)
self._rc.memory.serialize(stg_path)
self._rc.memory.serialize(stg_path) # serialize role's memory alone
@classmethod
def deserialize(cls, stg_path: Path) -> "Role":
@ -250,35 +243,7 @@ class Role(BaseModel):
role_class = import_class(class_name=role_class_str, module_name=module_name)
role = role_class(**role_info) # initiate particular Role
actions_info_path = stg_path.joinpath("actions/actions_info.json")
actions = []
actions_info = read_json_file(actions_info_path)
for action_info in actions_info:
action = Action.deser_class(action_info)
actions.append(action)
watches_info_path = stg_path.joinpath("watches/watches_info.json")
watches = []
watches_info = read_json_file(watches_info_path)
for watch_info in watches_info:
action = Action.deser_class(watch_info)
watches.append(action)
role.init_actions(actions)
role.watch(watches)
actions_todo_path = stg_path.joinpath("actions/todo.json")
# recover self._rc.state
actions_todo = read_json_file(actions_todo_path)
max_react_loop = actions_todo.get("max_react_loop", 1)
cur_state = actions_todo.get("cur_state", -1)
role.set_state(cur_state)
role.set_recovered(True)
react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value)
if react_mode_str not in RoleReactMode.values():
logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default")
react_mode_str = RoleReactMode.REACT.value
role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop)
role.set_recovered(True) # set True to make a tag
role_memory = Memory.deserialize(stg_path)
role.set_memory(role_memory)
@ -299,9 +264,9 @@ class Role(BaseModel):
for idx, action in enumerate(actions):
if not isinstance(action, Action):
## 默认初始化
i = action(name="", llm=self._llm)
i = action(llm=self._llm)
else:
if self._setting.is_human and not isinstance(action.llm, HumanProvider):
if self.is_human and not isinstance(action.llm, HumanProvider):
logger.warning(f"is_human attribute does not take effect,"
f"as Role's {str(action)} was initialized using LLM, try passing in Action classes instead of initialized instances")
i = action
@ -357,9 +322,14 @@ class Role(BaseModel):
def _get_prefix(self):
"""Get the role prefix"""
if self._setting.desc:
return self._setting.desc
return PREFIX_TEMPLATE.format(**self._setting.dict())
if self.desc:
return self.desc
return PREFIX_TEMPLATE.format(**{
"profile": self.profile,
"name": self.name,
"goal": self.goal,
"constraints": self.constraints
})
async def _think(self) -> None:
"""Think about what to do and decide on the next action"""

View file

@ -48,23 +48,9 @@ class Message(BaseModel):
kwargs["cause_by"] = action_class.deser_class(cause_by)
super(Message, self).__init__(**kwargs)
def dict(self,
*,
include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False) -> "DictStrAny":
def dict(self, *args, **kwargs) -> "DictStrAny":
""" overwrite the `dict` to dump dynamic pydantic model"""
obj_dict = super(Message, self).dict(include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
obj_dict = super(Message, self).dict(*args, **kwargs)
ic = self.instruct_content # deal custom-defined action
if ic:
schema = ic.schema()
@ -77,19 +63,6 @@ class Message(BaseModel):
obj_dict["cause_by"] = cb.ser_class()
return obj_dict
#
#
# @dataclass
# class Message:
# """list[<role>: <content>]"""
# content: str
# instruct_content: BaseModel = field(default=None)
# role: str = field(default='user') # system / user / assistant
# cause_by: Type["Action"] = field(default="")
# sent_from: str = field(default="")
# send_to: str = field(default="")
# restricted_to: str = field(default="")
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
return f"{self.role}: {self.content}"
@ -97,17 +70,6 @@ class Message(BaseModel):
def __repr__(self):
return self.__str__()
# def dict(self):
# return {
# "content": self.content,
# "instruct_content": self.instruct_content,
# "role": self.role,
# "cause_by": self.cause_by,
# "sent_from": self.sent_from,
# "send_to": self.send_to,
# "restricted_to": self.restricted_to
# }
def to_dict(self) -> dict:
return {
"role": self.role,