Merge pull request #728 from better629/feat_simple_ser

Feat simpler serialization in one file
This commit is contained in:
geekan 2024-01-09 16:31:50 +08:00 committed by GitHub
commit 9a42a14c91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 232 additions and 316 deletions

View file

@ -27,7 +27,7 @@ from metagpt.schema import (
from metagpt.utils.file_repository import FileRepository
class Action(SerializationMixin, is_polymorphic_base=True):
class Action(SerializationMixin):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
name: str = ""

View file

@ -12,7 +12,6 @@
functionality is to be consolidated into the `Environment` class.
"""
import asyncio
from pathlib import Path
from typing import Iterable, Set
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -21,7 +20,7 @@ from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import is_send_to, read_json_file, write_json_file
from metagpt.utils.common import is_send_to
class Environment(BaseModel):
@ -42,44 +41,6 @@ class Environment(BaseModel):
self.add_roles(self.roles.values())
return self
def serialize(self, stg_path: Path):
roles_path = stg_path.joinpath("roles.json")
roles_info = []
for role_key, role in self.roles.items():
roles_info.append(
{
"role_class": role.__class__.__name__,
"module_name": role.__module__,
"role_name": role.name,
"role_sub_tags": list(self.member_addrs.get(role)),
}
)
role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}"))
write_json_file(roles_path, roles_info)
history_path = stg_path.joinpath("history.json")
write_json_file(history_path, {"content": self.history})
@classmethod
def deserialize(cls, stg_path: Path) -> "Environment":
"""stg_path: ./storage/team/environment/"""
roles_path = stg_path.joinpath("roles.json")
roles_info = read_json_file(roles_path)
roles = []
for role_info in roles_info:
# role stored in ./environment/roles/{role_class}_{role_name}
role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}")
role = Role.deserialize(role_path)
roles.append(role)
history = read_json_file(stg_path.joinpath("history.json"))
history = history.get("content")
environment = Environment(**{"history": history})
environment.add_roles(roles)
return environment
def add_role(self, role: Role):
"""增加一个在当前环境的角色
Add a role in the current environment

View file

@ -7,19 +7,13 @@
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
"""
from collections import defaultdict
from pathlib import Path
from typing import DefaultDict, Iterable, Set
from pydantic import BaseModel, Field, SerializeAsAny
from metagpt.const import IGNORED_MESSAGE_ID
from metagpt.schema import Message
from metagpt.utils.common import (
any_to_str,
any_to_str_set,
read_json_file,
write_json_file,
)
from metagpt.utils.common import any_to_str, any_to_str_set
class Memory(BaseModel):
@ -29,22 +23,6 @@ class Memory(BaseModel):
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
ignore_id: bool = False
def serialize(self, stg_path: Path):
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")
storage = self.model_dump()
write_json_file(memory_path, storage)
@classmethod
def deserialize(cls, stg_path: Path) -> "Memory":
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")
memory_dict = read_json_file(memory_path)
memory = Memory(**memory_dict)
return memory
def add(self, message: Message):
"""Add a new message to storage, while updating the index"""
if self.ignore_id:

View file

@ -23,7 +23,6 @@
from __future__ import annotations
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, Optional, Set, Type
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -31,7 +30,6 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
from metagpt.context import Context, context
from metagpt.llm import LLM
from metagpt.logs import logger
@ -39,14 +37,7 @@ from metagpt.memory import Memory
from metagpt.provider import HumanProvider
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, MessageQueue, SerializationMixin
from metagpt.utils.common import (
any_to_name,
any_to_str,
import_class,
read_json_file,
role_raise_decorator,
write_json_file,
)
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
@ -128,7 +119,7 @@ class RoleContext(BaseModel):
return self.memory.get()
class Role(SerializationMixin, is_polymorphic_base=True):
class Role(SerializationMixin):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
@ -217,6 +208,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
if self.latest_observed_msg:
self.recovered = True
def _reset(self):
self.states = []
self.actions = []
@ -225,47 +219,12 @@ class Role(SerializationMixin, is_polymorphic_base=True):
def _setting(self):
return f"{self.name}({self.profile})"
def serialize(self, stg_path: Path = None):
stg_path = (
SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}")
if stg_path is None
else stg_path
)
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
role_info_path = stg_path.joinpath("role_info.json")
write_json_file(role_info_path, role_info)
self.rc.memory.serialize(stg_path) # serialize role's memory alone
@classmethod
def deserialize(cls, stg_path: Path) -> "Role":
"""stg_path = ./storage/team/environment/roles/{role_class}_{role_name}"""
role_info_path = stg_path.joinpath("role_info.json")
role_info = read_json_file(role_info_path)
role_class_str = role_info.pop("role_class")
module_name = role_info.pop("module_name")
role_class = import_class(class_name=role_class_str, module_name=module_name)
role = role_class(**role_info) # initiate particular Role
role.set_recovered(True) # set True to make a tag
role_memory = Memory.deserialize(stg_path)
role.set_memory(role_memory)
return role
def _init_action_system_message(self, action: Action):
action.set_prefix(self._get_prefix())
def refresh_system_message(self):
self.llm.system_prompt = self._get_prefix()
def set_recovered(self, recovered: bool = False):
self.recovered = recovered
def set_memory(self, memory: Memory):
self.rc.memory = memory
@ -376,7 +335,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
if self.recovered and self.rc.state >= 0:
self._set_state(self.rc.state) # action to run from recovered state
self.set_recovered(False) # avoid max_react_loop out of work
self.recovered = False # avoid max_react_loop out of work
return True
prompt = self._get_prefix()

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
@ -32,8 +32,9 @@ from pydantic import (
PrivateAttr,
field_serializer,
field_validator,
model_serializer,
model_validator,
)
from pydantic_core import core_schema
from metagpt.const import (
MESSAGE_ROUTE_CAUSE_BY,
@ -53,7 +54,7 @@ from metagpt.utils.serialize import (
)
class SerializationMixin(BaseModel):
class SerializationMixin(BaseModel, extra="forbid"):
"""
PolyMorphic subclasses Serialization / Deserialization Mixin
- First of all, we need to know that pydantic is not designed for polymorphism.
@ -68,49 +69,44 @@ class SerializationMixin(BaseModel):
__is_polymorphic_base = False
__subclasses_map__ = {}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
schema = handler(source)
og_schema_ref = schema["ref"]
schema["ref"] += ":mixin"
return core_schema.no_info_before_validator_function(
cls.__deserialize_with_real_type__,
schema=schema,
ref=og_schema_ref,
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
)
@classmethod
def __serialize_add_class_type__(
cls,
value,
handler: core_schema.SerializerFunctionWrapHandler,
) -> Any:
ret = handler(value)
if not len(cls.__subclasses__()):
# only subclass add `__module_class_name`
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
@model_serializer(mode="wrap")
def __serialize_with_class_type__(self, default_serializer) -> Any:
# default serializer, then append the `__module_class_name` field and return
ret = default_serializer(self)
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
return ret
@model_validator(mode="wrap")
@classmethod
def __deserialize_with_real_type__(cls, value: Any):
if not isinstance(value, dict):
return value
def __convert_to_real_type__(cls, value: Any, handler):
if isinstance(value, dict) is False:
return handler(value)
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
# add right condition to init BaseClass like Action()
return value
module_class_name = value.get("__module_class_name", None)
if module_class_name is None:
raise ValueError("Missing field: __module_class_name")
# it is a dict so make sure to remove the __module_class_name
# because we don't allow extra keywords but want to ensure
# e.g Cat.model_validate(cat.model_dump()) works
class_full_name = value.pop("__module_class_name", None)
class_type = cls.__subclasses_map__.get(module_class_name, None)
# if it's not the polymorphic base we construct via default handler
if not cls.__is_polymorphic_base:
if class_full_name is None:
return handler(value)
elif str(cls) == f"<class '{class_full_name}'>":
return handler(value)
else:
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
pass
# otherwise we lookup the correct polymorphic type and construct that
# instead
if class_full_name is None:
raise ValueError("Missing __module_class_name field")
class_type = cls.__subclasses_map__.get(class_full_name, None)
if class_type is None:
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
# TODO could try dynamic import
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
return class_type(**value)
@ -186,12 +182,17 @@ class Message(BaseModel):
@field_validator("instruct_content", mode="before")
@classmethod
def check_instruct_content(cls, ic: Any) -> BaseModel:
if ic and not isinstance(ic, BaseModel) and "class" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
if ic and isinstance(ic, dict) and "class" in ic:
if "mapping" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
elif "module" in ic:
# subclasses of BaseModel
ic_obj = import_class(ic["class"], ic["module"])
else:
raise KeyError("missing required key to init Message.instruct_content from dict")
ic = ic_obj(**ic["value"])
return ic
@ -216,13 +217,16 @@ class Message(BaseModel):
if ic:
# compatible with custom-defined ActionOutput
schema = ic.model_json_schema()
# `Documents` contain definitions
if "definitions" not in schema:
# TODO refine with nested BaseModel
ic_type = str(type(ic))
if "<class 'metagpt.actions.action_node" in ic_type:
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
else:
# due to instruct_content can be assigned by subclasses of BaseModel
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
return ic_dict
def __init__(self, content: str = "", **data: Any):

View file

@ -49,28 +49,21 @@ class Team(BaseModel):
def serialize(self, stg_path: Path = None):
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
team_info_path = stg_path.joinpath("team.json")
team_info_path = stg_path.joinpath("team_info.json")
write_json_file(team_info_path, self.model_dump(exclude={"env": True}))
self.env.serialize(stg_path.joinpath("environment")) # save environment alone
write_json_file(team_info_path, self.model_dump())
@classmethod
def deserialize(cls, stg_path: Path) -> "Team":
"""stg_path = ./storage/team"""
# recover team_info
team_info_path = stg_path.joinpath("team_info.json")
team_info_path = stg_path.joinpath("team.json")
if not team_info_path.exists():
raise FileNotFoundError(
"recover storage meta file `team_info.json` not exist, "
"not to recover and please start a new project."
"recover storage meta file `team.json` not exist, " "not to recover and please start a new project."
)
team_info: dict = read_json_file(team_info_path)
# recover environment
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
team_info.update({"env": environment})
team = Team(**team_info)
return team

View file

@ -18,12 +18,12 @@ from metagpt.config2 import config
def make_sk_kernel():
kernel = sk.Kernel()
if llm := config.get_openai_llm():
if llm := config.get_azure_llm():
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
)
else:
elif llm := config.get_openai_llm():
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(llm.model, llm.api_key),