From b4207cec923bd15b5d74ffbbf0e442a937af4aa3 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 12 Aug 2024 17:16:28 +0800 Subject: [PATCH] use pydantic's exclude --- metagpt/environment/mgx/mgx_env.py | 12 ----------- metagpt/roles/di/role_zero.py | 9 ++++---- metagpt/schema.py | 12 ----------- tests/metagpt/test_schema.py | 33 ++++++++++++------------------ 4 files changed, 17 insertions(+), 49 deletions(-) diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 99f94052a..fae386952 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -125,15 +125,3 @@ class MGXEnv(Environment, SerializationMixin): def __repr__(self): return "MGXEnv()" - - def remove_unserializable(self, data: dict): - """Removes unserializable content from the data dictionary. - - Args: - data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method. - """ - - roles = data.get("roles", {}) - - for role in roles.values(): - [role.pop(key, None) for key in role.get("unserializable_fields", [])] diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 119fdf5a3..6802fdb00 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -4,9 +4,9 @@ import inspect import json import re import traceback -from typing import Callable, Dict, List, Literal, Optional, Tuple +from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple -from pydantic import model_validator +from pydantic import Field, model_validator from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions @@ -59,14 +59,14 @@ class RoleZero(Role): # Tools tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools tool_recommender: Optional[ToolRecommender] = None - tool_execution_map: dict[str, Callable] = {} + tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {} special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"] # Equipped with three basic tools by default for optional use editor: Editor = Editor() browser: Browser = Browser() # Experience - experience_retriever: ExpRetriever = DummyExpRetriever() + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = DummyExpRetriever() # Others command_rsp: str = "" # the raw string containing the commands @@ -74,7 +74,6 @@ class RoleZero(Role): memory_k: int = 20 # number of memories (messages) to use as historical context use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements - unserializable_fields: list[str] = ["tool_execution_map", "experience_retriever"] @model_validator(mode="after") def set_plan_and_tool(self) -> "RoleZero": diff --git a/metagpt/schema.py b/metagpt/schema.py index 2431304db..ad8c8f1d7 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -146,7 +146,6 @@ class SerializationMixin(BaseModel, extra="forbid"): file_path = file_path or self.get_serialization_path() serialized_data = self.model_dump() - self.remove_unserializable(serialized_data) write_json_file(file_path, serialized_data) logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") @@ -190,17 +189,6 @@ class SerializationMixin(BaseModel, extra="forbid"): return str(SERDESER_PATH / f"{cls.__qualname__}.json") - def remove_unserializable(self, data: dict): - """Removes unserializable content from the data dictionary. - - This method removes keys specified in the "unserializable_fields" list from the provided data dictionary. - It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields - that cannot be serialized. - """ - - for key in data.get("unserializable_fields", []): - data.pop(key, None) - class SimpleMessage(BaseModel): content: str diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 12e0c7aab..bc2bdd02a 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -9,9 +9,10 @@ """ import json +from typing import Annotated import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field from metagpt.actions import Action from metagpt.actions.action_node import ActionNode @@ -405,10 +406,8 @@ class TestUserModel(SerializationMixin, BaseModel): value: int -class TestUserModelWithRemove(TestUserModel): - def remove_unserializable(self, data: dict): - for key in ["value", "__module_class_name"]: - data.pop(key, None) +class TestUserModelWithExclude(TestUserModel): + age: Annotated[int, Field(exclude=True)] class TestSerializationMixin: @@ -441,31 +440,25 @@ class TestSerializationMixin: mock_read_json_file.assert_called_once_with(file_path) assert model == TestUserModel(**data) - def test_serialize_with_remove_unserializable(self, mock_write_json_file): - model = TestUserModelWithRemove(name="test", value=42) + def test_serialize_with_exclude(self, mock_write_json_file): + model = TestUserModelWithExclude(name="test", value=42, age=10) file_path = "test.json" model.serialize(file_path) - mock_write_json_file.assert_called_once_with(file_path, {"name": "test"}) + expected_data = { + "name": "test", + "value": 42, + "__module_class_name": "tests.metagpt.test_schema.TestUserModelWithExclude", + } + + mock_write_json_file.assert_called_once_with(file_path, expected_data) def test_get_serialization_path(self): expected_path = str(SERDESER_PATH / "TestUserModel.json") assert TestUserModel.get_serialization_path() == expected_path - def test_remove_unserializable(self, mock_user_model): - data = { - "name": "example", - "unserializable_fields": ["temp_data", "debug_info"], - "temp_data": "some temporary data", - "debug_info": "some debug information", - } - mock_user_model.remove_unserializable(data) - - expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]} - assert data == expected_data - if __name__ == "__main__": pytest.main([__file__, "-s"])