Merge pull request #728 from better629/feat_simple_ser

Feat simpler serialization in one file
This commit is contained in:
geekan 2024-01-09 16:31:50 +08:00 committed by GitHub
commit 9a42a14c91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 232 additions and 316 deletions

View file

@ -27,7 +27,7 @@ from metagpt.schema import (
from metagpt.utils.file_repository import FileRepository
class Action(SerializationMixin, is_polymorphic_base=True):
class Action(SerializationMixin):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
name: str = ""

View file

@ -12,7 +12,6 @@
functionality is to be consolidated into the `Environment` class.
"""
import asyncio
from pathlib import Path
from typing import Iterable, Set
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -21,7 +20,7 @@ from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import is_send_to, read_json_file, write_json_file
from metagpt.utils.common import is_send_to
class Environment(BaseModel):
@ -42,44 +41,6 @@ class Environment(BaseModel):
self.add_roles(self.roles.values())
return self
def serialize(self, stg_path: Path):
roles_path = stg_path.joinpath("roles.json")
roles_info = []
for role_key, role in self.roles.items():
roles_info.append(
{
"role_class": role.__class__.__name__,
"module_name": role.__module__,
"role_name": role.name,
"role_sub_tags": list(self.member_addrs.get(role)),
}
)
role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}"))
write_json_file(roles_path, roles_info)
history_path = stg_path.joinpath("history.json")
write_json_file(history_path, {"content": self.history})
@classmethod
def deserialize(cls, stg_path: Path) -> "Environment":
"""stg_path: ./storage/team/environment/"""
roles_path = stg_path.joinpath("roles.json")
roles_info = read_json_file(roles_path)
roles = []
for role_info in roles_info:
# role stored in ./environment/roles/{role_class}_{role_name}
role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}")
role = Role.deserialize(role_path)
roles.append(role)
history = read_json_file(stg_path.joinpath("history.json"))
history = history.get("content")
environment = Environment(**{"history": history})
environment.add_roles(roles)
return environment
def add_role(self, role: Role):
"""增加一个在当前环境的角色
Add a role in the current environment

View file

@ -7,19 +7,13 @@
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
"""
from collections import defaultdict
from pathlib import Path
from typing import DefaultDict, Iterable, Set
from pydantic import BaseModel, Field, SerializeAsAny
from metagpt.const import IGNORED_MESSAGE_ID
from metagpt.schema import Message
from metagpt.utils.common import (
any_to_str,
any_to_str_set,
read_json_file,
write_json_file,
)
from metagpt.utils.common import any_to_str, any_to_str_set
class Memory(BaseModel):
@ -29,22 +23,6 @@ class Memory(BaseModel):
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
ignore_id: bool = False
def serialize(self, stg_path: Path):
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")
storage = self.model_dump()
write_json_file(memory_path, storage)
@classmethod
def deserialize(cls, stg_path: Path) -> "Memory":
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")
memory_dict = read_json_file(memory_path)
memory = Memory(**memory_dict)
return memory
def add(self, message: Message):
"""Add a new message to storage, while updating the index"""
if self.ignore_id:

View file

@ -23,7 +23,6 @@
from __future__ import annotations
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, Optional, Set, Type
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -31,7 +30,6 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
from metagpt.context import Context, context
from metagpt.llm import LLM
from metagpt.logs import logger
@ -39,14 +37,7 @@ from metagpt.memory import Memory
from metagpt.provider import HumanProvider
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, MessageQueue, SerializationMixin
from metagpt.utils.common import (
any_to_name,
any_to_str,
import_class,
read_json_file,
role_raise_decorator,
write_json_file,
)
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
@ -128,7 +119,7 @@ class RoleContext(BaseModel):
return self.memory.get()
class Role(SerializationMixin, is_polymorphic_base=True):
class Role(SerializationMixin):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
@ -217,6 +208,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
if self.latest_observed_msg:
self.recovered = True
def _reset(self):
self.states = []
self.actions = []
@ -225,47 +219,12 @@ class Role(SerializationMixin, is_polymorphic_base=True):
def _setting(self):
return f"{self.name}({self.profile})"
def serialize(self, stg_path: Path = None):
stg_path = (
SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}")
if stg_path is None
else stg_path
)
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
role_info_path = stg_path.joinpath("role_info.json")
write_json_file(role_info_path, role_info)
self.rc.memory.serialize(stg_path) # serialize role's memory alone
@classmethod
def deserialize(cls, stg_path: Path) -> "Role":
"""stg_path = ./storage/team/environment/roles/{role_class}_{role_name}"""
role_info_path = stg_path.joinpath("role_info.json")
role_info = read_json_file(role_info_path)
role_class_str = role_info.pop("role_class")
module_name = role_info.pop("module_name")
role_class = import_class(class_name=role_class_str, module_name=module_name)
role = role_class(**role_info) # initiate particular Role
role.set_recovered(True) # set True to make a tag
role_memory = Memory.deserialize(stg_path)
role.set_memory(role_memory)
return role
def _init_action_system_message(self, action: Action):
action.set_prefix(self._get_prefix())
def refresh_system_message(self):
self.llm.system_prompt = self._get_prefix()
def set_recovered(self, recovered: bool = False):
self.recovered = recovered
def set_memory(self, memory: Memory):
self.rc.memory = memory
@ -376,7 +335,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
if self.recovered and self.rc.state >= 0:
self._set_state(self.rc.state) # action to run from recovered state
self.set_recovered(False) # avoid max_react_loop out of work
self.recovered = False # avoid max_react_loop out of work
return True
prompt = self._get_prefix()

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
@ -32,8 +32,9 @@ from pydantic import (
PrivateAttr,
field_serializer,
field_validator,
model_serializer,
model_validator,
)
from pydantic_core import core_schema
from metagpt.const import (
MESSAGE_ROUTE_CAUSE_BY,
@ -53,7 +54,7 @@ from metagpt.utils.serialize import (
)
class SerializationMixin(BaseModel):
class SerializationMixin(BaseModel, extra="forbid"):
"""
PolyMorphic subclasses Serialization / Deserialization Mixin
- First of all, we need to know that pydantic is not designed for polymorphism.
@ -68,49 +69,44 @@ class SerializationMixin(BaseModel):
__is_polymorphic_base = False
__subclasses_map__ = {}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
schema = handler(source)
og_schema_ref = schema["ref"]
schema["ref"] += ":mixin"
return core_schema.no_info_before_validator_function(
cls.__deserialize_with_real_type__,
schema=schema,
ref=og_schema_ref,
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
)
@classmethod
def __serialize_add_class_type__(
cls,
value,
handler: core_schema.SerializerFunctionWrapHandler,
) -> Any:
ret = handler(value)
if not len(cls.__subclasses__()):
# only subclass add `__module_class_name`
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
@model_serializer(mode="wrap")
def __serialize_with_class_type__(self, default_serializer) -> Any:
# default serializer, then append the `__module_class_name` field and return
ret = default_serializer(self)
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
return ret
@model_validator(mode="wrap")
@classmethod
def __deserialize_with_real_type__(cls, value: Any):
if not isinstance(value, dict):
return value
def __convert_to_real_type__(cls, value: Any, handler):
if isinstance(value, dict) is False:
return handler(value)
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
# add right condition to init BaseClass like Action()
return value
module_class_name = value.get("__module_class_name", None)
if module_class_name is None:
raise ValueError("Missing field: __module_class_name")
# it is a dict so make sure to remove the __module_class_name
# because we don't allow extra keywords but want to ensure
# e.g Cat.model_validate(cat.model_dump()) works
class_full_name = value.pop("__module_class_name", None)
class_type = cls.__subclasses_map__.get(module_class_name, None)
# if it's not the polymorphic base we construct via default handler
if not cls.__is_polymorphic_base:
if class_full_name is None:
return handler(value)
elif str(cls) == f"<class '{class_full_name}'>":
return handler(value)
else:
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
pass
# otherwise we lookup the correct polymorphic type and construct that
# instead
if class_full_name is None:
raise ValueError("Missing __module_class_name field")
class_type = cls.__subclasses_map__.get(class_full_name, None)
if class_type is None:
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
# TODO could try dynamic import
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
return class_type(**value)
@ -186,12 +182,17 @@ class Message(BaseModel):
@field_validator("instruct_content", mode="before")
@classmethod
def check_instruct_content(cls, ic: Any) -> BaseModel:
if ic and not isinstance(ic, BaseModel) and "class" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
if ic and isinstance(ic, dict) and "class" in ic:
if "mapping" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
elif "module" in ic:
# subclasses of BaseModel
ic_obj = import_class(ic["class"], ic["module"])
else:
raise KeyError("missing required key to init Message.instruct_content from dict")
ic = ic_obj(**ic["value"])
return ic
@ -216,13 +217,16 @@ class Message(BaseModel):
if ic:
# compatible with custom-defined ActionOutput
schema = ic.model_json_schema()
# `Documents` contain definitions
if "definitions" not in schema:
# TODO refine with nested BaseModel
ic_type = str(type(ic))
if "<class 'metagpt.actions.action_node" in ic_type:
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
else:
# due to instruct_content can be assigned by subclasses of BaseModel
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
return ic_dict
def __init__(self, content: str = "", **data: Any):

View file

@ -49,28 +49,21 @@ class Team(BaseModel):
def serialize(self, stg_path: Path = None):
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
team_info_path = stg_path.joinpath("team.json")
team_info_path = stg_path.joinpath("team_info.json")
write_json_file(team_info_path, self.model_dump(exclude={"env": True}))
self.env.serialize(stg_path.joinpath("environment")) # save environment alone
write_json_file(team_info_path, self.model_dump())
@classmethod
def deserialize(cls, stg_path: Path) -> "Team":
"""stg_path = ./storage/team"""
# recover team_info
team_info_path = stg_path.joinpath("team_info.json")
team_info_path = stg_path.joinpath("team.json")
if not team_info_path.exists():
raise FileNotFoundError(
"recover storage meta file `team_info.json` not exist, "
"not to recover and please start a new project."
"recover storage meta file `team.json` not exist, " "not to recover and please start a new project."
)
team_info: dict = read_json_file(team_info_path)
# recover environment
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
team_info.update({"env": environment})
team = Team(**team_info)
return team

View file

@ -18,12 +18,12 @@ from metagpt.config2 import config
def make_sk_kernel():
kernel = sk.Kernel()
if llm := config.get_openai_llm():
if llm := config.get_azure_llm():
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
)
else:
elif llm := config.get_openai_llm():
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(llm.model, llm.api_key),

