From 98ac5fbce3d744d4c3a91403d73cf2e0bbe7cf8b Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 9 Aug 2024 10:31:03 +0800 Subject: [PATCH 01/20] serialize mgxenv --- metagpt/environment/mgx/mgx_env.py | 66 ++++++++++++++++++++++-- metagpt/roles/di/role_zero.py | 9 ++-- metagpt/roles/di/swe_agent.py | 3 +- metagpt/strategy/experience_retriever.py | 2 +- metagpt/tools/libs/browser.py | 24 +++++---- metagpt/tools/libs/editor.py | 11 ++-- metagpt/tools/tool_recommend.py | 4 ++ metagpt/utils/common.py | 22 +++++++- metagpt/utils/serialize.py | 38 +++++++++++++- tests/metagpt/utils/test_serialize.py | 36 +++++++++++++ 10 files changed, 185 insertions(+), 30 deletions(-) diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 873358252..156b08267 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -1,3 +1,5 @@ +from typing import ClassVar + from metagpt.actions import ( UserRequirement, WriteDesign, @@ -6,12 +8,14 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT +from metagpt.const import AGENT, SERDESER_PATH from metagpt.environment.base_env import Environment -from metagpt.logs import get_human_input +from metagpt.logs import get_human_input, logger from metagpt.roles import Architect, ProductManager, ProjectManager, Role from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.exceptions import handle_exception +from metagpt.utils.serialize import deserialize_model, serialize_model class MGXEnv(Environment): @@ -22,6 +26,11 @@ class MGXEnv(Environment): direct_chat_roles: set[str] = set() # record direct chat: @role_name + default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json") + + def __repr__(self): + return "MGXEnv()" + def _publish_message(self, message: Message, peekable: bool = True) -> bool: return super().publish_message(message, peekable) @@ -121,5 +130,54 @@ class MGXEnv(Environment): converted_msg.content = f"from {sent_from} to {converted_msg.send_to}: {converted_msg.content}" return converted_msg - def __repr__(self): - return "MGXEnv()" + @handle_exception + def serialize(self, file_path: str = None) -> str: + """Serializes the current instance to a JSON file. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None. + + Returns: + str: The path to the JSON file where the instance was saved. + """ + + file_path = file_path or self.default_serialization_path + + serialize_model(self, file_path, remove_unserializable=self.remove_unserializable) + logger.info(f"MGXEnv serialization successful. File saved at: {file_path}") + + return file_path + + @classmethod + @handle_exception + def deserialize(cls, file_path: str = None) -> "MGXEnv": + """Deserializes a JSON file to an instance of MGXEnv. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file to read from. Defaults to None. + + Returns: + MGXEnv: An instance of MGXEnv. + """ + + file_path = file_path or cls.default_serialization_path + + model = deserialize_model(cls, file_path) + logger.info(f"MGXEnv deserialization successful. Instance created from file: {file_path}") + + return model + + def remove_unserializable(self, data: dict): + """Removes unserializable content from the data dictionary. + + Args: + data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method. + """ + roles = data.get("roles", {}) + + for role in roles.values(): + [role.pop(key, None) for key in role.get("unserializable_fields", [])] diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 4e932649c..b1b524d21 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -4,7 +4,7 @@ import inspect import json import re import traceback -from typing import Callable, Dict, List, Literal, Tuple +from typing import Callable, Dict, List, Literal, Optional, Tuple from pydantic import model_validator @@ -46,11 +46,11 @@ class RoleZero(Role): name: str = "Zero" profile: str = "RoleZero" goal: str = "" - system_msg: list[str] = None # Use None to conform to the default value at llm.aask + system_msg: Optional[list[str]] = None # Use None to conform to the default value at llm.aask cmd_prompt: str = CMD_PROMPT thought_guidance: str = THOUGHT_GUIDANCE instruction: str = ROLE_INSTRUCTION - task_type_desc: str = None + task_type_desc: Optional[str] = None # React Mode react_mode: Literal["react"] = "react" @@ -58,7 +58,7 @@ class RoleZero(Role): # Tools tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools - tool_recommender: ToolRecommender = None + tool_recommender: Optional[ToolRecommender] = None tool_execution_map: dict[str, Callable] = {} special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"] # Equipped with three basic tools by default for optional use @@ -74,6 +74,7 @@ class RoleZero(Role): memory_k: int = 20 # number of memories (messages) to use as historical context use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements + unserializable_fields: list[str] = ["tool_execution_map", "experience_retriever"] @model_validator(mode="after") def set_plan_and_tool(self) -> "RoleZero": diff --git a/metagpt/roles/di/swe_agent.py b/metagpt/roles/di/swe_agent.py index 2384ac147..d85a3863b 100644 --- a/metagpt/roles/di/swe_agent.py +++ b/metagpt/roles/di/swe_agent.py @@ -1,4 +1,5 @@ import json +from typing import Optional from pydantic import Field @@ -17,7 +18,7 @@ class SWEAgent(RoleZero): name: str = "Swen" profile: str = "Issue Solver" goal: str = "Resolve GitHub issue or bug in any existing codebase" - system_msg: str = [SWE_AGENT_SYSTEM_TEMPLATE] + system_msg: Optional[list[str]] = [SWE_AGENT_SYSTEM_TEMPLATE] _instruction: str = NEXT_STEP_TEMPLATE tools: list[str] = [ "Bash", diff --git a/metagpt/strategy/experience_retriever.py b/metagpt/strategy/experience_retriever.py index 9587ef9f8..46d0fd862 100644 --- a/metagpt/strategy/experience_retriever.py +++ b/metagpt/strategy/experience_retriever.py @@ -6,7 +6,7 @@ from pydantic import BaseModel class ExpRetriever(BaseModel): """interface for experience retriever""" - def retrieve(self, context: str) -> str: + def retrieve(self, context: str = "") -> str: raise NotImplementedError diff --git a/metagpt/tools/libs/browser.py b/metagpt/tools/libs/browser.py index a458109e6..3efddd2e8 100644 --- a/metagpt/tools/libs/browser.py +++ b/metagpt/tools/libs/browser.py @@ -12,6 +12,7 @@ from playwright.async_api import ( Request, async_playwright, ) +from pydantic import BaseModel, ConfigDict, Field from metagpt.tools.tool_registry import register_tool from metagpt.utils.a11y_tree import ( @@ -43,7 +44,7 @@ from metagpt.utils.report import BrowserReporter "type", ], ) -class Browser: +class Browser(BaseModel): """A tool for browsing the web. Don't initialize a new instance of this class if one already exists. Note: If you plan to use the browser to assist you in completing tasks, then using the browser should be a standalone @@ -66,16 +67,17 @@ class Browser: >>> await browser.close_tab() """ - def __init__(self): - self.playwright: Optional[Playwright] = None - self.browser_instance: Optional[Browser_] = None - self.browser_ctx: Optional[BrowserContext] = None - self.page: Optional[Page] = None - self.accessibility_tree: list = [] - self.headless: bool = True - self.proxy = get_proxy_from_env() - self.is_empty_page = True - self.reporter = BrowserReporter() + model_config = ConfigDict(arbitrary_types_allowed=True) + + playwright: Optional[Playwright] = None + browser_instance: Optional[Browser_] = None + browser_ctx: Optional[BrowserContext] = None + page: Optional[Page] = None + accessibility_tree: list = Field(default_factory=list) + headless: bool = True + proxy: Optional[str] = Field(default_factory=get_proxy_from_env) + is_empty_page: bool = True + reporter: BrowserReporter = Field(default_factory=BrowserReporter) async def start(self) -> None: """Starts Playwright and launches a browser""" diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index b964a2741..149482ab2 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -2,9 +2,8 @@ import os import shutil import subprocess -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.tools.tool_registry import register_tool from metagpt.utils.report import EditorReporter @@ -17,12 +16,12 @@ class FileBlock(BaseModel): @register_tool() -class Editor: +class Editor(BaseModel): """A tool for reading, understanding, writing, and editing files""" - def __init__(self) -> None: - print(f"Editor initialized with root path at: {DEFAULT_WORKSPACE_ROOT}") - self.resource = EditorReporter() + model_config = ConfigDict(arbitrary_types_allowed=True) + + resource: EditorReporter = EditorReporter() def write(self, path: str, content: str): """Write the whole content to a file. When used, make sure content arg contains the full content of the file.""" diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py index cca5cb3ae..25f403c77 100644 --- a/metagpt/tools/tool_recommend.py +++ b/metagpt/tools/tool_recommend.py @@ -64,6 +64,10 @@ class ToolRecommender(BaseModel): @field_validator("tools", mode="before") @classmethod def validate_tools(cls, v: list[str]) -> dict[str, Tool]: + # If `v` is already a dictionary (e.g., during deserialization), return it as is. + if isinstance(v, dict): + return v + # One can use special symbol [""] to indicate use of all registered tools if v == [""]: return TOOL_REGISTRY.get_all_tools() diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 3eead9ed4..2f7867104 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -28,6 +28,7 @@ import time import traceback from asyncio import iscoroutinefunction from datetime import datetime +from functools import partial from io import BytesIO from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -577,13 +578,30 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: return data -def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4): +def write_json_file(json_file: str, data: Any, encoding: str = None, indent: int = 4, use_fallback: bool = False): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) + # For debug, if use_fallback, unknown values will be logged instead of raising an exception. + def fallback(x: Any) -> str: + tip = f"PydanticSerializationError occurred while processing file '{json_file}'" + + if inspect.ismethod(x): + logger.error(f"{tip}, Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}") + elif inspect.isfunction(x): + logger.error(f"{tip}, Function: {x.__name__}") + elif hasattr(x, "__class__"): + logger.error(f"{tip}, Instance of: {x.__class__.__name__}") + elif hasattr(x, "__name__"): + logger.error(f"{tip}, Class or module: {x.__name__}") + else: + logger.error(f"{tip}, Unknown type: {type(x)}") + + custom_default = partial(to_jsonable_python, fallback=fallback if use_fallback else None) + with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default) def read_csv_to_list(curr_file: str, header=False, strip_trail=True): diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index c6bd8ad75..814621377 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,8 +4,11 @@ import copy import pickle +from typing import Callable, Optional, Type -from metagpt.utils.common import import_class +from pydantic import BaseModel + +from metagpt.utils.common import import_class, read_json_file, write_json_file def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -81,3 +84,36 @@ def deserialize_message(message_ser: str) -> "Message": message.instruct_content = ic_new return message + + +def serialize_model(model: BaseModel, file_path: str, remove_unserializable: Optional[Callable[[dict], None]] = None): + """Serializes a Pydantic model to a JSON file. + + Args: + model (BaseModel): The Pydantic model to serialize. + file_path (str): The path to the JSON file where the model will be saved. + remove_unserializable (Optional[Callable[[dict], None]]): Optional function to remove unserializable content from the serialized data. + """ + + serialized_data = model.model_dump() + + if remove_unserializable: + remove_unserializable(serialized_data) + + write_json_file(file_path, serialized_data) + + +def deserialize_model(cls: Type[BaseModel], file_path: str) -> BaseModel: + """Deserializes a JSON file to a Pydantic model. + + Args: + cls (Type[BaseModel]): The Pydantic model class to deserialize into. + file_path (str): The path to the JSON file to read from. + + Returns: + BaseModel: An instance of the Pydantic model. + """ + + data: dict = read_json_file(file_path) + + return cls(**data) diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index 0ba3a8d41..3b20f3fa0 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -6,13 +6,17 @@ from typing import List +from pydantic import BaseModel + from metagpt.actions import WritePRD from metagpt.actions.action_node import ActionNode from metagpt.schema import Message from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, deserialize_message, + deserialize_model, serialize_message, + serialize_model, ) @@ -66,3 +70,35 @@ def test_serialize_and_deserialize_message(): assert new_message.content == message.content assert new_message.cause_by == message.cause_by assert new_message.instruct_content.field1 == out_data["field1"] + + +class TestUserModel(BaseModel): + name: str + value: int + + +def test_serialize_model(mocker): + model = TestUserModel(name="test", value=42) + file_path = "test.json" + mock_write_json_file = mocker.patch("metagpt.utils.serialize.write_json_file") + + # Test without remove_unserializable + serialize_model(model, file_path) + mock_write_json_file.assert_called_once_with(file_path, model.model_dump()) + + # Test with remove_unserializable + def remove_unserializable(data: dict): + data.pop("value", None) + + serialize_model(model, file_path, remove_unserializable) + mock_write_json_file.assert_called_with(file_path, {"name": "test"}) + + +def test_deserialize_model(mocker): + file_path = "test.json" + data = {"name": "test", "value": 42} + mock_read_json_file = mocker.patch("metagpt.utils.serialize.read_json_file", return_value=data) + + model = deserialize_model(TestUserModel, file_path) + mock_read_json_file.assert_called_once_with(file_path) + assert model == TestUserModel(**data) From d98d36fec5b367c1fcf29cbfc50e0bc7739fb62d Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 9 Aug 2024 10:34:20 +0800 Subject: [PATCH 02/20] serialize mgxenv --- metagpt/environment/mgx/mgx_env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 156b08267..517e185d9 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -28,9 +28,6 @@ class MGXEnv(Environment): default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json") - def __repr__(self): - return "MGXEnv()" - def _publish_message(self, message: Message, peekable: bool = True) -> bool: return super().publish_message(message, peekable) @@ -130,6 +127,9 @@ class MGXEnv(Environment): converted_msg.content = f"from {sent_from} to {converted_msg.send_to}: {converted_msg.content}" return converted_msg + def __repr__(self): + return "MGXEnv()" + @handle_exception def serialize(self, file_path: str = None) -> str: """Serializes the current instance to a JSON file. From 887f180e58f338025eb8bedcd16c662215a58b30 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 9 Aug 2024 16:38:38 +0800 Subject: [PATCH 03/20] create new func `handle_unknown_serialization` --- metagpt/utils/common.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 7b3cb5224..cf938cb3d 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -579,27 +579,29 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: return data +def handle_unknown_serialization(x: Any) -> str: + """For `to_jsonable_python` debug, unknown values will be logged instead of raising an exception.""" + + if inspect.ismethod(x): + logger.error(f"Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}") + elif inspect.isfunction(x): + logger.error(f"Function: {x.__name__}") + elif hasattr(x, "__class__"): + logger.error(f"Instance of: {x.__class__.__name__}") + elif hasattr(x, "__name__"): + logger.error(f"Class or module: {x.__name__}") + else: + logger.error(f"Unknown type: {type(x)}") + + return f"" + + def write_json_file(json_file: str, data: Any, encoding: str = None, indent: int = 4, use_fallback: bool = False): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) - # For debug, if use_fallback, unknown values will be logged instead of raising an exception. - def fallback(x: Any) -> str: - tip = f"PydanticSerializationError occurred while processing file '{json_file}'" - - if inspect.ismethod(x): - logger.error(f"{tip}, Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}") - elif inspect.isfunction(x): - logger.error(f"{tip}, Function: {x.__name__}") - elif hasattr(x, "__class__"): - logger.error(f"{tip}, Instance of: {x.__class__.__name__}") - elif hasattr(x, "__name__"): - logger.error(f"{tip}, Class or module: {x.__name__}") - else: - logger.error(f"{tip}, Unknown type: {type(x)}") - - custom_default = partial(to_jsonable_python, fallback=fallback if use_fallback else None) + custom_default = partial(to_jsonable_python, fallback=handle_unknown_serialization if use_fallback else None) with open(json_file, "w", encoding=encoding) as fout: json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default) From 43ffa3558b89134c7e0c05445b5e55f160546b7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E4=BC=9F=E9=9F=AC?= Date: Mon, 12 Aug 2024 13:52:32 +0800 Subject: [PATCH 04/20] add repair_escape_error function to parse commands --- metagpt/prompts/di/role_zero.py | 3 +++ metagpt/roles/di/role_zero.py | 36 ++++++++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/metagpt/prompts/di/role_zero.py b/metagpt/prompts/di/role_zero.py index 8443a7960..25ca4637a 100644 --- a/metagpt/prompts/di/role_zero.py +++ b/metagpt/prompts/di/role_zero.py @@ -100,6 +100,9 @@ JSON_REPAIR_PROMPT = """ ## json data {json_data} +## json decode error +{json_decode_error} + ## Output Format ```json diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 773124dcc..e483f03cc 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -311,10 +311,20 @@ class RoleZero(Role): if commands.endswith("]") and not commands.startswith("["): commands = "[" + commands commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON)) - except json.JSONDecodeError: + except json.JSONDecodeError as e: logger.warning(f"Failed to parse JSON for: {self.command_rsp}. Trying to repair...") - commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp)) - commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands)) + commands = await self.llm.aask( + msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp, json_decode_error=str(e)) + ) + try: + commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands)) + except json.JSONDecodeError: + # repair escape error of code and math + commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp) + new_command = self.repair_escape_error(commands) + commands = json.loads( + repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON) + ) except Exception as e: tb = traceback.format_exc() print(tb) @@ -327,6 +337,26 @@ class RoleZero(Role): commands = commands["commands"] if "commands" in commands else [commands] return commands, True + def repair_escape_error(self, commands): + """Repaires escape errors in command responses""" + escape_repair_map = { + "\a": "\\\\a", + "\b": "\\\\b", + "\f": "\\\\f", + "\r": "\\\\r", + "\t": "\\\\t", + "\v": "\\\\v", + } + new_command = "" + for index, ch in enumerate(commands): + if ch == "\\" and index + 1 < len(commands): + if commands[index + 1] not in ["n", '"', " "]: + new_command += "\\" + elif ch in escape_repair_map: + ch = escape_repair_map[ch] + new_command += ch + return commands + async def _run_commands(self, commands) -> str: outputs = [] for cmd in commands: From 587dd0cc81d06aa039ff942e2f23162498d77642 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 12 Aug 2024 14:37:00 +0800 Subject: [PATCH 05/20] use serialize in SerializationMixin --- metagpt/environment/mgx/mgx_env.py | 56 ++------------------ metagpt/schema.py | 74 +++++++++++++++++++++++++++ metagpt/utils/serialize.py | 38 +------------- tests/metagpt/test_schema.py | 71 ++++++++++++++++++++++++- tests/metagpt/utils/test_serialize.py | 36 ------------- 5 files changed, 150 insertions(+), 125 deletions(-) diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 43ef9c4b5..99f94052a 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -1,5 +1,3 @@ -from typing import ClassVar - from metagpt.actions import ( UserRequirement, WriteDesign, @@ -8,17 +6,15 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT, SERDESER_PATH +from metagpt.const import AGENT from metagpt.environment.base_env import Environment -from metagpt.logs import get_human_input, logger +from metagpt.logs import get_human_input from metagpt.roles import Architect, ProductManager, ProjectManager, Role -from metagpt.schema import Message +from metagpt.schema import Message, SerializationMixin from metagpt.utils.common import any_to_str, any_to_str_set -from metagpt.utils.exceptions import handle_exception -from metagpt.utils.serialize import deserialize_model, serialize_model -class MGXEnv(Environment): +class MGXEnv(Environment, SerializationMixin): """MGX Environment""" # If True, fixed software sop bypassing TL is allowed, otherwise, TL will fully take over the routing @@ -26,8 +22,6 @@ class MGXEnv(Environment): direct_chat_roles: set[str] = set() # record direct chat: @role_name - default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json") - def _publish_message(self, message: Message, peekable: bool = True) -> bool: return super().publish_message(message, peekable) @@ -132,53 +126,13 @@ class MGXEnv(Environment): def __repr__(self): return "MGXEnv()" - @handle_exception - def serialize(self, file_path: str = None) -> str: - """Serializes the current instance to a JSON file. - - If an exception occurs, `handle_exception` will catch it and return `None`. - - Args: - file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None. - - Returns: - str: The path to the JSON file where the instance was saved. - """ - - file_path = file_path or self.default_serialization_path - - serialize_model(self, file_path, remove_unserializable=self.remove_unserializable) - logger.info(f"MGXEnv serialization successful. File saved at: {file_path}") - - return file_path - - @classmethod - @handle_exception - def deserialize(cls, file_path: str = None) -> "MGXEnv": - """Deserializes a JSON file to an instance of MGXEnv. - - If an exception occurs, `handle_exception` will catch it and return `None`. - - Args: - file_path (str, optional): The path to the JSON file to read from. Defaults to None. - - Returns: - MGXEnv: An instance of MGXEnv. - """ - - file_path = file_path or cls.default_serialization_path - - model = deserialize_model(cls, file_path) - logger.info(f"MGXEnv deserialization successful. Instance created from file: {file_path}") - - return model - def remove_unserializable(self, data: dict): """Removes unserializable content from the data dictionary. Args: data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method. """ + roles = data.get("roles", {}) for role in roles.values(): diff --git a/metagpt/schema.py b/metagpt/schema.py index 648e2bd73..2431304db 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -44,6 +44,7 @@ from metagpt.const import ( MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, + SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) @@ -56,6 +57,8 @@ from metagpt.utils.common import ( any_to_str_set, aread, import_class, + read_json_file, + write_json_file, ) from metagpt.utils.exceptions import handle_exception from metagpt.utils.report import TaskReporter @@ -127,6 +130,77 @@ class SerializationMixin(BaseModel, extra="forbid"): cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls super().__init_subclass__(**kwargs) + @handle_exception + def serialize(self, file_path: str = None) -> str: + """Serializes the current instance to a JSON file. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None. + + Returns: + str: The path to the JSON file where the instance was saved. + """ + + file_path = file_path or self.get_serialization_path() + + serialized_data = self.model_dump() + self.remove_unserializable(serialized_data) + + write_json_file(file_path, serialized_data) + logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") + + return file_path + + @classmethod + @handle_exception + def deserialize(cls, file_path: str = None) -> BaseModel: + """Deserializes a JSON file to an instance of cls. + + If an exception occurs, `handle_exception` will catch it and return `None`. + + Args: + file_path (str, optional): The path to the JSON file to read from. Defaults to None. + + Returns: + An instance of the cls. + """ + + file_path = file_path or cls.get_serialization_path() + + data: dict = read_json_file(file_path) + + model = cls(**data) + logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}") + + return model + + @classmethod + def get_serialization_path(cls) -> str: + """Get the serialization path for the class. + + This method constructs a file path for serialization based on the class name. + The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName' + is the name of the class. + + Returns: + str: The path to the serialization file. + """ + + return str(SERDESER_PATH / f"{cls.__qualname__}.json") + + def remove_unserializable(self, data: dict): + """Removes unserializable content from the data dictionary. + + This method removes keys specified in the "unserializable_fields" list from the provided data dictionary. + It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields + that cannot be serialized. + """ + + for key in data.get("unserializable_fields", []): + data.pop(key, None) + class SimpleMessage(BaseModel): content: str diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 814621377..c6bd8ad75 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,11 +4,8 @@ import copy import pickle -from typing import Callable, Optional, Type -from pydantic import BaseModel - -from metagpt.utils.common import import_class, read_json_file, write_json_file +from metagpt.utils.common import import_class def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -84,36 +81,3 @@ def deserialize_message(message_ser: str) -> "Message": message.instruct_content = ic_new return message - - -def serialize_model(model: BaseModel, file_path: str, remove_unserializable: Optional[Callable[[dict], None]] = None): - """Serializes a Pydantic model to a JSON file. - - Args: - model (BaseModel): The Pydantic model to serialize. - file_path (str): The path to the JSON file where the model will be saved. - remove_unserializable (Optional[Callable[[dict], None]]): Optional function to remove unserializable content from the serialized data. - """ - - serialized_data = model.model_dump() - - if remove_unserializable: - remove_unserializable(serialized_data) - - write_json_file(file_path, serialized_data) - - -def deserialize_model(cls: Type[BaseModel], file_path: str) -> BaseModel: - """Deserializes a JSON file to a Pydantic model. - - Args: - cls (Type[BaseModel]): The Pydantic model class to deserialize into. - file_path (str): The path to the JSON file to read from. - - Returns: - BaseModel: An instance of the Pydantic model. - """ - - data: dict = read_json_file(file_path) - - return cls(**data) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 48f13f4a2..12e0c7aab 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -11,11 +11,12 @@ import json import pytest +from pydantic import BaseModel from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode -from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO +from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, CodeSummarizeContext, @@ -23,6 +24,7 @@ from metagpt.schema import ( Message, MessageQueue, Plan, + SerializationMixin, SystemMessage, Task, UMLClassAttribute, @@ -398,5 +400,72 @@ def test_create_instruct_value(name, value): assert obj.model_dump() == value +class TestUserModel(SerializationMixin, BaseModel): + name: str + value: int + + +class TestUserModelWithRemove(TestUserModel): + def remove_unserializable(self, data: dict): + for key in ["value", "__module_class_name"]: + data.pop(key, None) + + +class TestSerializationMixin: + @pytest.fixture + def mock_write_json_file(self, mocker): + return mocker.patch("metagpt.schema.write_json_file") + + @pytest.fixture + def mock_read_json_file(self, mocker): + return mocker.patch("metagpt.schema.read_json_file") + + @pytest.fixture + def mock_user_model(self): + return TestUserModel(name="test", value=42) + + def test_serialize(self, mock_write_json_file, mock_user_model): + file_path = "test.json" + + mock_user_model.serialize(file_path) + + mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump()) + + def test_deserialize(self, mock_read_json_file): + file_path = "test.json" + data = {"name": "test", "value": 42} + mock_read_json_file.return_value = data + + model = TestUserModel.deserialize(file_path) + + mock_read_json_file.assert_called_once_with(file_path) + assert model == TestUserModel(**data) + + def test_serialize_with_remove_unserializable(self, mock_write_json_file): + model = TestUserModelWithRemove(name="test", value=42) + file_path = "test.json" + + model.serialize(file_path) + + mock_write_json_file.assert_called_once_with(file_path, {"name": "test"}) + + def test_get_serialization_path(self): + expected_path = str(SERDESER_PATH / "TestUserModel.json") + + assert TestUserModel.get_serialization_path() == expected_path + + def test_remove_unserializable(self, mock_user_model): + data = { + "name": "example", + "unserializable_fields": ["temp_data", "debug_info"], + "temp_data": "some temporary data", + "debug_info": "some debug information", + } + mock_user_model.remove_unserializable(data) + + expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]} + assert data == expected_data + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index 3b20f3fa0..0ba3a8d41 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -6,17 +6,13 @@ from typing import List -from pydantic import BaseModel - from metagpt.actions import WritePRD from metagpt.actions.action_node import ActionNode from metagpt.schema import Message from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, deserialize_message, - deserialize_model, serialize_message, - serialize_model, ) @@ -70,35 +66,3 @@ def test_serialize_and_deserialize_message(): assert new_message.content == message.content assert new_message.cause_by == message.cause_by assert new_message.instruct_content.field1 == out_data["field1"] - - -class TestUserModel(BaseModel): - name: str - value: int - - -def test_serialize_model(mocker): - model = TestUserModel(name="test", value=42) - file_path = "test.json" - mock_write_json_file = mocker.patch("metagpt.utils.serialize.write_json_file") - - # Test without remove_unserializable - serialize_model(model, file_path) - mock_write_json_file.assert_called_once_with(file_path, model.model_dump()) - - # Test with remove_unserializable - def remove_unserializable(data: dict): - data.pop("value", None) - - serialize_model(model, file_path, remove_unserializable) - mock_write_json_file.assert_called_with(file_path, {"name": "test"}) - - -def test_deserialize_model(mocker): - file_path = "test.json" - data = {"name": "test", "value": 42} - mock_read_json_file = mocker.patch("metagpt.utils.serialize.read_json_file", return_value=data) - - model = deserialize_model(TestUserModel, file_path) - mock_read_json_file.assert_called_once_with(file_path) - assert model == TestUserModel(**data) From 31764add8a6065316b02df6fdbc0c55f6938d1a0 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 12 Aug 2024 14:56:15 +0800 Subject: [PATCH 06/20] add an example --- examples/serialize_model.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 examples/serialize_model.py diff --git a/examples/serialize_model.py b/examples/serialize_model.py new file mode 100644 index 000000000..b81afc743 --- /dev/null +++ b/examples/serialize_model.py @@ -0,0 +1,24 @@ +from metagpt.logs import logger +from metagpt.roles.product_manager import ProductManager + + +def main(): + """Demonstrates serialization and deserialization using SerializationMixin. + + This example creates an instance of ProductManager, serializes it to a file, + and then deserializes it back to an instance. + + If executed correctly, the following log messages will be output: + ProductManager serialization successful. File saved at: /data/hjt/gitlab_metagpt/workspace/storage/ProductManager.json + ProductManager deserialization successful. Instance created from file: /data/hjt/gitlab_metagpt/workspace/storage/ProductManager.json + The role is Product Manager + """ + role = ProductManager() + role.serialize() + + role: ProductManager = ProductManager.deserialize() + logger.info(f"The role is {role.profile}") + + +if __name__ == "__main__": + main() From 9d40d96535e8325a416c086aa619901b112d8168 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 12 Aug 2024 15:04:56 +0800 Subject: [PATCH 07/20] add an example --- examples/serialize_model.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/serialize_model.py b/examples/serialize_model.py index b81afc743..2e4431ed2 100644 --- a/examples/serialize_model.py +++ b/examples/serialize_model.py @@ -1,23 +1,24 @@ +from metagpt.environment.mgx.mgx_env import MGXEnv from metagpt.logs import logger -from metagpt.roles.product_manager import ProductManager def main(): """Demonstrates serialization and deserialization using SerializationMixin. - This example creates an instance of ProductManager, serializes it to a file, + This example creates an instance of MGXEnv, serializes it to a file, and then deserializes it back to an instance. If executed correctly, the following log messages will be output: - ProductManager serialization successful. File saved at: /data/hjt/gitlab_metagpt/workspace/storage/ProductManager.json - ProductManager deserialization successful. Instance created from file: /data/hjt/gitlab_metagpt/workspace/storage/ProductManager.json - The role is Product Manager + MGXEnv serialization successful. File saved at: /.../workspace/storage/MGXEnv.json + MGXEnv deserialization successful. Instance created from file: /.../workspace/storage/MGXEnv.json + The object is MGXEnv() """ - role = ProductManager() - role.serialize() - role: ProductManager = ProductManager.deserialize() - logger.info(f"The role is {role.profile}") + env = MGXEnv() + env.serialize() + + env: MGXEnv = MGXEnv.deserialize() + logger.info(f"The object is {repr(env)}") if __name__ == "__main__": From 11894d12f3f056053c6c7743c0e6b15ccb643650 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 12 Aug 2024 15:19:36 +0800 Subject: [PATCH 08/20] add an example --- examples/serialize_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/serialize_model.py b/examples/serialize_model.py index 2e4431ed2..2423efef8 100644 --- a/examples/serialize_model.py +++ b/examples/serialize_model.py @@ -11,14 +11,14 @@ def main(): If executed correctly, the following log messages will be output: MGXEnv serialization successful. File saved at: /.../workspace/storage/MGXEnv.json MGXEnv deserialization successful. Instance created from file: /.../workspace/storage/MGXEnv.json - The object is MGXEnv() + The instance is MGXEnv() """ env = MGXEnv() env.serialize() env: MGXEnv = MGXEnv.deserialize() - logger.info(f"The object is {repr(env)}") + logger.info(f"The instance is {repr(env)}") if __name__ == "__main__": From 9deb9f35127925aeca9dab6b8e3fc6cea3b11859 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Mon, 12 Aug 2024 16:40:27 +0800 Subject: [PATCH 09/20] add url params for SERPAPI/SERPER search engine --- metagpt/tools/search_engine_serpapi.py | 19 ++++++++----------- metagpt/tools/search_engine_serper.py | 15 ++++++--------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 5744b1b62..b3ccb0649 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -6,7 +6,7 @@ @File : search_engine_serpapi.py """ import warnings -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import aiohttp from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -24,6 +24,7 @@ class SerpAPIWrapper(BaseModel): "hl": "en", } ) + url: str = "https://serpapi.com/search" aiosession: Optional[aiohttp.ClientSession] = None proxy: Optional[str] = None @@ -49,22 +50,18 @@ class SerpAPIWrapper(BaseModel): async def results(self, query: str, max_results: int) -> dict: """Use aiohttp to run query through SerpAPI and return the results async.""" - def construct_url_and_params() -> Tuple[str, Dict[str, str]]: - params = self.get_params(query) - params["source"] = "python" - params["num"] = max_results - params["output"] = "json" - url = "https://serpapi.com/search" - return url, params + params = self.get_params(query) + params["source"] = "python" + params["num"] = max_results + params["output"] = "json" - url, params = construct_url_and_params() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.get(url, params=params, proxy=self.proxy) as response: + async with session.get(self.url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get(url, params=params, proxy=self.proxy) as response: + async with self.aiosession.get(self.url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index ba2fb4f93..bcb959ed3 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -7,7 +7,7 @@ """ import json import warnings -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import aiohttp from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -17,6 +17,7 @@ class SerperWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) api_key: str + url: str = "https://google.serper.dev/search" payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10}) aiosession: Optional[aiohttp.ClientSession] = None proxy: Optional[str] = None @@ -46,20 +47,16 @@ class SerperWrapper(BaseModel): async def results(self, queries: list[str], max_results: int = 8) -> dict: """Use aiohttp to run query through Serper and return the results async.""" - def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]: - payloads = self.get_payloads(queries, max_results) - url = "https://google.serper.dev/search" - headers = self.get_headers() - return url, payloads, headers + payloads = self.get_payloads(queries, max_results) + headers = self.get_headers() - url, payloads, headers = construct_url_and_payload_and_headers() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: + async with session.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: + async with self.aiosession.get.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() From be6c3b445554bf4e8278392ab9935cc4cd127afd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E4=BC=9F=E9=9F=AC?= Date: Mon, 12 Aug 2024 16:53:48 +0800 Subject: [PATCH 10/20] =?UTF-8?q?=E8=B0=83=E6=95=B4repair=5Fescape=5Ferror?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/roles/di/role_zero.py | 28 ++++++-------------------- metagpt/utils/repair_llm_raw_output.py | 24 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 032bf8101..960dfa805 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -35,7 +35,11 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, any_to_str -from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output +from metagpt.utils.repair_llm_raw_output import ( + RepairType, + repair_escape_error, + repair_llm_raw_output, +) from metagpt.utils.report import ThoughtReporter @@ -326,7 +330,7 @@ class RoleZero(Role): except json.JSONDecodeError: # repair escape error of code and math commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp) - new_command = self.repair_escape_error(commands) + new_command = repair_escape_error(commands) commands = json.loads( repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON) ) @@ -342,26 +346,6 @@ class RoleZero(Role): commands = commands["commands"] if "commands" in commands else [commands] return commands, True - def repair_escape_error(self, commands): - """Repaires escape errors in command responses""" - escape_repair_map = { - "\a": "\\\\a", - "\b": "\\\\b", - "\f": "\\\\f", - "\r": "\\\\r", - "\t": "\\\\t", - "\v": "\\\\v", - } - new_command = "" - for index, ch in enumerate(commands): - if ch == "\\" and index + 1 < len(commands): - if commands[index + 1] not in ["n", '"', " "]: - new_command += "\\" - elif ch in escape_repair_map: - ch = escape_repair_map[ch] - new_command += ch - return commands - async def _run_commands(self, commands) -> str: outputs = [] for cmd in commands: diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 17e095c5f..fc27448eb 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -347,3 +347,27 @@ def extract_state_value_from_output(content: str) -> str: matches = list(set(matches)) state = matches[0] if len(matches) > 0 else "-1" return state + + +def repair_escape_error(commands): + """ + Repaires escape errors in command responses. + When role-zero parses a command, the command may contain unknown escape characters. + """ + escape_repair_map = { + "\a": "\\\\a", + "\b": "\\\\b", + "\f": "\\\\f", + "\r": "\\\\r", + "\t": "\\\\t", + "\v": "\\\\v", + } + new_command = "" + for index, ch in enumerate(commands): + if ch == "\\" and index + 1 < len(commands): + if commands[index + 1] not in ["n", '"', " "]: + new_command += "\\" + elif ch in escape_repair_map: + ch = escape_repair_map[ch] + new_command += ch + return new_command From b4207cec923bd15b5d74ffbbf0e442a937af4aa3 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 12 Aug 2024 17:16:28 +0800 Subject: [PATCH 11/20] use pydantic's exclude --- metagpt/environment/mgx/mgx_env.py | 12 ----------- metagpt/roles/di/role_zero.py | 9 ++++---- metagpt/schema.py | 12 ----------- tests/metagpt/test_schema.py | 33 ++++++++++++------------------ 4 files changed, 17 insertions(+), 49 deletions(-) diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 99f94052a..fae386952 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -125,15 +125,3 @@ class MGXEnv(Environment, SerializationMixin): def __repr__(self): return "MGXEnv()" - - def remove_unserializable(self, data: dict): - """Removes unserializable content from the data dictionary. - - Args: - data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method. - """ - - roles = data.get("roles", {}) - - for role in roles.values(): - [role.pop(key, None) for key in role.get("unserializable_fields", [])] diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 119fdf5a3..6802fdb00 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -4,9 +4,9 @@ import inspect import json import re import traceback -from typing import Callable, Dict, List, Literal, Optional, Tuple +from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple -from pydantic import model_validator +from pydantic import Field, model_validator from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions @@ -59,14 +59,14 @@ class RoleZero(Role): # Tools tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools tool_recommender: Optional[ToolRecommender] = None - tool_execution_map: dict[str, Callable] = {} + tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {} special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"] # Equipped with three basic tools by default for optional use editor: Editor = Editor() browser: Browser = Browser() # Experience - experience_retriever: ExpRetriever = DummyExpRetriever() + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = DummyExpRetriever() # Others command_rsp: str = "" # the raw string containing the commands @@ -74,7 +74,6 @@ class RoleZero(Role): memory_k: int = 20 # number of memories (messages) to use as historical context use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements - unserializable_fields: list[str] = ["tool_execution_map", "experience_retriever"] @model_validator(mode="after") def set_plan_and_tool(self) -> "RoleZero": diff --git a/metagpt/schema.py b/metagpt/schema.py index 2431304db..ad8c8f1d7 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -146,7 +146,6 @@ class SerializationMixin(BaseModel, extra="forbid"): file_path = file_path or self.get_serialization_path() serialized_data = self.model_dump() - self.remove_unserializable(serialized_data) write_json_file(file_path, serialized_data) logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}") @@ -190,17 +189,6 @@ class SerializationMixin(BaseModel, extra="forbid"): return str(SERDESER_PATH / f"{cls.__qualname__}.json") - def remove_unserializable(self, data: dict): - """Removes unserializable content from the data dictionary. - - This method removes keys specified in the "unserializable_fields" list from the provided data dictionary. - It is intended to clean the dictionary obtained from Pydantic's `model_dump` method by removing fields - that cannot be serialized. - """ - - for key in data.get("unserializable_fields", []): - data.pop(key, None) - class SimpleMessage(BaseModel): content: str diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 12e0c7aab..bc2bdd02a 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -9,9 +9,10 @@ """ import json +from typing import Annotated import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field from metagpt.actions import Action from metagpt.actions.action_node import ActionNode @@ -405,10 +406,8 @@ class TestUserModel(SerializationMixin, BaseModel): value: int -class TestUserModelWithRemove(TestUserModel): - def remove_unserializable(self, data: dict): - for key in ["value", "__module_class_name"]: - data.pop(key, None) +class TestUserModelWithExclude(TestUserModel): + age: Annotated[int, Field(exclude=True)] class TestSerializationMixin: @@ -441,31 +440,25 @@ class TestSerializationMixin: mock_read_json_file.assert_called_once_with(file_path) assert model == TestUserModel(**data) - def test_serialize_with_remove_unserializable(self, mock_write_json_file): - model = TestUserModelWithRemove(name="test", value=42) + def test_serialize_with_exclude(self, mock_write_json_file): + model = TestUserModelWithExclude(name="test", value=42, age=10) file_path = "test.json" model.serialize(file_path) - mock_write_json_file.assert_called_once_with(file_path, {"name": "test"}) + expected_data = { + "name": "test", + "value": 42, + "__module_class_name": "tests.metagpt.test_schema.TestUserModelWithExclude", + } + + mock_write_json_file.assert_called_once_with(file_path, expected_data) def test_get_serialization_path(self): expected_path = str(SERDESER_PATH / "TestUserModel.json") assert TestUserModel.get_serialization_path() == expected_path - def test_remove_unserializable(self, mock_user_model): - data = { - "name": "example", - "unserializable_fields": ["temp_data", "debug_info"], - "temp_data": "some temporary data", - "debug_info": "some debug information", - } - mock_user_model.remove_unserializable(data) - - expected_data = {"name": "example", "unserializable_fields": ["temp_data", "debug_info"]} - assert data == expected_data - if __name__ == "__main__": pytest.main([__file__, "-s"]) From 2222cf33793577f222e1316b14803d5433cfce22 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 12 Aug 2024 18:24:10 +0800 Subject: [PATCH 12/20] use exclude --- metagpt/roles/di/data_analyst.py | 4 +++- metagpt/roles/di/team_leader.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index 3a43f72e0..f65042217 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Annotated + from pydantic import Field, model_validator from metagpt.actions.di.execute_nb_code import ExecuteNbCode @@ -31,7 +33,7 @@ class DataAnalyst(RoleZero): tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser"] custom_tools: list[str] = ["web scraping", "Terminal"] custom_tool_recommender: ToolRecommender = None - experience_retriever: ExpRetriever = KeywordExpRetriever() + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever() use_reflection: bool = True write_code: WriteAnalysisCode = Field(default_factory=WriteAnalysisCode, exclude=True) diff --git a/metagpt/roles/di/team_leader.py b/metagpt/roles/di/team_leader.py index 12b4b3a18..353e00620 100644 --- a/metagpt/roles/di/team_leader.py +++ b/metagpt/roles/di/team_leader.py @@ -1,5 +1,9 @@ from __future__ import annotations +from typing import Annotated + +from pydantic import Field + from metagpt.actions.di.run_command import RunCommand from metagpt.prompts.di.team_leader import ( FINISH_CURRENT_TASK_CMD, @@ -24,7 +28,7 @@ class TeamLeader(RoleZero): tools: list[str] = ["Plan", "RoleZero", "TeamLeader"] - experience_retriever: ExpRetriever = SimpleExpRetriever() + experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = SimpleExpRetriever() def _update_tool_execution(self): self.tool_execution_map.update( From 90e1e53bb69a79afd8dafa87cefd6e75e61ef5e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E4=BC=9F=E9=9F=AC?= Date: Mon, 12 Aug 2024 20:07:35 +0800 Subject: [PATCH 13/20] add annotations --- metagpt/utils/repair_llm_raw_output.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index fc27448eb..f1607255e 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -353,6 +353,23 @@ def repair_escape_error(commands): """ Repaires escape errors in command responses. When role-zero parses a command, the command may contain unknown escape characters. + + This function has two steps: + 1. Transform unescaped substrings like "\d" and "\(" to "\\\\d" and "\\\\(". + 2. Transform escaped characters like '\f' to substrings like "\\\\f". + + Example: + When the original JSON string is " {"content":"\\\\( \\\\frac{1}{2} \\\\)"} ", + The "content" will be parsed correctly to "\( \frac{1}{2} \)". + + When there is a wrong JSON string like: " {"content":"\( \frac{1}{2} \)"}", + It will cause a parsing error. + + To repair the wrong JSON string, the following transformations will be used: + "\(" ---> "\\\\(" + '\f' ---> "\\\\f" + "\)" ---> "\\\\)" + """ escape_repair_map = { "\a": "\\\\a", From f1d102e76c97eeca7f33f984414a990b74c8317b Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Mon, 12 Aug 2024 20:22:33 +0800 Subject: [PATCH 14/20] add url params for SERPAPI/SERPER search engine --- metagpt/configs/search_config.py | 5 +++-- metagpt/tools/search_engine_serpapi.py | 19 ++++++++----------- metagpt/tools/search_engine_serper.py | 16 +++++++--------- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py index 5f7f2d9a3..7b50fb6d3 100644 --- a/metagpt/configs/search_config.py +++ b/metagpt/configs/search_config.py @@ -7,7 +7,7 @@ """ from typing import Callable, Optional -from pydantic import Field +from pydantic import ConfigDict, Field from metagpt.tools import SearchEngineType from metagpt.utils.yaml_model import YamlModel @@ -16,10 +16,11 @@ from metagpt.utils.yaml_model import YamlModel class SearchConfig(YamlModel): """Config for Search""" + model_config = ConfigDict(extra="allow") + api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO api_key: str = "" cse_id: str = "" # for google - discovery_service_url: str = "" # for google search_func: Optional[Callable] = None params: dict = Field( default_factory=lambda: { diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 5744b1b62..b3ccb0649 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -6,7 +6,7 @@ @File : search_engine_serpapi.py """ import warnings -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import aiohttp from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -24,6 +24,7 @@ class SerpAPIWrapper(BaseModel): "hl": "en", } ) + url: str = "https://serpapi.com/search" aiosession: Optional[aiohttp.ClientSession] = None proxy: Optional[str] = None @@ -49,22 +50,18 @@ class SerpAPIWrapper(BaseModel): async def results(self, query: str, max_results: int) -> dict: """Use aiohttp to run query through SerpAPI and return the results async.""" - def construct_url_and_params() -> Tuple[str, Dict[str, str]]: - params = self.get_params(query) - params["source"] = "python" - params["num"] = max_results - params["output"] = "json" - url = "https://serpapi.com/search" - return url, params + params = self.get_params(query) + params["source"] = "python" + params["num"] = max_results + params["output"] = "json" - url, params = construct_url_and_params() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.get(url, params=params, proxy=self.proxy) as response: + async with session.get(self.url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get(url, params=params, proxy=self.proxy) as response: + async with self.aiosession.get(self.url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index ba2fb4f93..932f2eb44 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -7,7 +7,7 @@ """ import json import warnings -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import aiohttp from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -17,6 +17,7 @@ class SerperWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) api_key: str + url: str = "https://google.serper.dev/search" payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10}) aiosession: Optional[aiohttp.ClientSession] = None proxy: Optional[str] = None @@ -33,6 +34,7 @@ class SerperWrapper(BaseModel): "To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain " "an API key from https://serper.dev/." ) + return values async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: @@ -46,20 +48,16 @@ class SerperWrapper(BaseModel): async def results(self, queries: list[str], max_results: int = 8) -> dict: """Use aiohttp to run query through Serper and return the results async.""" - def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]: - payloads = self.get_payloads(queries, max_results) - url = "https://google.serper.dev/search" - headers = self.get_headers() - return url, payloads, headers + payloads = self.get_payloads(queries, max_results) + headers = self.get_headers() - url, payloads, headers = construct_url_and_payload_and_headers() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: + async with session.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: + async with self.aiosession.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() From 42da3b1fe8e2200338b5ad476ca4193e2b1d3e12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E4=BC=9F=E9=9F=AC?= Date: Mon, 12 Aug 2024 20:28:50 +0800 Subject: [PATCH 15/20] update annotations --- metagpt/utils/repair_llm_raw_output.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index f1607255e..2015b2ed7 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -352,7 +352,7 @@ def extract_state_value_from_output(content: str) -> str: def repair_escape_error(commands): """ Repaires escape errors in command responses. - When role-zero parses a command, the command may contain unknown escape characters. + When RoleZero parses a command, the command may contain unknown escape characters. This function has two steps: 1. Transform unescaped substrings like "\d" and "\(" to "\\\\d" and "\\\\(". @@ -362,7 +362,7 @@ def repair_escape_error(commands): When the original JSON string is " {"content":"\\\\( \\\\frac{1}{2} \\\\)"} ", The "content" will be parsed correctly to "\( \frac{1}{2} \)". - When there is a wrong JSON string like: " {"content":"\( \frac{1}{2} \)"}", + However, if the orginal JSON string is " {"content":"\( \frac{1}{2} \)"}" directly. It will cause a parsing error. To repair the wrong JSON string, the following transformations will be used: From ae4b65fdfde1e0a6a112c77d0722d74ad9a89b00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E4=BC=9F=E9=9F=AC?= Date: Mon, 12 Aug 2024 20:30:46 +0800 Subject: [PATCH 16/20] update annotations --- metagpt/utils/repair_llm_raw_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 2015b2ed7..68fa73108 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -362,7 +362,7 @@ def repair_escape_error(commands): When the original JSON string is " {"content":"\\\\( \\\\frac{1}{2} \\\\)"} ", The "content" will be parsed correctly to "\( \frac{1}{2} \)". - However, if the orginal JSON string is " {"content":"\( \frac{1}{2} \)"}" directly. + However, if the original JSON string is " {"content":"\( \frac{1}{2} \)"}" directly. It will cause a parsing error. To repair the wrong JSON string, the following transformations will be used: From 46bce295cab987dc5e73b9dadcf21a1fdda9f1bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 13 Aug 2024 13:43:46 +0800 Subject: [PATCH 17/20] =?UTF-8?q?fixbug:=20linux=E7=8E=AF=E5=A2=83?= =?UTF-8?q?=E5=8F=98=E9=87=8Fkey=E4=B8=8D=E8=83=BD=E4=BD=BF=E7=94=A8-?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/prompts/di/role_zero.py | 4 ++-- metagpt/tools/libs/env.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/metagpt/prompts/di/role_zero.py b/metagpt/prompts/di/role_zero.py index 49ad7a05b..b6faa8cb3 100644 --- a/metagpt/prompts/di/role_zero.py +++ b/metagpt/prompts/di/role_zero.py @@ -152,7 +152,7 @@ Respond with a concise thought, then provide the appropriate response category: """ -QUICK_THINK_EXAMPLES =""" +QUICK_THINK_EXAMPLES = """ # Example 1. Request: "How do I design an online document editing platform that supports real-time collaboration?" @@ -190,4 +190,4 @@ Response Category: AMBIGUOUS. # Instruction """ -QUICK_THINK_PROMPT = QUICK_THINK_PROMPT.format(examples=QUICK_THINK_EXAMPLES) \ No newline at end of file +QUICK_THINK_PROMPT = QUICK_THINK_PROMPT.format(examples=QUICK_THINK_EXAMPLES) diff --git a/metagpt/tools/libs/env.py b/metagpt/tools/libs/env.py index 1fa265d07..c1757c5f9 100644 --- a/metagpt/tools/libs/env.py +++ b/metagpt/tools/libs/env.py @@ -31,6 +31,10 @@ async def default_get_env(key: str, app_name: str = None) -> str: if app_key in os.environ: return os.environ[app_key] + env_app_key = app_key.replace("-", "_") # "-" is not supported by linux environment variable + if env_app_key in os.environ: + return os.environ[env_app_key] + from metagpt.context import Context context = Context() From 2a4e3730e10cd6eb96d718b71a4eee80610e6508 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Tue, 13 Aug 2024 15:12:24 +0800 Subject: [PATCH 18/20] add search report --- metagpt/actions/search_enhanced_qa.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/metagpt/actions/search_enhanced_qa.py b/metagpt/actions/search_enhanced_qa.py index 1d7944d61..152e615b6 100644 --- a/metagpt/actions/search_enhanced_qa.py +++ b/metagpt/actions/search_enhanced_qa.py @@ -4,7 +4,7 @@ from __future__ import annotations import json -from pydantic import Field, model_validator +from pydantic import Field, PrivateAttr, model_validator from metagpt.actions import Action from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize @@ -90,6 +90,8 @@ class SearchEnhancedQA(Action): description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.", ) + _reporter: ThoughtReporter = PrivateAttr(ThoughtReporter()) + @model_validator(mode="after") def initialize(self): if self.web_browse_and_summarize_action is None: @@ -118,13 +120,14 @@ class SearchEnhancedQA(Action): Raises: ValueError: If the query is invalid. """ + async with self._reporter: + await self._reporter.async_report({"type": "search", "stage": "init"}) + self._validate_query(query) - self._validate_query(query) + processed_query = await self._process_query(query, rewrite_query) + context = await self._build_context(processed_query) - processed_query = await self._process_query(query, rewrite_query) - context = await self._build_context(processed_query) - - return await self._generate_answer(processed_query, context) + return await self._generate_answer(processed_query, context) def _validate_query(self, query: str) -> None: """Validate the input query. @@ -203,6 +206,7 @@ class SearchEnhancedQA(Action): """ relevant_urls = await self._collect_relevant_links(query) + await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls}) if not relevant_urls: logger.warning(f"No relevant URLs found for query: {query}") return [] @@ -245,10 +249,12 @@ class SearchEnhancedQA(Action): contents = await self._fetch_web_contents(urls) summaries = {} + await self._reporter.async_report( + {"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]} + ) for content in contents: url = content.url inner_text = content.inner_text.replace("\n", "") - if self.web_browse_and_summarize_action._is_content_invalid(inner_text): logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...") continue @@ -276,8 +282,7 @@ class SearchEnhancedQA(Action): system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context) - async with ThoughtReporter(enable_llm_stream=True) as reporter: - await reporter.async_report({"type": "quick"}) + async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "search", "stage": "answer"}) rsp = await self._aask(query, [system_prompt]) - return rsp From b36f34c1443ecad617b5c56eb585e091a2e6e819 Mon Sep 17 00:00:00 2001 From: zhanglei Date: Tue, 13 Aug 2024 16:50:09 +0800 Subject: [PATCH 19/20] fix: result process bug --- metagpt/ext/cr/actions/code_review.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metagpt/ext/cr/actions/code_review.py b/metagpt/ext/cr/actions/code_review.py index 473ea8018..ab86ec239 100644 --- a/metagpt/ext/cr/actions/code_review.py +++ b/metagpt/ext/cr/actions/code_review.py @@ -164,6 +164,7 @@ class CodeReview(Action): system_prompt = [CODE_REVIEW_COMFIRM_SYSTEM_PROMPT.format(code_language=code_language)] resp = await self.llm.aask(prompt, system_msgs=system_prompt) if "True" in resp or "true" in resp: + cmt["code"] = get_code_block_from_patch(patch, code_start_line, code_end_line) new_comments.append(cmt) except Exception: logger.info("False") From e6ae39da57d096111c3f32a1b18fbb26289d72a3 Mon Sep 17 00:00:00 2001 From: zhanglei Date: Tue, 13 Aug 2024 16:53:16 +0800 Subject: [PATCH 20/20] fix --- metagpt/ext/cr/actions/code_review.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/metagpt/ext/cr/actions/code_review.py b/metagpt/ext/cr/actions/code_review.py index ab86ec239..e3e6e69f2 100644 --- a/metagpt/ext/cr/actions/code_review.py +++ b/metagpt/ext/cr/actions/code_review.py @@ -164,7 +164,6 @@ class CodeReview(Action): system_prompt = [CODE_REVIEW_COMFIRM_SYSTEM_PROMPT.format(code_language=code_language)] resp = await self.llm.aask(prompt, system_msgs=system_prompt) if "True" in resp or "true" in resp: - cmt["code"] = get_code_block_from_patch(patch, code_start_line, code_end_line) new_comments.append(cmt) except Exception: logger.info("False") @@ -209,6 +208,6 @@ class CodeReview(Action): comments = await self.confirm_comments(patch=patch, comments=comments, points=points) for comment in comments: if comment["code"]: - if not (comment["code"].startswith("-") or comment["code"].isspace()): + if not (comment["code"].isspace()): result.append(comment) return result