From 98ee696cf0fb28874c9b06e697be2b4f824ba61d Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 8 Jan 2024 22:15:56 +0800 Subject: [PATCH 1/3] rm expicit serialize&deserialize interface and update unittests --- metagpt/actions/action.py | 2 +- metagpt/environment.py | 41 +--------- metagpt/memory/memory.py | 24 +----- metagpt/roles/role.py | 53 ++----------- metagpt/schema.py | 74 +++++++++---------- metagpt/team.py | 15 +--- metagpt/utils/make_sk_kernel.py | 4 +- .../serialize_deserialize/test_action.py | 15 ++-- ...itect_deserialize.py => test_architect.py} | 9 +-- .../serialize_deserialize/test_environment.py | 21 +++--- .../serialize_deserialize/test_memory.py | 12 +-- .../serialize_deserialize/test_polymorphic.py | 9 ++- .../test_prepare_interview.py | 2 +- .../test_product_manager.py | 2 +- .../test_project_manager.py | 9 +-- .../serialize_deserialize/test_reasearcher.py | 2 +- .../serialize_deserialize/test_role.py | 41 +++++----- .../serialize_deserialize/test_sk_agent.py | 9 +-- .../serialize_deserialize/test_team.py | 42 +++++++---- .../test_tutorial_assistant.py | 2 +- .../serialize_deserialize/test_write_code.py | 4 +- .../test_write_code_review.py | 2 +- .../test_write_design.py | 32 +++----- .../test_write_docstring.py | 2 +- .../serialize_deserialize/test_write_prd.py | 10 +-- .../test_write_review.py | 2 +- .../test_write_tutorial.py | 4 +- 27 files changed, 154 insertions(+), 290 deletions(-) rename tests/metagpt/serialize_deserialize/{test_architect_deserialize.py => test_architect.py} (76%) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 24357a700..9f045bbaa 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -27,7 +27,7 @@ from metagpt.schema import ( from metagpt.utils.file_repository import FileRepository -class Action(SerializationMixin, is_polymorphic_base=True): +class Action(SerializationMixin): model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"]) name: str = "" diff --git a/metagpt/environment.py b/metagpt/environment.py index 6511647ef..5a2dd339b 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -12,7 +12,6 @@ functionality is to be consolidated into the `Environment` class. """ import asyncio -from pathlib import Path from typing import Iterable, Set from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator @@ -21,7 +20,7 @@ from metagpt.context import Context from metagpt.logs import logger from metagpt.roles.role import Role from metagpt.schema import Message -from metagpt.utils.common import is_send_to, read_json_file, write_json_file +from metagpt.utils.common import is_send_to class Environment(BaseModel): @@ -42,44 +41,6 @@ class Environment(BaseModel): self.add_roles(self.roles.values()) return self - def serialize(self, stg_path: Path): - roles_path = stg_path.joinpath("roles.json") - roles_info = [] - for role_key, role in self.roles.items(): - roles_info.append( - { - "role_class": role.__class__.__name__, - "module_name": role.__module__, - "role_name": role.name, - "role_sub_tags": list(self.member_addrs.get(role)), - } - ) - role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) - write_json_file(roles_path, roles_info) - - history_path = stg_path.joinpath("history.json") - write_json_file(history_path, {"content": self.history}) - - @classmethod - def deserialize(cls, stg_path: Path) -> "Environment": - """stg_path: ./storage/team/environment/""" - roles_path = stg_path.joinpath("roles.json") - roles_info = read_json_file(roles_path) - roles = [] - for role_info in roles_info: - # role stored in ./environment/roles/{role_class}_{role_name} - role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}") - role = Role.deserialize(role_path) - roles.append(role) - - history = read_json_file(stg_path.joinpath("history.json")) - history = history.get("content") - - environment = Environment(**{"history": history}) - environment.add_roles(roles) - - return environment - def add_role(self, role: Role): """增加一个在当前环境的角色 Add a role in the current environment diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 593409648..580361d33 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -7,19 +7,13 @@ @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ from collections import defaultdict -from pathlib import Path from typing import DefaultDict, Iterable, Set from pydantic import BaseModel, Field, SerializeAsAny from metagpt.const import IGNORED_MESSAGE_ID from metagpt.schema import Message -from metagpt.utils.common import ( - any_to_str, - any_to_str_set, - read_json_file, - write_json_file, -) +from metagpt.utils.common import any_to_str, any_to_str_set class Memory(BaseModel): @@ -29,22 +23,6 @@ class Memory(BaseModel): index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list)) ignore_id: bool = False - def serialize(self, stg_path: Path): - """stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" - memory_path = stg_path.joinpath("memory.json") - storage = self.model_dump() - write_json_file(memory_path, storage) - - @classmethod - def deserialize(cls, stg_path: Path) -> "Memory": - """stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" - memory_path = stg_path.joinpath("memory.json") - - memory_dict = read_json_file(memory_path) - memory = Memory(**memory_dict) - - return memory - def add(self, message: Message): """Add a new message to storage, while updating the index""" if self.ignore_id: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index cdb2da40a..73d82e369 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -23,7 +23,6 @@ from __future__ import annotations from enum import Enum -from pathlib import Path from typing import Any, Iterable, Optional, Set, Type from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator @@ -31,7 +30,6 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement -from metagpt.const import SERDESER_PATH from metagpt.context import Context, context from metagpt.llm import LLM from metagpt.logs import logger @@ -39,14 +37,7 @@ from metagpt.memory import Memory from metagpt.provider import HumanProvider from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, MessageQueue, SerializationMixin -from metagpt.utils.common import ( - any_to_name, - any_to_str, - import_class, - read_json_file, - role_raise_decorator, - write_json_file, -) +from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """ @@ -128,7 +119,7 @@ class RoleContext(BaseModel): return self.memory.get() -class Role(SerializationMixin, is_polymorphic_base=True): +class Role(SerializationMixin): """Role/Agent""" model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"]) @@ -217,6 +208,9 @@ class Role(SerializationMixin, is_polymorphic_base=True): self.llm.system_prompt = self._get_prefix() self._watch(data.get("watch") or [UserRequirement]) + if self.latest_observed_msg: + self.recovered = True + def _reset(self): self.states = [] self.actions = [] @@ -225,47 +219,12 @@ class Role(SerializationMixin, is_polymorphic_base=True): def _setting(self): return f"{self.name}({self.profile})" - def serialize(self, stg_path: Path = None): - stg_path = ( - SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") - if stg_path is None - else stg_path - ) - - role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True}) - role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__}) - role_info_path = stg_path.joinpath("role_info.json") - write_json_file(role_info_path, role_info) - - self.rc.memory.serialize(stg_path) # serialize role's memory alone - - @classmethod - def deserialize(cls, stg_path: Path) -> "Role": - """stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" - role_info_path = stg_path.joinpath("role_info.json") - role_info = read_json_file(role_info_path) - - role_class_str = role_info.pop("role_class") - module_name = role_info.pop("module_name") - role_class = import_class(class_name=role_class_str, module_name=module_name) - - role = role_class(**role_info) # initiate particular Role - role.set_recovered(True) # set True to make a tag - - role_memory = Memory.deserialize(stg_path) - role.set_memory(role_memory) - - return role - def _init_action_system_message(self, action: Action): action.set_prefix(self._get_prefix()) def refresh_system_message(self): self.llm.system_prompt = self._get_prefix() - def set_recovered(self, recovered: bool = False): - self.recovered = recovered - def set_memory(self, memory: Memory): self.rc.memory = memory @@ -376,7 +335,7 @@ class Role(SerializationMixin, is_polymorphic_base=True): if self.recovered and self.rc.state >= 0: self._set_state(self.rc.state) # action to run from recovered state - self.set_recovered(False) # avoid max_react_loop out of work + self.recovered = False # avoid max_react_loop out of work return True prompt = self._get_prefix() diff --git a/metagpt/schema.py b/metagpt/schema.py index cf24fbc6f..a557951c7 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -23,7 +23,7 @@ from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Type, TypeVar, Union from pydantic import ( BaseModel, @@ -32,8 +32,9 @@ from pydantic import ( PrivateAttr, field_serializer, field_validator, + model_serializer, + model_validator, ) -from pydantic_core import core_schema from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, @@ -53,7 +54,7 @@ from metagpt.utils.serialize import ( ) -class SerializationMixin(BaseModel): +class SerializationMixin(BaseModel, extra="forbid"): """ PolyMorphic subclasses Serialization / Deserialization Mixin - First of all, we need to know that pydantic is not designed for polymorphism. @@ -68,49 +69,44 @@ class SerializationMixin(BaseModel): __is_polymorphic_base = False __subclasses_map__ = {} - @classmethod - def __get_pydantic_core_schema__( - cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema] - ) -> core_schema.CoreSchema: - schema = handler(source) - og_schema_ref = schema["ref"] - schema["ref"] += ":mixin" - - return core_schema.no_info_before_validator_function( - cls.__deserialize_with_real_type__, - schema=schema, - ref=og_schema_ref, - serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__), - ) - - @classmethod - def __serialize_add_class_type__( - cls, - value, - handler: core_schema.SerializerFunctionWrapHandler, - ) -> Any: - ret = handler(value) - if not len(cls.__subclasses__()): - # only subclass add `__module_class_name` - ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}" + @model_serializer(mode="wrap") + def __serialize_with_class_type__(self, default_serializer) -> Any: + # default serializer, then append the `__module_class_name` field and return + ret = default_serializer(self) + ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" return ret + @model_validator(mode="wrap") @classmethod - def __deserialize_with_real_type__(cls, value: Any): - if not isinstance(value, dict): - return value + def __convert_to_real_type__(cls, value: Any, handler): + if isinstance(value, dict) is False: + return handler(value) - if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value): - # add right condition to init BaseClass like Action() - return value - module_class_name = value.get("__module_class_name", None) - if module_class_name is None: - raise ValueError("Missing field: __module_class_name") + # it is a dict so make sure to remove the __module_class_name + # because we don't allow extra keywords but want to ensure + # e.g Cat.model_validate(cat.model_dump()) works + class_full_name = value.pop("__module_class_name", None) - class_type = cls.__subclasses_map__.get(module_class_name, None) + # if it's not the polymorphic base we construct via default handler + if not cls.__is_polymorphic_base: + if class_full_name is None: + return handler(value) + elif str(cls) == f"": + return handler(value) + else: + # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") + pass + + # otherwise we lookup the correct polymorphic type and construct that + # instead + if class_full_name is None: + raise ValueError("Missing __module_class_name field") + + class_type = cls.__subclasses_map__.get(class_full_name, None) if class_type is None: - raise TypeError("Trying to instantiate {module_class_name} which not defined yet.") + # TODO could try dynamic import + raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!") return class_type(**value) diff --git a/metagpt/team.py b/metagpt/team.py index 87fee8dc7..96a27d482 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -49,28 +49,21 @@ class Team(BaseModel): def serialize(self, stg_path: Path = None): stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path + team_info_path = stg_path.joinpath("team.json") - team_info_path = stg_path.joinpath("team_info.json") - write_json_file(team_info_path, self.model_dump(exclude={"env": True})) - - self.env.serialize(stg_path.joinpath("environment")) # save environment alone + write_json_file(team_info_path, self.model_dump()) @classmethod def deserialize(cls, stg_path: Path) -> "Team": """stg_path = ./storage/team""" # recover team_info - team_info_path = stg_path.joinpath("team_info.json") + team_info_path = stg_path.joinpath("team.json") if not team_info_path.exists(): raise FileNotFoundError( - "recover storage meta file `team_info.json` not exist, " - "not to recover and please start a new project." + "recover storage meta file `team.json` not exist, " "not to recover and please start a new project." ) team_info: dict = read_json_file(team_info_path) - - # recover environment - environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) - team_info.update({"env": environment}) team = Team(**team_info) return team diff --git a/metagpt/utils/make_sk_kernel.py b/metagpt/utils/make_sk_kernel.py index 319ba3e34..283a682d6 100644 --- a/metagpt/utils/make_sk_kernel.py +++ b/metagpt/utils/make_sk_kernel.py @@ -18,12 +18,12 @@ from metagpt.config2 import config def make_sk_kernel(): kernel = sk.Kernel() - if llm := config.get_openai_llm(): + if llm := config.get_azure_llm(): kernel.add_chat_service( "chat_completion", AzureChatCompletion(llm.model, llm.base_url, llm.api_key), ) - else: + elif llm := config.get_openai_llm(): kernel.add_chat_service( "chat_completion", OpenAIChatCompletion(llm.model, llm.api_key), diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index 81879e34e..f66900241 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -8,25 +8,20 @@ from metagpt.actions import Action from metagpt.llm import LLM -def test_action_serialize(): +@pytest.mark.asyncio +async def test_action_serdeser(): action = Action() ser_action_dict = action.model_dump() assert "name" in ser_action_dict assert "llm" not in ser_action_dict # not export - assert "__module_class_name" not in ser_action_dict + assert "__module_class_name" in ser_action_dict action = Action(name="test") ser_action_dict = action.model_dump() assert "test" in ser_action_dict["name"] + new_action = Action(**ser_action_dict) -@pytest.mark.asyncio -async def test_action_deserialize(): - action = Action() - serialized_data = action.model_dump() - - new_action = Action(**serialized_data) - - assert new_action.name == "Action" + assert new_action.name == "test" assert isinstance(new_action.llm, type(LLM())) assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect.py similarity index 76% rename from tests/metagpt/serialize_deserialize/test_architect_deserialize.py rename to tests/metagpt/serialize_deserialize/test_architect.py index b113912a7..343662494 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect.py @@ -8,20 +8,15 @@ from metagpt.actions.action import Action from metagpt.roles.architect import Architect -def test_architect_serialize(): +@pytest.mark.asyncio +async def test_architect_serdeser(): role = Architect() ser_role_dict = role.model_dump(by_alias=True) assert "name" in ser_role_dict assert "states" in ser_role_dict assert "actions" in ser_role_dict - -@pytest.mark.asyncio -async def test_architect_deserialize(): - role = Architect() - ser_role_dict = role.model_dump(by_alias=True) new_role = Architect(**ser_role_dict) - # new_role = Architect.deserialize(ser_role_dict) assert new_role.name == "Bob" assert len(new_role.actions) == 1 assert isinstance(new_role.actions[0], Action) diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py index 5a68288a6..3e2a3abba 100644 --- a/tests/metagpt/serialize_deserialize/test_environment.py +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- # @Desc : -import shutil from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement @@ -10,7 +9,7 @@ from metagpt.actions.project_management import WriteTasks from metagpt.environment import Environment from metagpt.roles.project_manager import ProjectManager from metagpt.schema import Message -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_str, read_json_file, write_json_file from tests.metagpt.serialize_deserialize.test_serdeser_base import ( ActionOK, ActionRaise, @@ -19,17 +18,14 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import ( ) -def test_env_serialize(): +def test_env_serdeser(): env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.model_dump() assert "roles" in ser_env_dict assert len(ser_env_dict["roles"]) == 0 - -def test_env_deserialize(): - env = Environment() - env.publish_message(message=Message(content="test env serialize")) - ser_env_dict = env.model_dump() new_env = Environment(**ser_env_dict) assert len(new_env.roles) == 0 assert len(new_env.history) == 25 @@ -79,12 +75,13 @@ def test_environment_serdeser_save(): environment = Environment() role_c = RoleC() - shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) - stg_path = serdeser_path.joinpath("team", "environment") + env_path = stg_path.joinpath("env.json") environment.add_role(role_c) - environment.serialize(stg_path) - new_env: Environment = Environment.deserialize(stg_path) + write_json_file(env_path, environment.model_dump()) + + env_dict = read_json_file(env_path) + new_env: Environment = Environment(**env_dict) assert len(new_env.roles) == 1 assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK diff --git a/tests/metagpt/serialize_deserialize/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py index aa3e2a465..fdaea7861 100644 --- a/tests/metagpt/serialize_deserialize/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -9,7 +9,7 @@ from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.design_api import WriteDesign from metagpt.memory.memory import Memory from metagpt.schema import Message -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_str, read_json_file, write_json_file from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path @@ -53,14 +53,14 @@ def test_memory_serdeser_save(): memory.add_batch([msg1, msg2]) stg_path = serdeser_path.joinpath("team", "environment") - memory.serialize(stg_path) - assert stg_path.joinpath("memory.json").exists() + memory_path = stg_path.joinpath("memory.json") + write_json_file(memory_path, memory.model_dump()) + assert memory_path.exists() - new_memory = Memory.deserialize(stg_path) + memory_dict = read_json_file(memory_path) + new_memory = Memory(**memory_dict) assert new_memory.count() == 2 new_msg2 = new_memory.get(1)[0] assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] assert new_msg2.cause_by == any_to_str(WriteDesign) assert len(new_memory.index) == 2 - - stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/serialize_deserialize/test_polymorphic.py b/tests/metagpt/serialize_deserialize/test_polymorphic.py index ed0482c34..e5f8ec8d6 100644 --- a/tests/metagpt/serialize_deserialize/test_polymorphic.py +++ b/tests/metagpt/serialize_deserialize/test_polymorphic.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : unittest of polymorphic conditions +import copy from pydantic import BaseModel, ConfigDict, SerializeAsAny @@ -12,6 +13,8 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import ( class ActionSubClasses(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + actions: list[SerializeAsAny[Action]] = [] @@ -40,19 +43,21 @@ def test_no_serialize_as_any(): def test_polymorphic(): - _ = ActionOKV2( + ok_v2 = ActionOKV2( **{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"} ) action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()]) action_subcls_dict = action_subcls.model_dump() + action_subcls_dict2 = copy.deepcopy(action_subcls_dict) assert "__module_class_name" in action_subcls_dict["actions"][0] new_action_subcls = ActionSubClasses(**action_subcls_dict) assert isinstance(new_action_subcls.actions[0], ActionOKV2) + assert new_action_subcls.actions[0].extra_field == ok_v2.extra_field assert isinstance(new_action_subcls.actions[1], ActionPass) - new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict) + new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict2) assert isinstance(new_action_subcls.actions[0], ActionOKV2) assert isinstance(new_action_subcls.actions[1], ActionPass) diff --git a/tests/metagpt/serialize_deserialize/test_prepare_interview.py b/tests/metagpt/serialize_deserialize/test_prepare_interview.py index cd9912103..3b57aa27e 100644 --- a/tests/metagpt/serialize_deserialize/test_prepare_interview.py +++ b/tests/metagpt/serialize_deserialize/test_prepare_interview.py @@ -8,7 +8,7 @@ from metagpt.actions.prepare_interview import PrepareInterview @pytest.mark.asyncio -async def test_action_deserialize(): +async def test_action_serdeser(): action = PrepareInterview() serialized_data = action.model_dump() assert serialized_data["name"] == "PrepareInterview" diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 094943900..1a056f9d4 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -10,7 +10,7 @@ from metagpt.schema import Message @pytest.mark.asyncio -async def test_product_manager_deserialize(new_filename): +async def test_product_manager_serdeser(new_filename): role = ProductManager() ser_role_dict = role.model_dump(by_alias=True) new_role = ProductManager(**ser_role_dict) diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py index 1088a4461..f2c5af853 100644 --- a/tests/metagpt/serialize_deserialize/test_project_manager.py +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -9,19 +9,14 @@ from metagpt.actions.project_management import WriteTasks from metagpt.roles.project_manager import ProjectManager -def test_project_manager_serialize(): +@pytest.mark.asyncio +async def test_project_manager_serdeser(): role = ProjectManager() ser_role_dict = role.model_dump(by_alias=True) assert "name" in ser_role_dict assert "states" in ser_role_dict assert "actions" in ser_role_dict - -@pytest.mark.asyncio -async def test_project_manager_deserialize(): - role = ProjectManager() - ser_role_dict = role.model_dump(by_alias=True) - new_role = ProjectManager(**ser_role_dict) assert new_role.name == "Eve" assert len(new_role.actions) == 1 diff --git a/tests/metagpt/serialize_deserialize/test_reasearcher.py b/tests/metagpt/serialize_deserialize/test_reasearcher.py index 1b8dbf2c7..a2d1fa513 100644 --- a/tests/metagpt/serialize_deserialize/test_reasearcher.py +++ b/tests/metagpt/serialize_deserialize/test_reasearcher.py @@ -8,7 +8,7 @@ from metagpt.roles.researcher import Researcher @pytest.mark.asyncio -async def test_tutorial_assistant_deserialize(): +async def test_tutorial_assistant_serdeser(): role = Researcher() ser_role_dict = role.model_dump() assert "name" in ser_role_dict diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index d38797baf..bbfe350b7 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -10,13 +10,12 @@ from pydantic import BaseModel, SerializeAsAny from metagpt.actions import WriteCode from metagpt.actions.add_requirement import UserRequirement -from metagpt.const import SERDESER_PATH from metagpt.logs import logger from metagpt.roles.engineer import Engineer from metagpt.roles.product_manager import ProductManager from metagpt.roles.role import Role from metagpt.schema import Message -from metagpt.utils.common import format_trackback_info +from metagpt.utils.common import format_trackback_info, read_json_file, write_json_file from tests.metagpt.serialize_deserialize.test_serdeser_base import ( ActionOK, RoleA, @@ -60,37 +59,31 @@ def test_role_serialize(): assert "actions" in ser_role_dict -def test_engineer_serialize(): +def test_engineer_serdeser(): role = Engineer() ser_role_dict = role.model_dump() assert "name" in ser_role_dict assert "states" in ser_role_dict assert "actions" in ser_role_dict - -@pytest.mark.asyncio -async def test_engineer_deserialize(): - role = Engineer(use_code_review=True) - ser_role_dict = role.model_dump() - new_role = Engineer(**ser_role_dict) assert new_role.name == "Alex" - assert new_role.use_code_review is True + assert new_role.use_code_review is False assert len(new_role.actions) == 1 assert isinstance(new_role.actions[0], WriteCode) - # await new_role.actions[0].run(context="write a cli snake game", filename="test_code") def test_role_serdeser_save(): - stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles") shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) pm = ProductManager() - role_tag = f"{pm.__class__.__name__}_{pm.name}" - stg_path = stg_path_prefix.joinpath(role_tag) - pm.serialize(stg_path) - new_pm = Role.deserialize(stg_path) + stg_path = serdeser_path.joinpath("team", "environment", "roles", f"{pm.__class__.__name__}_{pm.name}") + role_path = stg_path.joinpath("role.json") + write_json_file(role_path, pm.model_dump()) + + role_dict = read_json_file(role_path) + new_pm = ProductManager(**role_dict) assert new_pm.name == pm.name assert len(new_pm.get_memories(1)) == 0 @@ -98,22 +91,24 @@ def test_role_serdeser_save(): @pytest.mark.asyncio async def test_role_serdeser_interrupt(): role_c = RoleC() - shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True) + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) - stg_path = SERDESER_PATH.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}") + stg_path = serdeser_path.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}") + role_path = stg_path.joinpath("role.json") try: await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement)) except Exception: - logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") - role_c.serialize(stg_path) + logger.error(f"Exception in `role_c.run`, detail: {format_trackback_info()}") + write_json_file(role_path, role_c.model_dump()) assert role_c.rc.memory.count() == 1 - new_role_a: Role = Role.deserialize(stg_path) - assert new_role_a.rc.state == 1 + role_dict = read_json_file(role_path) + new_role_c: Role = RoleC(**role_dict) + assert new_role_c.rc.state == 1 with pytest.raises(Exception): - await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement)) + await new_role_c.run(with_message=Message(content="demo", cause_by=UserRequirement)) if __name__ == "__main__": diff --git a/tests/metagpt/serialize_deserialize/test_sk_agent.py b/tests/metagpt/serialize_deserialize/test_sk_agent.py index 7f287b8f9..97c0ade99 100644 --- a/tests/metagpt/serialize_deserialize/test_sk_agent.py +++ b/tests/metagpt/serialize_deserialize/test_sk_agent.py @@ -5,15 +5,8 @@ import pytest from metagpt.roles.sk_agent import SkAgent -def test_sk_agent_serialize(): - role = SkAgent() - ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"}) - assert "name" in ser_role_dict - assert "planner" in ser_role_dict - - @pytest.mark.asyncio -async def test_sk_agent_deserialize(): +async def test_sk_agent_serdeser(): role = SkAgent() ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"}) assert "name" in ser_role_dict diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 566f63c3d..57c8a8508 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -4,13 +4,14 @@ # @Desc : import shutil +from pathlib import Path import pytest -from metagpt.const import SERDESER_PATH from metagpt.logs import logger from metagpt.roles import Architect, ProductManager, ProjectManager from metagpt.team import Team +from metagpt.utils.common import write_json_file from tests.metagpt.serialize_deserialize.test_serdeser_base import ( ActionOK, RoleA, @@ -45,9 +46,16 @@ def test_team_deserialize(): assert new_company.env.get_role(arch.profile) is not None -def test_team_serdeser_save(): - company = Team() +def mock_team_serialize(self, stg_path: Path = serdeser_path.joinpath("team")): + team_info_path = stg_path.joinpath("team.json") + write_json_file(team_info_path, self.model_dump()) + + +def test_team_serdeser_save(mocker): + mocker.patch("metagpt.team.Team.serialize", mock_team_serialize) + + company = Team() company.hire([RoleC()]) stg_path = serdeser_path.joinpath("team") @@ -61,9 +69,11 @@ def test_team_serdeser_save(): @pytest.mark.asyncio -async def test_team_recover(): +async def test_team_recover(mocker): + mocker.patch("metagpt.team.Team.serialize", mock_team_serialize) + idea = "write a snake game" - stg_path = SERDESER_PATH.joinpath("team") + stg_path = serdeser_path.joinpath("team") shutil.rmtree(stg_path, ignore_errors=True) company = Team() @@ -75,9 +85,9 @@ async def test_team_recover(): ser_data = company.model_dump() new_company = Team(**ser_data) - new_company.env.get_role(role_c.profile) - # assert new_role_c.rc.memory == role_c.rc.memory # TODO - # assert new_role_c.rc.env != role_c.rc.env # TODO + new_role_c = new_company.env.get_role(role_c.profile) + assert new_role_c.rc.memory == role_c.rc.memory + assert new_role_c.rc.env != role_c.rc.env assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK new_company.run_project(idea) @@ -85,9 +95,11 @@ async def test_team_recover(): @pytest.mark.asyncio -async def test_team_recover_save(): +async def test_team_recover_save(mocker): + mocker.patch("metagpt.team.Team.serialize", mock_team_serialize) + idea = "write a 2048 web game" - stg_path = SERDESER_PATH.joinpath("team") + stg_path = serdeser_path.joinpath("team") shutil.rmtree(stg_path, ignore_errors=True) company = Team() @@ -98,8 +110,8 @@ async def test_team_recover_save(): new_company = Team.deserialize(stg_path) new_role_c = new_company.env.get_role(role_c.profile) - # assert new_role_c.rc.memory == role_c.rc.memory - # assert new_role_c.rc.env != role_c.rc.env + assert new_role_c.rc.memory == role_c.rc.memory + assert new_role_c.rc.env != role_c.rc.env assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=` assert new_role_c.rc.todo != role_c.rc.todo # serialize exclude `rc.todo` assert new_role_c.rc.news != role_c.rc.news # serialize exclude `rc.news` @@ -109,9 +121,11 @@ async def test_team_recover_save(): @pytest.mark.asyncio -async def test_team_recover_multi_roles_save(): +async def test_team_recover_multi_roles_save(mocker): + mocker.patch("metagpt.team.Team.serialize", mock_team_serialize) + idea = "write a snake game" - stg_path = SERDESER_PATH.joinpath("team") + stg_path = serdeser_path.joinpath("team") shutil.rmtree(stg_path, ignore_errors=True) role_a = RoleA() diff --git a/tests/metagpt/serialize_deserialize/test_tutorial_assistant.py b/tests/metagpt/serialize_deserialize/test_tutorial_assistant.py index e642dae54..cb8feec19 100644 --- a/tests/metagpt/serialize_deserialize/test_tutorial_assistant.py +++ b/tests/metagpt/serialize_deserialize/test_tutorial_assistant.py @@ -7,7 +7,7 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant @pytest.mark.asyncio -async def test_tutorial_assistant_deserialize(): +async def test_tutorial_assistant_serdeser(): role = TutorialAssistant() ser_role_dict = role.model_dump() assert "name" in ser_role_dict diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index cb262bb45..12dc49c3b 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -9,7 +9,7 @@ from metagpt.actions import WriteCode from metagpt.schema import CodingContext, Document -def test_write_design_serialize(): +def test_write_design_serdeser(): action = WriteCode() ser_action_dict = action.model_dump() assert ser_action_dict["name"] == "WriteCode" @@ -17,7 +17,7 @@ def test_write_design_serialize(): @pytest.mark.asyncio -async def test_write_code_deserialize(): +async def test_write_code_serdeser(): context = CodingContext( filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers") ) diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py index 991b3c13b..d1a9bff24 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code_review.py +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -9,7 +9,7 @@ from metagpt.schema import CodingContext, Document @pytest.mark.asyncio -async def test_write_code_review_deserialize(): +async def test_write_code_review_serdeser(): code_content = """ def div(a: int, b: int = 0): return a / b diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index 7bcba3fc8..37d505914 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -7,33 +7,25 @@ import pytest from metagpt.actions import WriteDesign, WriteTasks -def test_write_design_serialize(): - action = WriteDesign() - ser_action_dict = action.model_dump() - assert "name" in ser_action_dict - assert "llm" not in ser_action_dict # not export - - -def test_write_task_serialize(): - action = WriteTasks() - ser_action_dict = action.model_dump() - assert "name" in ser_action_dict - assert "llm" not in ser_action_dict # not export - - @pytest.mark.asyncio -async def test_write_design_deserialize(): +async def test_write_design_serialize(): action = WriteDesign() - serialized_data = action.model_dump() - new_action = WriteDesign(**serialized_data) + ser_action_dict = action.model_dump() + assert "name" in ser_action_dict + assert "llm" not in ser_action_dict # not export + + new_action = WriteDesign(**ser_action_dict) assert new_action.name == "WriteDesign" await new_action.run(with_messages="write a cli snake game") @pytest.mark.asyncio -async def test_write_task_deserialize(): +async def test_write_task_serialize(): action = WriteTasks() - serialized_data = action.model_dump() - new_action = WriteTasks(**serialized_data) + ser_action_dict = action.model_dump() + assert "name" in ser_action_dict + assert "llm" not in ser_action_dict # not export + + new_action = WriteTasks(**ser_action_dict) assert new_action.name == "WriteTasks" await new_action.run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_write_docstring.py b/tests/metagpt/serialize_deserialize/test_write_docstring.py index e4116ab30..fb927f089 100644 --- a/tests/metagpt/serialize_deserialize/test_write_docstring.py +++ b/tests/metagpt/serialize_deserialize/test_write_docstring.py @@ -29,7 +29,7 @@ class Person: ], ids=["google", "numpy", "sphinx"], ) -async def test_action_deserialize(style: str, part: str): +async def test_action_serdeser(style: str, part: str): action = WriteDocstring() serialized_data = action.model_dump() diff --git a/tests/metagpt/serialize_deserialize/test_write_prd.py b/tests/metagpt/serialize_deserialize/test_write_prd.py index b9eff5a19..820ee237c 100644 --- a/tests/metagpt/serialize_deserialize/test_write_prd.py +++ b/tests/metagpt/serialize_deserialize/test_write_prd.py @@ -9,18 +9,14 @@ from metagpt.actions import WritePRD from metagpt.schema import Message -def test_action_serialize(new_filename): +@pytest.mark.asyncio +async def test_action_serdeser(new_filename): action = WritePRD() ser_action_dict = action.model_dump() assert "name" in ser_action_dict assert "llm" not in ser_action_dict # not export - -@pytest.mark.asyncio -async def test_action_deserialize(new_filename): - action = WritePRD() - serialized_data = action.model_dump() - new_action = WritePRD(**serialized_data) + new_action = WritePRD(**ser_action_dict) assert new_action.name == "WritePRD" action_output = await new_action.run(with_messages=Message(content="write a cli snake game")) assert len(action_output.content) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_write_review.py b/tests/metagpt/serialize_deserialize/test_write_review.py index f02a01910..17e212276 100644 --- a/tests/metagpt/serialize_deserialize/test_write_review.py +++ b/tests/metagpt/serialize_deserialize/test_write_review.py @@ -42,7 +42,7 @@ CONTEXT = """ @pytest.mark.asyncio -async def test_action_deserialize(): +async def test_action_serdeser(): action = WriteReview() serialized_data = action.model_dump() assert serialized_data["name"] == "WriteReview" diff --git a/tests/metagpt/serialize_deserialize/test_write_tutorial.py b/tests/metagpt/serialize_deserialize/test_write_tutorial.py index 606a90f8c..4eeef7e0d 100644 --- a/tests/metagpt/serialize_deserialize/test_write_tutorial.py +++ b/tests/metagpt/serialize_deserialize/test_write_tutorial.py @@ -9,7 +9,7 @@ from metagpt.actions.write_tutorial import WriteContent, WriteDirectory @pytest.mark.asyncio @pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")]) -async def test_write_directory_deserialize(language: str, topic: str): +async def test_write_directory_serdeser(language: str, topic: str): action = WriteDirectory() serialized_data = action.model_dump() assert serialized_data["name"] == "WriteDirectory" @@ -30,7 +30,7 @@ async def test_write_directory_deserialize(language: str, topic: str): ("language", "topic", "directory"), [("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})], ) -async def test_write_content_deserialize(language: str, topic: str, directory: Dict): +async def test_write_content_serdeser(language: str, topic: str, directory: Dict): action = WriteContent(language=language, directory=directory) serialized_data = action.model_dump() assert serialized_data["name"] == "WriteContent" From f9a150bab0a24b212b78469958aa2ee61813c844 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 9 Jan 2024 15:40:42 +0800 Subject: [PATCH 2/3] make instruct_content support any inherited basemodel ser&deser --- metagpt/schema.py | 25 ++++--- .../serialize_deserialize/test_schema.py | 68 +++++++++++++++---- .../test_serdeser_base.py | 10 +-- 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/metagpt/schema.py b/metagpt/schema.py index a557951c7..7d1c2b539 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -182,12 +182,16 @@ class Message(BaseModel): @field_validator("instruct_content", mode="before") @classmethod def check_instruct_content(cls, ic: Any) -> BaseModel: - if ic and not isinstance(ic, BaseModel) and "class" in ic: - # compatible with custom-defined ActionOutput - mapping = actionoutput_str_to_mapping(ic["mapping"]) - - actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import - ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) + if ic and isinstance(ic, dict) and "class" in ic: + if "mapping" in ic: + # compatible with custom-defined ActionOutput + mapping = actionoutput_str_to_mapping(ic["mapping"]) + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) + elif "module" in ic: + ic_obj = import_class(ic["class"], ic["module"]) + else: + raise KeyError("missing required key to init Message.instruct_content from dict") ic = ic_obj(**ic["value"]) return ic @@ -212,13 +216,16 @@ class Message(BaseModel): if ic: # compatible with custom-defined ActionOutput schema = ic.model_json_schema() - # `Documents` contain definitions - if "definitions" not in schema: - # TODO refine with nested BaseModel + ic_type = str(type(ic)) + if " Date: Tue, 9 Jan 2024 16:07:33 +0800 Subject: [PATCH 3/3] update --- metagpt/schema.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metagpt/schema.py b/metagpt/schema.py index 7d1c2b539..853a9c6bb 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -189,6 +189,7 @@ class Message(BaseModel): actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) elif "module" in ic: + # subclasses of BaseModel ic_obj = import_class(ic["class"], ic["module"]) else: raise KeyError("missing required key to init Message.instruct_content from dict")