View file

@ -8,25 +8,20 @@ from metagpt.actions import Action
from metagpt.llm import LLM
def test_action_serialize():
@pytest.mark.asyncio
async def test_action_serdeser():
action = Action()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
assert "llm" not in ser_action_dict # not export
assert "__module_class_name" not in ser_action_dict
assert "__module_class_name" in ser_action_dict
action = Action(name="test")
ser_action_dict = action.model_dump()
assert "test" in ser_action_dict["name"]
new_action = Action(**ser_action_dict)
@pytest.mark.asyncio
async def test_action_deserialize():
action = Action()
serialized_data = action.model_dump()
new_action = Action(**serialized_data)
assert new_action.name == "Action"
assert new_action.name == "test"
assert isinstance(new_action.llm, type(LLM()))
assert len(await new_action._aask("who are you")) > 0

View file

@ -8,20 +8,15 @@ from metagpt.actions.action import Action
from metagpt.roles.architect import Architect
def test_architect_serialize():
@pytest.mark.asyncio
async def test_architect_serdeser():
role = Architect()
ser_role_dict = role.model_dump(by_alias=True)
assert "name" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_architect_deserialize():
role = Architect()
ser_role_dict = role.model_dump(by_alias=True)
new_role = Architect(**ser_role_dict)
# new_role = Architect.deserialize(ser_role_dict)
assert new_role.name == "Bob"
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], Action)

