use pydantic's exclude

This commit is contained in:
seehi 2024-08-12 17:16:28 +08:00
parent 11894d12f3
commit b4207cec92
4 changed files with 17 additions and 49 deletions

View file

@ -125,15 +125,3 @@ class MGXEnv(Environment, SerializationMixin):
def __repr__(self):
return "MGXEnv()"
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
Args:
data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method.
"""
roles = data.get("roles", {})
for role in roles.values():
[role.pop(key, None) for key in role.get("unserializable_fields", [])]

View file

@ -4,9 +4,9 @@ import inspect
import json
import re
import traceback
from typing import Callable, Dict, List, Literal, Optional, Tuple
from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple
from pydantic import model_validator
from pydantic import Field, model_validator
from metagpt.actions import Action, UserRequirement
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
@ -59,14 +59,14 @@ class RoleZero(Role):
# Tools
tools: list[str] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tool_recommender: Optional[ToolRecommender] = None
tool_execution_map: dict[str, Callable] = {}
tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {}
special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"]
# Equipped with three basic tools by default for optional use
editor: Editor = Editor()
browser: Browser = Browser()
# Experience
experience_retriever: ExpRetriever = DummyExpRetriever()
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = DummyExpRetriever()
# Others
command_rsp: str = "" # the raw string containing the commands
@ -74,7 +74,6 @@ class RoleZero(Role):
memory_k: int = 20 # number of memories (messages) to use as historical context
use_fixed_sop: bool = False
requirements_constraints: str = "" # the constraints in user requirements
unserializable_fields: list[str] = ["tool_execution_map", "experience_retriever"]
@model_validator(mode="after")
def set_plan_and_tool(self) -> "RoleZero":

View file

@ -146,7 +146,6 @@ class SerializationMixin(BaseModel, extra="forbid"):
file_path = file_path or self.get_serialization_path()
serialized_data = self.model_dump()
self.remove_unserializable(serialized_data)
write_json_file(file_path, serialized_data)
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
@ -190,17 +189,6 @@ class SerializationMixin(BaseModel, extra="forbid"):
return str(SERDESER_PATH / f"{cls.__qualname__}.json")
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
This method removes keys specified in the "unserializable_fields" list from the provided data dictionary.
It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields
that cannot be serialized.
"""
for key in data.get("unserializable_fields", []):
data.pop(key, None)
class SimpleMessage(BaseModel):
content: str