use pydantic's exclude

This commit is contained in:
seehi 2024-08-12 17:16:28 +08:00
parent 11894d12f3
commit b4207cec92
4 changed files with 17 additions and 49 deletions

View file

@ -125,15 +125,3 @@ class MGXEnv(Environment, SerializationMixin):
def __repr__(self):
return "MGXEnv()"
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
Args:
data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method.
"""
roles = data.get("roles", {})
for role in roles.values():
[role.pop(key, None) for key in role.get("unserializable_fields", [])]

View file

@ -4,9 +4,9 @@ import inspect
import json
import re
import traceback
from typing import Callable, Dict, List, Literal, Optional, Tuple
from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple
from pydantic import model_validator
from pydantic import Field, model_validator
from metagpt.actions import Action, UserRequirement
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
@ -59,14 +59,14 @@ class RoleZero(Role):
# Tools
tools: list[str] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tool_recommender: Optional[ToolRecommender] = None
tool_execution_map: dict[str, Callable] = {}
tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {}
special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"]
# Equipped with three basic tools by default for optional use
editor: Editor = Editor()
browser: Browser = Browser()
# Experience
experience_retriever: ExpRetriever = DummyExpRetriever()
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = DummyExpRetriever()
# Others
command_rsp: str = "" # the raw string containing the commands
@ -74,7 +74,6 @@ class RoleZero(Role):
memory_k: int = 20 # number of memories (messages) to use as historical context
use_fixed_sop: bool = False
requirements_constraints: str = "" # the constraints in user requirements
unserializable_fields: list[str] = ["tool_execution_map", "experience_retriever"]
@model_validator(mode="after")
def set_plan_and_tool(self) -> "RoleZero":

View file

@ -146,7 +146,6 @@ class SerializationMixin(BaseModel, extra="forbid"):
file_path = file_path or self.get_serialization_path()
serialized_data = self.model_dump()
self.remove_unserializable(serialized_data)
write_json_file(file_path, serialized_data)
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
@ -190,17 +189,6 @@ class SerializationMixin(BaseModel, extra="forbid"):
return str(SERDESER_PATH / f"{cls.__qualname__}.json")
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
This method removes keys specified in the "unserializable_fields" list from the provided data dictionary.
It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields
that cannot be serialized.
"""
for key in data.get("unserializable_fields", []):
data.pop(key, None)
class SimpleMessage(BaseModel):
content: str

View file

@ -9,9 +9,10 @@
"""
import json
from typing import Annotated
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
@ -405,10 +406,8 @@ class TestUserModel(SerializationMixin, BaseModel):
value: int
class TestUserModelWithRemove(TestUserModel):
def remove_unserializable(self, data: dict):
for key in ["value", "__module_class_name"]:
data.pop(key, None)
class TestUserModelWithExclude(TestUserModel):
age: Annotated[int, Field(exclude=True)]
class TestSerializationMixin:
@ -441,31 +440,25 @@ class TestSerializationMixin:
mock_read_json_file.assert_called_once_with(file_path)
assert model == TestUserModel(**data)
def test_serialize_with_remove_unserializable(self, mock_write_json_file):
model = TestUserModelWithRemove(name="test", value=42)
def test_serialize_with_exclude(self, mock_write_json_file):
model = TestUserModelWithExclude(name="test", value=42, age=10)
file_path = "test.json"
model.serialize(file_path)
mock_write_json_file.assert_called_once_with(file_path, {"name": "test"})
expected_data = {
"name": "test",
"value": 42,
"__module_class_name": "tests.metagpt.test_schema.TestUserModelWithExclude",
}
mock_write_json_file.assert_called_once_with(file_path, expected_data)
def test_get_serialization_path(self):
expected_path = str(SERDESER_PATH / "TestUserModel.json")
assert TestUserModel.get_serialization_path() == expected_path
def test_remove_unserializable(self, mock_user_model):
data = {
"name": "example",
"unserializable_fields": ["temp_data", "debug_info"],
"temp_data": "some temporary data",
"debug_info": "some debug information",
}
mock_user_model.remove_unserializable(data)
expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]}
assert data == expected_data
if __name__ == "__main__":
pytest.main([__file__, "-s"])