View file

@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
# @Desc :
import shutil
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
@ -10,7 +9,7 @@ from metagpt.actions.project_management import WriteTasks
from metagpt.environment import Environment
from metagpt.roles.project_manager import ProjectManager
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
from metagpt.utils.common import any_to_str, read_json_file, write_json_file
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOK,
ActionRaise,
@ -19,17 +18,14 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
)
def test_env_serialize():
def test_env_serdeser():
env = Environment()
env.publish_message(message=Message(content="test env serialize"))
ser_env_dict = env.model_dump()
assert "roles" in ser_env_dict
assert len(ser_env_dict["roles"]) == 0
def test_env_deserialize():
env = Environment()
env.publish_message(message=Message(content="test env serialize"))
ser_env_dict = env.model_dump()
new_env = Environment(**ser_env_dict)
assert len(new_env.roles) == 0
assert len(new_env.history) == 25
@ -79,12 +75,13 @@ def test_environment_serdeser_save():
environment = Environment()
role_c = RoleC()
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
stg_path = serdeser_path.joinpath("team", "environment")
env_path = stg_path.joinpath("env.json")
environment.add_role(role_c)
environment.serialize(stg_path)
new_env: Environment = Environment.deserialize(stg_path)
write_json_file(env_path, environment.model_dump())
env_dict = read_json_file(env_path)
new_env: Environment = Environment(**env_dict)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK

