diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 873358252..156b08267 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -1,3 +1,5 @@ +from typing import ClassVar + from metagpt.actions import ( UserRequirement, WriteDesign, @@ -6,12 +8,14 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT +from metagpt.const import AGENT, SERDESER_PATH from metagpt.environment.base_env import Environment -from metagpt.logs import get_human_input +from metagpt.logs import get_human_input, logger from metagpt.roles import Architect, ProductManager, ProjectManager, Role from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.exceptions import handle_exception +from metagpt.utils.serialize import deserialize_model, serialize_model class MGXEnv(Environment): @@ -22,6 +26,11 @@ class MGXEnv(Environment): direct_chat_roles: set[str] = set() # record direct chat: @role_name + default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json") + + def __repr__(self): + return "MGXEnv()" + def _publish_message(self, message: Message, peekable: bool = True) -> bool: return super().publish_message(message, peekable) @@ -121,5 +130,54 @@ class MGXEnv(Environment): converted_msg.content = f"from {sent_from} to {converted_msg.send_to}: {converted_msg.content}" return converted_msg - def __repr__(self): - return "MGXEnv()" + @handle_exception + def serialize(self, file_path: str = None) -> str: + """Serializes the current instance to a JSON file. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None. + + Returns: + str: The path to the JSON file where the instance was saved. + """ + + file_path = file_path or self.default_serialization_path + + serialize_model(self, file_path, remove_unserializable=self.remove_unserializable) + logger.info(f"MGXEnv serialization successful. File saved at: {file_path}") + + return file_path + + @classmethod + @handle_exception + def deserialize(cls, file_path: str = None) -> "MGXEnv": + """Deserializes a JSON file to an instance of MGXEnv. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file to read from. Defaults to None. + + Returns: + MGXEnv: An instance of MGXEnv. + """ + + file_path = file_path or cls.default_serialization_path + + model = deserialize_model(cls, file_path) + logger.info(f"MGXEnv deserialization successful. Instance created from file: {file_path}") + + return model + + def remove_unserializable(self, data: dict): + """Removes unserializable content from the data dictionary. + + Args: + data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method. + """ + roles = data.get("roles", {}) + + for role in roles.values(): + [role.pop(key, None) for key in role.get("unserializable_fields", [])] diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 4e932649c..b1b524d21 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -4,7 +4,7 @@ import inspect import json import re import traceback -from typing import Callable, Dict, List, Literal, Tuple +from typing import Callable, Dict, List, Literal, Optional, Tuple from pydantic import model_validator @@ -46,11 +46,11 @@ class RoleZero(Role): name: str = "Zero" profile: str = "RoleZero" goal: str = "" - system_msg: list[str] = None # Use None to conform to the default value at llm.aask + system_msg: Optional[list[str]] = None # Use None to conform to the default value at llm.aask cmd_prompt: str = CMD_PROMPT thought_guidance: str = THOUGHT_GUIDANCE instruction: str = ROLE_INSTRUCTION - task_type_desc: str = None + task_type_desc: Optional[str] = None # React Mode react_mode: Literal["react"] = "react" @@ -58,7 +58,7 @@ class RoleZero(Role): # Tools tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools - tool_recommender: ToolRecommender = None + tool_recommender: Optional[ToolRecommender] = None tool_execution_map: dict[str, Callable] = {} special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"] # Equipped with three basic tools by default for optional use @@ -74,6 +74,7 @@ class RoleZero(Role): memory_k: int = 20 # number of memories (messages) to use as historical context use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements + unserializable_fields: list[str] = ["tool_execution_map", "experience_retriever"] @model_validator(mode="after") def set_plan_and_tool(self) -> "RoleZero": diff --git a/metagpt/roles/di/swe_agent.py b/metagpt/roles/di/swe_agent.py index 2384ac147..d85a3863b 100644 --- a/metagpt/roles/di/swe_agent.py +++ b/metagpt/roles/di/swe_agent.py @@ -1,4 +1,5 @@ import json +from typing import Optional from pydantic import Field @@ -17,7 +18,7 @@ class SWEAgent(RoleZero): name: str = "Swen" profile: str = "Issue Solver" goal: str = "Resolve GitHub issue or bug in any existing codebase" - system_msg: str = [SWE_AGENT_SYSTEM_TEMPLATE] + system_msg: Optional[list[str]] = [SWE_AGENT_SYSTEM_TEMPLATE] _instruction: str = NEXT_STEP_TEMPLATE tools: list[str] = [ "Bash", diff --git a/metagpt/strategy/experience_retriever.py b/metagpt/strategy/experience_retriever.py index 9587ef9f8..46d0fd862 100644 --- a/metagpt/strategy/experience_retriever.py +++ b/metagpt/strategy/experience_retriever.py @@ -6,7 +6,7 @@ from pydantic import BaseModel class ExpRetriever(BaseModel): """interface for experience retriever""" - def retrieve(self, context: str) -> str: + def retrieve(self, context: str = "") -> str: raise NotImplementedError diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py index a458109e6..3efddd2e8 100644 --- a/metagpt/tools/libs/browser.py +++ b/metagpt/tools/libs/browser.py @@ -12,6 +12,7 @@ from playwright.async_api import ( Request, async_playwright, ) +from pydantic import BaseModel, ConfigDict, Field from metagpt.tools.tool_registry import register_tool from metagpt.utils.a11y_tree import ( @@ -43,7 +44,7 @@ from metagpt.utils.report import BrowserReporter "type", ], ) -class Browser: +class Browser(BaseModel): """A tool for browsing the web. Don't initialize a new instance of this class if one already exists. Note: If you plan to use the browser to assist you in completing tasks, then using the browser should be a standalone @@ -66,16 +67,17 @@ class Browser: >>> await browser.close_tab() """ - def __init__(self): - self.playwright: Optional[Playwright] = None - self.browser_instance: Optional[Browser_] = None - self.browser_ctx: Optional[BrowserContext] = None - self.page: Optional[Page] = None - self.accessibility_tree: list = [] - self.headless: bool = True - self.proxy = get_proxy_from_env() - self.is_empty_page = True - self.reporter = BrowserReporter() + model_config = ConfigDict(arbitrary_types_allowed=True) + + playwright: Optional[Playwright] = None + browser_instance: Optional[Browser_] = None + browser_ctx: Optional[BrowserContext] = None + page: Optional[Page] = None + accessibility_tree: list = Field(default_factory=list) + headless: bool = True + proxy: Optional[str] = Field(default_factory=get_proxy_from_env) + is_empty_page: bool = True + reporter: BrowserReporter = Field(default_factory=BrowserReporter) async def start(self) -> None: """Starts Playwright and launches a browser""" diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index b964a2741..149482ab2 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -2,9 +2,8 @@ import os import shutil import subprocess -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.tools.tool_registry import register_tool from metagpt.utils.report import EditorReporter @@ -17,12 +16,12 @@ class FileBlock(BaseModel): @register_tool() -class Editor: +class Editor(BaseModel): """A tool for reading, understanding, writing, and editing files""" - def __init__(self) -> None: - print(f"Editor initialized with root path at: {DEFAULT_WORKSPACE_ROOT}") - self.resource = EditorReporter() + model_config = ConfigDict(arbitrary_types_allowed=True) + + resource: EditorReporter = EditorReporter() def write(self, path: str, content: str): """Write the whole content to a file. When used, make sure content arg contains the full content of the file.""" diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py index cca5cb3ae..25f403c77 100644 --- a/metagpt/tools/tool_recommend.py +++ b/metagpt/tools/tool_recommend.py @@ -64,6 +64,10 @@ class ToolRecommender(BaseModel): @field_validator("tools", mode="before") @classmethod def validate_tools(cls, v: list[str]) -> dict[str, Tool]: + # If `v` is already a dictionary (e.g., during deserialization), return it as is. + if isinstance(v, dict): + return v + # One can use special symbol [""] to indicate use of all registered tools if v == [""]: return TOOL_REGISTRY.get_all_tools() diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 3eead9ed4..2f7867104 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -28,6 +28,7 @@ import time import traceback from asyncio import iscoroutinefunction from datetime import datetime +from functools import partial from io import BytesIO from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -577,13 +578,30 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: return data -def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4): +def write_json_file(json_file: str, data: Any, encoding: str = None, indent: int = 4, use_fallback: bool = False): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) + # For debug, if use_fallback, unknown values will be logged instead of raising an exception. + def fallback(x: Any) -> str: + tip = f"PydanticSerializationError occurred while processing file '{json_file}'" + + if inspect.ismethod(x): + logger.error(f"{tip}, Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}") + elif inspect.isfunction(x): + logger.error(f"{tip}, Function: {x.__name__}") + elif hasattr(x, "__class__"): + logger.error(f"{tip}, Instance of: {x.__class__.__name__}") + elif hasattr(x, "__name__"): + logger.error(f"{tip}, Class or module: {x.__name__}") + else: + logger.error(f"{tip}, Unknown type: {type(x)}") + + custom_default = partial(to_jsonable_python, fallback=fallback if use_fallback else None) + with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default) def read_csv_to_list(curr_file: str, header=False, strip_trail=True): diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index c6bd8ad75..814621377 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,8 +4,11 @@ import copy import pickle +from typing import Callable, Optional, Type -from metagpt.utils.common import import_class +from pydantic import BaseModel + +from metagpt.utils.common import import_class, read_json_file, write_json_file def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -81,3 +84,36 @@ def deserialize_message(message_ser: str) -> "Message": message.instruct_content = ic_new return message + + +def serialize_model(model: BaseModel, file_path: str, remove_unserializable: Optional[Callable[[dict], None]] = None): + """Serializes a Pydantic model to a JSON file. + + Args: + model (BaseModel): The Pydantic model to serialize. + file_path (str): The path to the JSON file where the model will be saved. + remove_unserializable (Optional[Callable[[dict], None]]): Optional function to remove unserializable content from the serialized data. + """ + + serialized_data = model.model_dump() + + if remove_unserializable: + remove_unserializable(serialized_data) + + write_json_file(file_path, serialized_data) + + +def deserialize_model(cls: Type[BaseModel], file_path: str) -> BaseModel: + """Deserializes a JSON file to a Pydantic model. + + Args: + cls (Type[BaseModel]): The Pydantic model class to deserialize into. + file_path (str): The path to the JSON file to read from. + + Returns: + BaseModel: An instance of the Pydantic model. + """ + + data: dict = read_json_file(file_path) + + return cls(**data) diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index 0ba3a8d41..3b20f3fa0 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -6,13 +6,17 @@ from typing import List +from pydantic import BaseModel + from metagpt.actions import WritePRD from metagpt.actions.action_node import ActionNode from metagpt.schema import Message from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, deserialize_message, + deserialize_model, serialize_message, + serialize_model, ) @@ -66,3 +70,35 @@ def test_serialize_and_deserialize_message(): assert new_message.content == message.content assert new_message.cause_by == message.cause_by assert new_message.instruct_content.field1 == out_data["field1"] + + +class TestUserModel(BaseModel): + name: str + value: int + + +def test_serialize_model(mocker): + model = TestUserModel(name="test", value=42) + file_path = "test.json" + mock_write_json_file = mocker.patch("metagpt.utils.serialize.write_json_file") + + # Test without remove_unserializable + serialize_model(model, file_path) + mock_write_json_file.assert_called_once_with(file_path, model.model_dump()) + + # Test with remove_unserializable + def remove_unserializable(data: dict): + data.pop("value", None) + + serialize_model(model, file_path, remove_unserializable) + mock_write_json_file.assert_called_with(file_path, {"name": "test"}) + + +def test_deserialize_model(mocker): + file_path = "test.json" + data = {"name": "test", "value": 42} + mock_read_json_file = mocker.patch("metagpt.utils.serialize.read_json_file", return_value=data) + + model = deserialize_model(TestUserModel, file_path) + mock_read_json_file.assert_called_once_with(file_path) + assert model == TestUserModel(**data)