diff --git a/examples/serialize_model.py b/examples/serialize_model.py new file mode 100644 index 000000000..2423efef8 --- /dev/null +++ b/examples/serialize_model.py @@ -0,0 +1,25 @@ +from metagpt.environment.mgx.mgx_env import MGXEnv +from metagpt.logs import logger + + +def main(): + """Demonstrates serialization and deserialization using SerializationMixin. + + This example creates an instance of MGXEnv, serializes it to a file, + and then deserializes it back to an instance. + + If executed correctly, the following log messages will be output: + MGXEnv serialization successful. File saved at: /.../workspace/storage/MGXEnv.json + MGXEnv deserialization successful. Instance created from file: /.../workspace/storage/MGXEnv.json + The instance is MGXEnv() + """ + + env = MGXEnv() + env.serialize() + + env: MGXEnv = MGXEnv.deserialize() + logger.info(f"The instance is {repr(env)}") + + +if __name__ == "__main__": + main() diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index fc1b1cfc5..fae386952 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -10,11 +10,11 @@ from metagpt.const import AGENT from metagpt.environment.base_env import Environment from metagpt.logs import get_human_input from metagpt.roles import Architect, ProductManager, ProjectManager, Role -from metagpt.schema import Message +from metagpt.schema import Message, SerializationMixin from metagpt.utils.common import any_to_str, any_to_str_set -class MGXEnv(Environment): +class MGXEnv(Environment, SerializationMixin): """MGX Environment""" # If True, fixed software sop bypassing TL is allowed, otherwise, TL will fully take over the routing diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index 3a43f72e0..f65042217 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Annotated + from pydantic import Field, model_validator from metagpt.actions.di.execute_nb_code import ExecuteNbCode @@ -31,7 +33,7 @@ class DataAnalyst(RoleZero): tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser"] custom_tools: list[str] = ["web scraping", "Terminal"] custom_tool_recommender: ToolRecommender = None - experience_retriever: ExpRetriever = KeywordExpRetriever() + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever() use_reflection: bool = True write_code: WriteAnalysisCode = Field(default_factory=WriteAnalysisCode, exclude=True) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 98f6be62d..8d46b9c02 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -4,9 +4,9 @@ import inspect import json import re import traceback -from typing import Callable, Dict, List, Literal, Tuple +from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple -from pydantic import model_validator +from pydantic import Field, model_validator from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions @@ -47,12 +47,13 @@ class RoleZero(Role): name: str = "Zero" profile: str = "RoleZero" goal: str = "" + system_msg: Optional[list[str]] = None # Use None to conform to the default value at llm.aask system_prompt: str = SYSTEM_PROMPT # Use None to conform to the default value at llm.aask cmd_prompt: str = CMD_PROMPT cmd_prompt_current_state: str = "" thought_guidance: str = THOUGHT_GUIDANCE instruction: str = ROLE_INSTRUCTION - task_type_desc: str = None + task_type_desc: Optional[str] = None # React Mode react_mode: Literal["react"] = "react" @@ -60,15 +61,15 @@ class RoleZero(Role): # Tools tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools - tool_recommender: ToolRecommender = None - tool_execution_map: dict[str, Callable] = {} + tool_recommender: Optional[ToolRecommender] = None + tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {} special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"] # Equipped with three basic tools by default for optional use editor: Editor = Editor() browser: Browser = Browser() # Experience - experience_retriever: ExpRetriever = DummyExpRetriever() + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = DummyExpRetriever() # Others command_rsp: str = "" # the raw string containing the commands @@ -129,7 +130,7 @@ class RoleZero(Role): def _update_tool_execution(self): pass - + async def _think(self) -> bool: """Useful in 'react' mode. Use LLM to decide whether and what to do next.""" # Compatibility @@ -195,7 +196,7 @@ class RoleZero(Role): The `RoleZeroSerializer` extracts essential parts of `req` for the experience pool, trimming lengthy entries to retain only necessary parts. """ return await self.llm.aask(req, system_msgs=system_msgs) - + async def parse_browser_actions(self, memory: List[Message]) -> List[Message]: if not self.browser.is_empty_page: pattern = re.compile(r"Command Browser\.(\w+) executed") @@ -261,7 +262,7 @@ class RoleZero(Role): context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)]) intent_result = await self.llm.aask(context) - if "QUICK" in intent_result or "AMBIGUOUS " in intent_result: # llm call with the original context + if "QUICK" in intent_result or "AMBIGUOUS " in intent_result: # llm call with the original context async with ThoughtReporter(enable_llm_stream=True) as reporter: await reporter.async_report({"type": "quick"}) answer = await self.llm.aask(self.llm.format_msg(memory)) diff --git a/metagpt/roles/di/team_leader.py b/metagpt/roles/di/team_leader.py index 12b4b3a18..353e00620 100644 --- a/metagpt/roles/di/team_leader.py +++ b/metagpt/roles/di/team_leader.py @@ -1,5 +1,9 @@ from __future__ import annotations +from typing import Annotated + +from pydantic import Field + from metagpt.actions.di.run_command import RunCommand from metagpt.prompts.di.team_leader import ( FINISH_CURRENT_TASK_CMD, @@ -24,7 +28,7 @@ class TeamLeader(RoleZero): tools: list[str] = ["Plan", "RoleZero", "TeamLeader"] - experience_retriever: ExpRetriever = SimpleExpRetriever() + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = SimpleExpRetriever() def _update_tool_execution(self): self.tool_execution_map.update( diff --git a/metagpt/schema.py b/metagpt/schema.py index 648e2bd73..ad8c8f1d7 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -44,6 +44,7 @@ from metagpt.const import ( MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, + SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) @@ -56,6 +57,8 @@ from metagpt.utils.common import ( any_to_str_set, aread, import_class, + read_json_file, + write_json_file, ) from metagpt.utils.exceptions import handle_exception from metagpt.utils.report import TaskReporter @@ -127,6 +130,65 @@ class SerializationMixin(BaseModel, extra="forbid"): cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls super().__init_subclass__(**kwargs) + @handle_exception + def serialize(self, file_path: str = None) -> str: + """Serializes the current instance to a JSON file. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None. + + Returns: + str: The path to the JSON file where the instance was saved. + """ + + file_path = file_path or self.get_serialization_path() + + serialized_data = self.model_dump() + + write_json_file(file_path, serialized_data) + logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") + + return file_path + + @classmethod + @handle_exception + def deserialize(cls, file_path: str = None) -> BaseModel: + """Deserializes a JSON file to an instance of cls. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file to read from. Defaults to None. + + Returns: + An instance of the cls. + """ + + file_path = file_path or cls.get_serialization_path() + + data: dict = read_json_file(file_path) + + model = cls(**data) + logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}") + + return model + + @classmethod + def get_serialization_path(cls) -> str: + """Get the serialization path for the class. + + This method constructs a file path for serialization based on the class name. + The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName' + is the name of the class. + + Returns: + str: The path to the serialization file. + """ + + return str(SERDESER_PATH / f"{cls.__qualname__}.json") + class SimpleMessage(BaseModel): content: str diff --git a/metagpt/strategy/experience_retriever.py b/metagpt/strategy/experience_retriever.py index 9587ef9f8..46d0fd862 100644 --- a/metagpt/strategy/experience_retriever.py +++ b/metagpt/strategy/experience_retriever.py @@ -6,7 +6,7 @@ from pydantic import BaseModel class ExpRetriever(BaseModel): """interface for experience retriever""" - def retrieve(self, context: str) -> str: + def retrieve(self, context: str = "") -> str: raise NotImplementedError diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py index a458109e6..3efddd2e8 100644 --- a/metagpt/tools/libs/browser.py +++ b/metagpt/tools/libs/browser.py @@ -12,6 +12,7 @@ from playwright.async_api import ( Request, async_playwright, ) +from pydantic import BaseModel, ConfigDict, Field from metagpt.tools.tool_registry import register_tool from metagpt.utils.a11y_tree import ( @@ -43,7 +44,7 @@ from metagpt.utils.report import BrowserReporter "type", ], ) -class Browser: +class Browser(BaseModel): """A tool for browsing the web. Don't initialize a new instance of this class if one already exists. Note: If you plan to use the browser to assist you in completing tasks, then using the browser should be a standalone @@ -66,16 +67,17 @@ class Browser: >>> await browser.close_tab() """ - def __init__(self): - self.playwright: Optional[Playwright] = None - self.browser_instance: Optional[Browser_] = None - self.browser_ctx: Optional[BrowserContext] = None - self.page: Optional[Page] = None - self.accessibility_tree: list = [] - self.headless: bool = True - self.proxy = get_proxy_from_env() - self.is_empty_page = True - self.reporter = BrowserReporter() + model_config = ConfigDict(arbitrary_types_allowed=True) + + playwright: Optional[Playwright] = None + browser_instance: Optional[Browser_] = None + browser_ctx: Optional[BrowserContext] = None + page: Optional[Page] = None + accessibility_tree: list = Field(default_factory=list) + headless: bool = True + proxy: Optional[str] = Field(default_factory=get_proxy_from_env) + is_empty_page: bool = True + reporter: BrowserReporter = Field(default_factory=BrowserReporter) async def start(self) -> None: """Starts Playwright and launches a browser""" diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 40625a992..c2fdcb859 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -5,9 +5,8 @@ import subprocess from pathlib import Path from typing import List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger from metagpt.tools.tool_registry import register_tool from metagpt.utils import read_docx @@ -24,12 +23,12 @@ class FileBlock(BaseModel): @register_tool() -class Editor: +class Editor(BaseModel): """A tool for reading, understanding, writing, and editing files""" - def __init__(self) -> None: - print(f"Editor initialized with root path at: {DEFAULT_WORKSPACE_ROOT}") - self.resource = EditorReporter() + model_config = ConfigDict(arbitrary_types_allowed=True) + + resource: EditorReporter = EditorReporter() def write(self, path: str, content: str): """Write the whole content to a file. When used, make sure content arg contains the full content of the file.""" diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py index cca5cb3ae..25f403c77 100644 --- a/metagpt/tools/tool_recommend.py +++ b/metagpt/tools/tool_recommend.py @@ -64,6 +64,10 @@ class ToolRecommender(BaseModel): @field_validator("tools", mode="before") @classmethod def validate_tools(cls, v: list[str]) -> dict[str, Tool]: + # If `v` is already a dictionary (e.g., during deserialization), return it as is. + if isinstance(v, dict): + return v + # One can use special symbol [""] to indicate use of all registered tools if v == [""]: return TOOL_REGISTRY.get_all_tools() diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 42905c649..65bfa480d 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -28,6 +28,7 @@ import time import traceback from asyncio import iscoroutinefunction from datetime import datetime +from functools import partial from io import BytesIO from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -577,13 +578,32 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: return data -def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4): +def handle_unknown_serialization(x: Any) -> str: + """For `to_jsonable_python` debug, unknown values will be logged instead of raising an exception.""" + + if inspect.ismethod(x): + logger.error(f"Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}") + elif inspect.isfunction(x): + logger.error(f"Function: {x.__name__}") + elif hasattr(x, "__class__"): + logger.error(f"Instance of: {x.__class__.__name__}") + elif hasattr(x, "__name__"): + logger.error(f"Class or module: {x.__name__}") + else: + logger.error(f"Unknown type: {type(x)}") + + return f"" + + +def write_json_file(json_file: str, data: Any, encoding: str = None, indent: int = 4, use_fallback: bool = False): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) + custom_default = partial(to_jsonable_python, fallback=handle_unknown_serialization if use_fallback else None) + with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default) def read_csv_to_list(curr_file: str, header=False, strip_trail=True): diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 48f13f4a2..bc2bdd02a 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -9,13 +9,15 @@ """ import json +from typing import Annotated import pytest +from pydantic import BaseModel, Field from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode -from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO +from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, CodeSummarizeContext, @@ -23,6 +25,7 @@ from metagpt.schema import ( Message, MessageQueue, Plan, + SerializationMixin, SystemMessage, Task, UMLClassAttribute, @@ -398,5 +401,64 @@ def test_create_instruct_value(name, value): assert obj.model_dump() == value +class TestUserModel(SerializationMixin, BaseModel): + name: str + value: int + + +class TestUserModelWithExclude(TestUserModel): + age: Annotated[int, Field(exclude=True)] + + +class TestSerializationMixin: + @pytest.fixture + def mock_write_json_file(self, mocker): + return mocker.patch("metagpt.schema.write_json_file") + + @pytest.fixture + def mock_read_json_file(self, mocker): + return mocker.patch("metagpt.schema.read_json_file") + + @pytest.fixture + def mock_user_model(self): + return TestUserModel(name="test", value=42) + + def test_serialize(self, mock_write_json_file, mock_user_model): + file_path = "test.json" + + mock_user_model.serialize(file_path) + + mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump()) + + def test_deserialize(self, mock_read_json_file): + file_path = "test.json" + data = {"name": "test", "value": 42} + mock_read_json_file.return_value = data + + model = TestUserModel.deserialize(file_path) + + mock_read_json_file.assert_called_once_with(file_path) + assert model == TestUserModel(**data) + + def test_serialize_with_exclude(self, mock_write_json_file): + model = TestUserModelWithExclude(name="test", value=42, age=10) + file_path = "test.json" + + model.serialize(file_path) + + expected_data = { + "name": "test", + "value": 42, + "__module_class_name": "tests.metagpt.test_schema.TestUserModelWithExclude", + } + + mock_write_json_file.assert_called_once_with(file_path, expected_data) + + def test_get_serialization_path(self): + expected_path = str(SERDESER_PATH / "TestUserModel.json") + + assert TestUserModel.get_serialization_path() == expected_path + + if __name__ == "__main__": pytest.main([__file__, "-s"])