View file

@ -9,7 +9,7 @@ from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.design_api import WriteDesign
from metagpt.memory.memory import Memory
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
from metagpt.utils.common import any_to_str, read_json_file, write_json_file
from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path
@ -53,14 +53,14 @@ def test_memory_serdeser_save():
memory.add_batch([msg1, msg2])
stg_path = serdeser_path.joinpath("team", "environment")
memory.serialize(stg_path)
assert stg_path.joinpath("memory.json").exists()
memory_path = stg_path.joinpath("memory.json")
write_json_file(memory_path, memory.model_dump())
assert memory_path.exists()
new_memory = Memory.deserialize(stg_path)
memory_dict = read_json_file(memory_path)
new_memory = Memory(**memory_dict)
assert new_memory.count() == 2
new_msg2 = new_memory.get(1)[0]
assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"]
assert new_msg2.cause_by == any_to_str(WriteDesign)
assert len(new_memory.index) == 2
stg_path.joinpath("memory.json").unlink()

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of polymorphic conditions
import copy
from pydantic import BaseModel, ConfigDict, SerializeAsAny
@ -12,6 +13,8 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
class ActionSubClasses(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
actions: list[SerializeAsAny[Action]] = []
@ -40,19 +43,21 @@ def test_no_serialize_as_any():
def test_polymorphic():
_ = ActionOKV2(
ok_v2 = ActionOKV2(
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
)
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
action_subcls_dict2 = copy.deepcopy(action_subcls_dict)
assert "__module_class_name" in action_subcls_dict["actions"][0]
new_action_subcls = ActionSubClasses(**action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert new_action_subcls.actions[0].extra_field == ok_v2.extra_field
assert isinstance(new_action_subcls.actions[1], ActionPass)
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict2)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)

View file

@ -8,7 +8,7 @@ from metagpt.actions.prepare_interview import PrepareInterview
@pytest.mark.asyncio
async def test_action_deserialize():
async def test_action_serdeser():
action = PrepareInterview()
serialized_data = action.model_dump()
assert serialized_data["name"] == "PrepareInterview"

View file

@ -10,7 +10,7 @@ from metagpt.schema import Message
@pytest.mark.asyncio
async def test_product_manager_deserialize(new_filename):
async def test_product_manager_serdeser(new_filename):
role = ProductManager()
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProductManager(**ser_role_dict)

View file

@ -9,19 +9,14 @@ from metagpt.actions.project_management import WriteTasks
from metagpt.roles.project_manager import ProjectManager
def test_project_manager_serialize():
@pytest.mark.asyncio
async def test_project_manager_serdeser():
role = ProjectManager()
ser_role_dict = role.model_dump(by_alias=True)
assert "name" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_project_manager_deserialize():
role = ProjectManager()
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProjectManager(**ser_role_dict)
assert new_role.name == "Eve"
assert len(new_role.actions) == 1

View file

@ -8,7 +8,7 @@ from metagpt.roles.researcher import Researcher
@pytest.mark.asyncio
async def test_tutorial_assistant_deserialize():
async def test_tutorial_assistant_serdeser():
role = Researcher()
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict

View file

@ -10,13 +10,12 @@ from pydantic import BaseModel, SerializeAsAny
from metagpt.actions import WriteCode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
from metagpt.logs import logger
from metagpt.roles.engineer import Engineer
from metagpt.roles.product_manager import ProductManager
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import format_trackback_info
from metagpt.utils.common import format_trackback_info, read_json_file, write_json_file
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOK,
RoleA,
@ -60,37 +59,31 @@ def test_role_serialize():
assert "actions" in ser_role_dict
def test_engineer_serialize():
def test_engineer_serdeser():
role = Engineer()
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_engineer_deserialize():
role = Engineer(use_code_review=True)
ser_role_dict = role.model_dump()
new_role = Engineer(**ser_role_dict)
assert new_role.name == "Alex"
assert new_role.use_code_review is True
assert new_role.use_code_review is False
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], WriteCode)
# await new_role.actions[0].run(context="write a cli snake game", filename="test_code")
def test_role_serdeser_save():
stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles")
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
pm = ProductManager()
role_tag = f"{pm.__class__.__name__}_{pm.name}"
stg_path = stg_path_prefix.joinpath(role_tag)
pm.serialize(stg_path)
new_pm = Role.deserialize(stg_path)
stg_path = serdeser_path.joinpath("team", "environment", "roles", f"{pm.__class__.__name__}_{pm.name}")
role_path = stg_path.joinpath("role.json")
write_json_file(role_path, pm.model_dump())
role_dict = read_json_file(role_path)
new_pm = ProductManager(**role_dict)
assert new_pm.name == pm.name
assert len(new_pm.get_memories(1)) == 0
@ -98,22 +91,24 @@ def test_role_serdeser_save():
@pytest.mark.asyncio
async def test_role_serdeser_interrupt():
role_c = RoleC()
shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True)
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
stg_path = SERDESER_PATH.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}")
stg_path = serdeser_path.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}")
role_path = stg_path.joinpath("role.json")
try:
await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
except Exception:
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
role_c.serialize(stg_path)
logger.error(f"Exception in `role_c.run`, detail: {format_trackback_info()}")
write_json_file(role_path, role_c.model_dump())
assert role_c.rc.memory.count() == 1
new_role_a: Role = Role.deserialize(stg_path)
assert new_role_a.rc.state == 1
role_dict = read_json_file(role_path)
new_role_c: Role = RoleC(**role_dict)
assert new_role_c.rc.state == 1
with pytest.raises(Exception):
await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement))
await new_role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
if __name__ == "__main__":

View file

@ -1,10 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of schema ser&deser
import pytest
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.schema import Document, Documents, Message
from metagpt.schema import CodingContext, Document, Documents, Message, TestingContext
from metagpt.utils.common import any_to_str
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
MockICMessage,
@ -12,12 +13,16 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
)
def test_message_serdeser():
def test_message_serdeser_from_create_model():
with pytest.raises(KeyError):
_ = Message(content="code", instruct_content={"class": "test", "key": "value"})
out_mapping = {"field3": (str, ...), "field4": (list[str], ...)}
out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}
ic_obj = ActionNode.create_model_class("code", out_mapping)
ic_inst = ic_obj(**out_data)
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
message = Message(content="code", instruct_content=ic_inst, role="engineer", cause_by=WriteCode)
ser_data = message.model_dump()
assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode"
assert ser_data["instruct_content"]["class"] == "code"
@ -25,28 +30,67 @@ def test_message_serdeser():
new_message = Message(**ser_data)
assert new_message.cause_by == any_to_str(WriteCode)
assert new_message.cause_by in [any_to_str(WriteCode)]
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
assert new_message.instruct_content != ic_inst
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
message = Message(content="test_ic", instruct_content=MockICMessage())
mock_msg = MockMessage()
message = Message(content="test_ic", instruct_content=mock_msg)
ser_data = message.model_dump()
new_message = Message(**ser_data)
assert new_message.instruct_content != MockICMessage() # TODO
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
ser_data = message.model_dump()
assert "class" in ser_data["instruct_content"]
assert new_message.instruct_content == mock_msg
def test_message_without_postprocess():
"""to explain `instruct_content` should be postprocessed"""
"""to explain `instruct_content` from `create_model_class` should be postprocessed"""
out_mapping = {"field1": (list[str], ...)}
out_data = {"field1": ["field1 value1", "field1 value2"]}
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = MockMessage(content="code", instruct_content=ic_obj(**out_data))
message = MockICMessage(content="code", instruct_content=ic_obj(**out_data))
ser_data = message.model_dump()
assert ser_data["instruct_content"] == {}
ser_data["instruct_content"] = None
new_message = MockMessage(**ser_data)
new_message = MockICMessage(**ser_data)
assert new_message.instruct_content != ic_obj(**out_data)
def test_message_serdeser_from_basecontext():
doc_msg = Message(content="test_document", instruct_content=Document(content="test doc"))
ser_data = doc_msg.model_dump()
assert ser_data["instruct_content"]["value"]["content"] == "test doc"
assert ser_data["instruct_content"]["value"]["filename"] == ""
docs_msg = Message(
content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")})
)
ser_data = docs_msg.model_dump()
assert ser_data["instruct_content"]["class"] == "Documents"
assert ser_data["instruct_content"]["value"]["docs"]["doc1"]["content"] == "test doc"
assert ser_data["instruct_content"]["value"]["docs"]["doc1"]["filename"] == ""
code_ctxt = CodingContext(
filename="game.py",
design_doc=Document(root_path="docs/system_design", filename="xx.json", content="xxx"),
task_doc=Document(root_path="docs/tasks", filename="xx.json", content="xxx"),
code_doc=Document(root_path="xxx", filename="game.py", content="xxx"),
)
code_ctxt_msg = Message(content="coding_context", instruct_content=code_ctxt)
ser_data = code_ctxt_msg.model_dump()
assert ser_data["instruct_content"]["class"] == "CodingContext"
new_code_ctxt_msg = Message(**ser_data)
assert new_code_ctxt_msg.instruct_content == code_ctxt
assert new_code_ctxt_msg.instruct_content.code_doc.filename == "game.py"
testing_ctxt = TestingContext(
filename="test.py",
code_doc=Document(root_path="xxx", filename="game.py", content="xxx"),
test_doc=Document(root_path="docs/tests", filename="test.py", content="xxx"),
)
testing_ctxt_msg = Message(content="testing_context", instruct_content=testing_ctxt)
ser_data = testing_ctxt_msg.model_dump()
new_testing_ctxt_msg = Message(**ser_data)
assert new_testing_ctxt_msg.instruct_content == testing_ctxt
assert new_testing_ctxt_msg.instruct_content.test_doc.filename == "test.py"

View file

@ -16,14 +16,14 @@ from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
class MockICMessage(BaseModel):
content: str = "test_ic"
class MockMessage(BaseModel):
content: str = "test_msg"
class MockICMessage(BaseModel):
"""to test normal dict without postprocess"""
content: str = ""
content: str = "test_ic_msg"
instruct_content: Optional[BaseModel] = Field(default=None)

View file

@ -5,15 +5,8 @@ import pytest
from metagpt.roles.sk_agent import SkAgent
def test_sk_agent_serialize():
role = SkAgent()
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
assert "name" in ser_role_dict
assert "planner" in ser_role_dict
@pytest.mark.asyncio
async def test_sk_agent_deserialize():
async def test_sk_agent_serdeser():
role = SkAgent()
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
assert "name" in ser_role_dict

View file

@ -4,13 +4,14 @@
# @Desc :
import shutil
from pathlib import Path
import pytest
from metagpt.const import SERDESER_PATH
from metagpt.logs import logger
from metagpt.roles import Architect, ProductManager, ProjectManager
from metagpt.team import Team
from metagpt.utils.common import write_json_file
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOK,
RoleA,
@ -45,9 +46,16 @@ def test_team_deserialize():
assert new_company.env.get_role(arch.profile) is not None
def test_team_serdeser_save():
company = Team()
def mock_team_serialize(self, stg_path: Path = serdeser_path.joinpath("team")):
team_info_path = stg_path.joinpath("team.json")
write_json_file(team_info_path, self.model_dump())
def test_team_serdeser_save(mocker):
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
company = Team()
company.hire([RoleC()])
stg_path = serdeser_path.joinpath("team")
@ -61,9 +69,11 @@ def test_team_serdeser_save():
@pytest.mark.asyncio
async def test_team_recover():
async def test_team_recover(mocker):
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
idea = "write a snake game"
stg_path = SERDESER_PATH.joinpath("team")
stg_path = serdeser_path.joinpath("team")
shutil.rmtree(stg_path, ignore_errors=True)
company = Team()
@ -75,9 +85,9 @@ async def test_team_recover():
ser_data = company.model_dump()
new_company = Team(**ser_data)
new_company.env.get_role(role_c.profile)
# assert new_role_c.rc.memory == role_c.rc.memory # TODO
# assert new_role_c.rc.env != role_c.rc.env # TODO
new_role_c = new_company.env.get_role(role_c.profile)
assert new_role_c.rc.memory == role_c.rc.memory
assert new_role_c.rc.env != role_c.rc.env
assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK
new_company.run_project(idea)
@ -85,9 +95,11 @@ async def test_team_recover():
@pytest.mark.asyncio
async def test_team_recover_save():
async def test_team_recover_save(mocker):
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
idea = "write a 2048 web game"
stg_path = SERDESER_PATH.joinpath("team")
stg_path = serdeser_path.joinpath("team")
shutil.rmtree(stg_path, ignore_errors=True)
company = Team()
@ -98,8 +110,8 @@ async def test_team_recover_save():
new_company = Team.deserialize(stg_path)
new_role_c = new_company.env.get_role(role_c.profile)
# assert new_role_c.rc.memory == role_c.rc.memory
# assert new_role_c.rc.env != role_c.rc.env
assert new_role_c.rc.memory == role_c.rc.memory
assert new_role_c.rc.env != role_c.rc.env
assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=`
assert new_role_c.rc.todo != role_c.rc.todo # serialize exclude `rc.todo`
assert new_role_c.rc.news != role_c.rc.news # serialize exclude `rc.news`
@ -109,9 +121,11 @@ async def test_team_recover_save():
@pytest.mark.asyncio
async def test_team_recover_multi_roles_save():
async def test_team_recover_multi_roles_save(mocker):
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
idea = "write a snake game"
stg_path = SERDESER_PATH.joinpath("team")
stg_path = serdeser_path.joinpath("team")
shutil.rmtree(stg_path, ignore_errors=True)
role_a = RoleA()

View file

@ -7,7 +7,7 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
@pytest.mark.asyncio
async def test_tutorial_assistant_deserialize():
async def test_tutorial_assistant_serdeser():
role = TutorialAssistant()
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict

View file

@ -9,7 +9,7 @@ from metagpt.actions import WriteCode
from metagpt.schema import CodingContext, Document
def test_write_design_serialize():
def test_write_design_serdeser():
action = WriteCode()
ser_action_dict = action.model_dump()
assert ser_action_dict["name"] == "WriteCode"
@ -17,7 +17,7 @@ def test_write_design_serialize():
@pytest.mark.asyncio
async def test_write_code_deserialize():
async def test_write_code_serdeser():
context = CodingContext(
filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers")
)

View file

@ -9,7 +9,7 @@ from metagpt.schema import CodingContext, Document
@pytest.mark.asyncio
async def test_write_code_review_deserialize():
async def test_write_code_review_serdeser():
code_content = """
def div(a: int, b: int = 0):
return a / b

View file

@ -7,33 +7,25 @@ import pytest
from metagpt.actions import WriteDesign, WriteTasks
def test_write_design_serialize():
action = WriteDesign()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
assert "llm" not in ser_action_dict # not export
def test_write_task_serialize():
action = WriteTasks()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
async def test_write_design_deserialize():
async def test_write_design_serialize():
action = WriteDesign()
serialized_data = action.model_dump()
new_action = WriteDesign(**serialized_data)
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
assert "llm" not in ser_action_dict # not export
new_action = WriteDesign(**ser_action_dict)
assert new_action.name == "WriteDesign"
await new_action.run(with_messages="write a cli snake game")
@pytest.mark.asyncio
async def test_write_task_deserialize():
async def test_write_task_serialize():
action = WriteTasks()
serialized_data = action.model_dump()
new_action = WriteTasks(**serialized_data)
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
assert "llm" not in ser_action_dict # not export
new_action = WriteTasks(**ser_action_dict)
assert new_action.name == "WriteTasks"
await new_action.run(with_messages="write a cli snake game")

View file

@ -29,7 +29,7 @@ class Person:
],
ids=["google", "numpy", "sphinx"],
)
async def test_action_deserialize(style: str, part: str):
async def test_action_serdeser(style: str, part: str):
action = WriteDocstring()
serialized_data = action.model_dump()

View file

@ -9,18 +9,14 @@ from metagpt.actions import WritePRD
from metagpt.schema import Message
def test_action_serialize(new_filename):
@pytest.mark.asyncio
async def test_action_serdeser(new_filename):
action = WritePRD()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
async def test_action_deserialize(new_filename):
action = WritePRD()
serialized_data = action.model_dump()
new_action = WritePRD(**serialized_data)
new_action = WritePRD(**ser_action_dict)
assert new_action.name == "WritePRD"
action_output = await new_action.run(with_messages=Message(content="write a cli snake game"))
assert len(action_output.content) > 0

View file

@ -42,7 +42,7 @@ CONTEXT = """
@pytest.mark.asyncio
async def test_action_deserialize():
async def test_action_serdeser():
action = WriteReview()
serialized_data = action.model_dump()
assert serialized_data["name"] == "WriteReview"

View file

@ -9,7 +9,7 @@ from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
@pytest.mark.asyncio
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
async def test_write_directory_deserialize(language: str, topic: str):
async def test_write_directory_serdeser(language: str, topic: str):
action = WriteDirectory()
serialized_data = action.model_dump()
assert serialized_data["name"] == "WriteDirectory"
@ -30,7 +30,7 @@ async def test_write_directory_deserialize(language: str, topic: str):
("language", "topic", "directory"),
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
)
async def test_write_content_deserialize(language: str, topic: str, directory: Dict):
async def test_write_content_serdeser(language: str, topic: str, directory: Dict):
action = WriteContent(language=language, directory=directory)
serialized_data = action.model_dump()
assert serialized_data["name"] == "WriteContent"