mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 22:02:38 +02:00
Merge pull request #728 from better629/feat_simple_ser
Feat simpler serialization in one file
This commit is contained in:
commit
9a42a14c91
29 changed files with 232 additions and 316 deletions
|
|
@ -27,7 +27,7 @@ from metagpt.schema import (
|
|||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
class Action(SerializationMixin, is_polymorphic_base=True):
|
||||
class Action(SerializationMixin):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
||||
name: str = ""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
functionality is to be consolidated into the `Environment` class.
|
||||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
|
@ -21,7 +20,7 @@ from metagpt.context import Context
|
|||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import is_send_to, read_json_file, write_json_file
|
||||
from metagpt.utils.common import is_send_to
|
||||
|
||||
|
||||
class Environment(BaseModel):
|
||||
|
|
@ -42,44 +41,6 @@ class Environment(BaseModel):
|
|||
self.add_roles(self.roles.values())
|
||||
return self
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
roles_info = []
|
||||
for role_key, role in self.roles.items():
|
||||
roles_info.append(
|
||||
{
|
||||
"role_class": role.__class__.__name__,
|
||||
"module_name": role.__module__,
|
||||
"role_name": role.name,
|
||||
"role_sub_tags": list(self.member_addrs.get(role)),
|
||||
}
|
||||
)
|
||||
role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}"))
|
||||
write_json_file(roles_path, roles_info)
|
||||
|
||||
history_path = stg_path.joinpath("history.json")
|
||||
write_json_file(history_path, {"content": self.history})
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Environment":
|
||||
"""stg_path: ./storage/team/environment/"""
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
roles_info = read_json_file(roles_path)
|
||||
roles = []
|
||||
for role_info in roles_info:
|
||||
# role stored in ./environment/roles/{role_class}_{role_name}
|
||||
role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}")
|
||||
role = Role.deserialize(role_path)
|
||||
roles.append(role)
|
||||
|
||||
history = read_json_file(stg_path.joinpath("history.json"))
|
||||
history = history.get("content")
|
||||
|
||||
environment = Environment(**{"history": history})
|
||||
environment.add_roles(roles)
|
||||
|
||||
return environment
|
||||
|
||||
def add_role(self, role: Role):
|
||||
"""增加一个在当前环境的角色
|
||||
Add a role in the current environment
|
||||
|
|
|
|||
|
|
@ -7,19 +7,13 @@
|
|||
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import DefaultDict, Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field, SerializeAsAny
|
||||
|
||||
from metagpt.const import IGNORED_MESSAGE_ID
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import (
|
||||
any_to_str,
|
||||
any_to_str_set,
|
||||
read_json_file,
|
||||
write_json_file,
|
||||
)
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
|
|
@ -29,22 +23,6 @@ class Memory(BaseModel):
|
|||
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
|
||||
ignore_id: bool = False
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
storage = self.model_dump()
|
||||
write_json_file(memory_path, storage)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Memory":
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
|
||||
memory_dict = read_json_file(memory_path)
|
||||
memory = Memory(**memory_dict)
|
||||
|
||||
return memory
|
||||
|
||||
def add(self, message: Message):
|
||||
"""Add a new message to storage, while updating the index"""
|
||||
if self.ignore_id:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
|
@ -31,7 +30,6 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
|
|||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.context import Context, context
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -39,14 +37,7 @@ from metagpt.memory import Memory
|
|||
from metagpt.provider import HumanProvider
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
any_to_str,
|
||||
import_class,
|
||||
read_json_file,
|
||||
role_raise_decorator,
|
||||
write_json_file,
|
||||
)
|
||||
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
|
||||
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
|
||||
|
||||
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
|
||||
|
|
@ -128,7 +119,7 @@ class RoleContext(BaseModel):
|
|||
return self.memory.get()
|
||||
|
||||
|
||||
class Role(SerializationMixin, is_polymorphic_base=True):
|
||||
class Role(SerializationMixin):
|
||||
"""Role/Agent"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
|
@ -217,6 +208,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
self.llm.system_prompt = self._get_prefix()
|
||||
self._watch(data.get("watch") or [UserRequirement])
|
||||
|
||||
if self.latest_observed_msg:
|
||||
self.recovered = True
|
||||
|
||||
def _reset(self):
|
||||
self.states = []
|
||||
self.actions = []
|
||||
|
|
@ -225,47 +219,12 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
def _setting(self):
|
||||
return f"{self.name}({self.profile})"
|
||||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = (
|
||||
SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}")
|
||||
if stg_path is None
|
||||
else stg_path
|
||||
)
|
||||
|
||||
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
|
||||
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
write_json_file(role_info_path, role_info)
|
||||
|
||||
self.rc.memory.serialize(stg_path) # serialize role's memory alone
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Role":
|
||||
"""stg_path = ./storage/team/environment/roles/{role_class}_{role_name}"""
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
role_info = read_json_file(role_info_path)
|
||||
|
||||
role_class_str = role_info.pop("role_class")
|
||||
module_name = role_info.pop("module_name")
|
||||
role_class = import_class(class_name=role_class_str, module_name=module_name)
|
||||
|
||||
role = role_class(**role_info) # initiate particular Role
|
||||
role.set_recovered(True) # set True to make a tag
|
||||
|
||||
role_memory = Memory.deserialize(stg_path)
|
||||
role.set_memory(role_memory)
|
||||
|
||||
return role
|
||||
|
||||
def _init_action_system_message(self, action: Action):
|
||||
action.set_prefix(self._get_prefix())
|
||||
|
||||
def refresh_system_message(self):
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
|
||||
def set_recovered(self, recovered: bool = False):
|
||||
self.recovered = recovered
|
||||
|
||||
def set_memory(self, memory: Memory):
|
||||
self.rc.memory = memory
|
||||
|
||||
|
|
@ -376,7 +335,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
|
||||
if self.recovered and self.rc.state >= 0:
|
||||
self._set_state(self.rc.state) # action to run from recovered state
|
||||
self.set_recovered(False) # avoid max_react_loop out of work
|
||||
self.recovered = False # avoid max_react_loop out of work
|
||||
return True
|
||||
|
||||
prompt = self._get_prefix()
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from abc import ABC
|
|||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
|
|
@ -32,8 +32,9 @@ from pydantic import (
|
|||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from metagpt.const import (
|
||||
MESSAGE_ROUTE_CAUSE_BY,
|
||||
|
|
@ -53,7 +54,7 @@ from metagpt.utils.serialize import (
|
|||
)
|
||||
|
||||
|
||||
class SerializationMixin(BaseModel):
|
||||
class SerializationMixin(BaseModel, extra="forbid"):
|
||||
"""
|
||||
PolyMorphic subclasses Serialization / Deserialization Mixin
|
||||
- First of all, we need to know that pydantic is not designed for polymorphism.
|
||||
|
|
@ -68,49 +69,44 @@ class SerializationMixin(BaseModel):
|
|||
__is_polymorphic_base = False
|
||||
__subclasses_map__ = {}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = handler(source)
|
||||
og_schema_ref = schema["ref"]
|
||||
schema["ref"] += ":mixin"
|
||||
|
||||
return core_schema.no_info_before_validator_function(
|
||||
cls.__deserialize_with_real_type__,
|
||||
schema=schema,
|
||||
ref=og_schema_ref,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __serialize_add_class_type__(
|
||||
cls,
|
||||
value,
|
||||
handler: core_schema.SerializerFunctionWrapHandler,
|
||||
) -> Any:
|
||||
ret = handler(value)
|
||||
if not len(cls.__subclasses__()):
|
||||
# only subclass add `__module_class_name`
|
||||
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
@model_serializer(mode="wrap")
|
||||
def __serialize_with_class_type__(self, default_serializer) -> Any:
|
||||
# default serializer, then append the `__module_class_name` field and return
|
||||
ret = default_serializer(self)
|
||||
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||||
return ret
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def __deserialize_with_real_type__(cls, value: Any):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
def __convert_to_real_type__(cls, value: Any, handler):
|
||||
if isinstance(value, dict) is False:
|
||||
return handler(value)
|
||||
|
||||
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
|
||||
# add right condition to init BaseClass like Action()
|
||||
return value
|
||||
module_class_name = value.get("__module_class_name", None)
|
||||
if module_class_name is None:
|
||||
raise ValueError("Missing field: __module_class_name")
|
||||
# it is a dict so make sure to remove the __module_class_name
|
||||
# because we don't allow extra keywords but want to ensure
|
||||
# e.g Cat.model_validate(cat.model_dump()) works
|
||||
class_full_name = value.pop("__module_class_name", None)
|
||||
|
||||
class_type = cls.__subclasses_map__.get(module_class_name, None)
|
||||
# if it's not the polymorphic base we construct via default handler
|
||||
if not cls.__is_polymorphic_base:
|
||||
if class_full_name is None:
|
||||
return handler(value)
|
||||
elif str(cls) == f"<class '{class_full_name}'>":
|
||||
return handler(value)
|
||||
else:
|
||||
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
|
||||
pass
|
||||
|
||||
# otherwise we lookup the correct polymorphic type and construct that
|
||||
# instead
|
||||
if class_full_name is None:
|
||||
raise ValueError("Missing __module_class_name field")
|
||||
|
||||
class_type = cls.__subclasses_map__.get(class_full_name, None)
|
||||
|
||||
if class_type is None:
|
||||
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
|
||||
# TODO could try dynamic import
|
||||
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
|
||||
|
||||
return class_type(**value)
|
||||
|
||||
|
|
@ -186,12 +182,17 @@ class Message(BaseModel):
|
|||
@field_validator("instruct_content", mode="before")
|
||||
@classmethod
|
||||
def check_instruct_content(cls, ic: Any) -> BaseModel:
|
||||
if ic and not isinstance(ic, BaseModel) and "class" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
if ic and isinstance(ic, dict) and "class" in ic:
|
||||
if "mapping" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
elif "module" in ic:
|
||||
# subclasses of BaseModel
|
||||
ic_obj = import_class(ic["class"], ic["module"])
|
||||
else:
|
||||
raise KeyError("missing required key to init Message.instruct_content from dict")
|
||||
ic = ic_obj(**ic["value"])
|
||||
return ic
|
||||
|
||||
|
|
@ -216,13 +217,16 @@ class Message(BaseModel):
|
|||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.model_json_schema()
|
||||
# `Documents` contain definitions
|
||||
if "definitions" not in schema:
|
||||
# TODO refine with nested BaseModel
|
||||
ic_type = str(type(ic))
|
||||
if "<class 'metagpt.actions.action_node" in ic_type:
|
||||
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
||||
else:
|
||||
# due to instruct_content can be assigned by subclasses of BaseModel
|
||||
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
|
||||
return ic_dict
|
||||
|
||||
def __init__(self, content: str = "", **data: Any):
|
||||
|
|
|
|||
|
|
@ -49,28 +49,21 @@ class Team(BaseModel):
|
|||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
|
||||
team_info_path = stg_path.joinpath("team.json")
|
||||
|
||||
team_info_path = stg_path.joinpath("team_info.json")
|
||||
write_json_file(team_info_path, self.model_dump(exclude={"env": True}))
|
||||
|
||||
self.env.serialize(stg_path.joinpath("environment")) # save environment alone
|
||||
write_json_file(team_info_path, self.model_dump())
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Team":
|
||||
"""stg_path = ./storage/team"""
|
||||
# recover team_info
|
||||
team_info_path = stg_path.joinpath("team_info.json")
|
||||
team_info_path = stg_path.joinpath("team.json")
|
||||
if not team_info_path.exists():
|
||||
raise FileNotFoundError(
|
||||
"recover storage meta file `team_info.json` not exist, "
|
||||
"not to recover and please start a new project."
|
||||
"recover storage meta file `team.json` not exist, " "not to recover and please start a new project."
|
||||
)
|
||||
|
||||
team_info: dict = read_json_file(team_info_path)
|
||||
|
||||
# recover environment
|
||||
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
|
||||
team_info.update({"env": environment})
|
||||
team = Team(**team_info)
|
||||
return team
|
||||
|
||||
|
|
|
|||
|
|
@ -18,12 +18,12 @@ from metagpt.config2 import config
|
|||
|
||||
def make_sk_kernel():
|
||||
kernel = sk.Kernel()
|
||||
if llm := config.get_openai_llm():
|
||||
if llm := config.get_azure_llm():
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
|
||||
)
|
||||
else:
|
||||
elif llm := config.get_openai_llm():
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
OpenAIChatCompletion(llm.model, llm.api_key),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue