serialize mgxenv

This commit is contained in:
seehi 2024-08-09 10:31:03 +08:00
parent 5f86247c0d
commit 98ac5fbce3
10 changed files with 185 additions and 30 deletions

View file

@ -1,3 +1,5 @@
from typing import ClassVar
from metagpt.actions import (
UserRequirement,
WriteDesign,
@ -6,12 +8,14 @@ from metagpt.actions import (
WriteTest,
)
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import AGENT
from metagpt.const import AGENT, SERDESER_PATH
from metagpt.environment.base_env import Environment
from metagpt.logs import get_human_input
from metagpt.logs import get_human_input, logger
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.serialize import deserialize_model, serialize_model
class MGXEnv(Environment):
@ -22,6 +26,11 @@ class MGXEnv(Environment):
direct_chat_roles: set[str] = set() # record direct chat: @role_name
default_serialization_path: ClassVar[str] = str(SERDESER_PATH / "mgxenv" / "mgxenv.json")
def __repr__(self):
return "MGXEnv()"
def _publish_message(self, message: Message, peekable: bool = True) -> bool:
return super().publish_message(message, peekable)
@ -121,5 +130,54 @@ class MGXEnv(Environment):
converted_msg.content = f"from {sent_from} to {converted_msg.send_to}: {converted_msg.content}"
return converted_msg
def __repr__(self):
return "MGXEnv()"
@handle_exception
def serialize(self, file_path: str = None) -> str:
"""Serializes the current instance to a JSON file.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
Returns:
str: The path to the JSON file where the instance was saved.
"""
file_path = file_path or self.default_serialization_path
serialize_model(self, file_path, remove_unserializable=self.remove_unserializable)
logger.info(f"MGXEnv serialization successful. File saved at: {file_path}")
return file_path
@classmethod
@handle_exception
def deserialize(cls, file_path: str = None) -> "MGXEnv":
"""Deserializes a JSON file to an instance of MGXEnv.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
Returns:
MGXEnv: An instance of MGXEnv.
"""
file_path = file_path or cls.default_serialization_path
model = deserialize_model(cls, file_path)
logger.info(f"MGXEnv deserialization successful. Instance created from file: {file_path}")
return model
def remove_unserializable(self, data: dict):
"""Removes unserializable content from the data dictionary.
Args:
data (dict): The data dictionary to clean, obtained from Pydantic's model_dump method.
"""
roles = data.get("roles", {})
for role in roles.values():
[role.pop(key, None) for key in role.get("unserializable_fields", [])]

View file

@ -4,7 +4,7 @@ import inspect
import json
import re
import traceback
from typing import Callable, Dict, List, Literal, Tuple
from typing import Callable, Dict, List, Literal, Optional, Tuple
from pydantic import model_validator
@ -46,11 +46,11 @@ class RoleZero(Role):
name: str = "Zero"
profile: str = "RoleZero"
goal: str = ""
system_msg: list[str] = None # Use None to conform to the default value at llm.aask
system_msg: Optional[list[str]] = None # Use None to conform to the default value at llm.aask
cmd_prompt: str = CMD_PROMPT
thought_guidance: str = THOUGHT_GUIDANCE
instruction: str = ROLE_INSTRUCTION
task_type_desc: str = None
task_type_desc: Optional[str] = None
# React Mode
react_mode: Literal["react"] = "react"
@ -58,7 +58,7 @@ class RoleZero(Role):
# Tools
tools: list[str] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tool_recommender: ToolRecommender = None
tool_recommender: Optional[ToolRecommender] = None
tool_execution_map: dict[str, Callable] = {}
special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"]
# Equipped with three basic tools by default for optional use
@ -74,6 +74,7 @@ class RoleZero(Role):
memory_k: int = 20 # number of memories (messages) to use as historical context
use_fixed_sop: bool = False
requirements_constraints: str = "" # the constraints in user requirements
unserializable_fields: list[str] = ["tool_execution_map", "experience_retriever"]
@model_validator(mode="after")
def set_plan_and_tool(self) -> "RoleZero":

View file

@ -1,4 +1,5 @@
import json
from typing import Optional
from pydantic import Field
@ -17,7 +18,7 @@ class SWEAgent(RoleZero):
name: str = "Swen"
profile: str = "Issue Solver"
goal: str = "Resolve GitHub issue or bug in any existing codebase"
system_msg: str = [SWE_AGENT_SYSTEM_TEMPLATE]
system_msg: Optional[list[str]] = [SWE_AGENT_SYSTEM_TEMPLATE]
_instruction: str = NEXT_STEP_TEMPLATE
tools: list[str] = [
"Bash",

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
class ExpRetriever(BaseModel):
"""interface for experience retriever"""
def retrieve(self, context: str) -> str:
def retrieve(self, context: str = "") -> str:
raise NotImplementedError

View file

@ -12,6 +12,7 @@ from playwright.async_api import (
Request,
async_playwright,
)
from pydantic import BaseModel, ConfigDict, Field
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.a11y_tree import (
@ -43,7 +44,7 @@ from metagpt.utils.report import BrowserReporter
"type",
],
)
class Browser:
class Browser(BaseModel):
"""A tool for browsing the web. Don't initialize a new instance of this class if one already exists.
Note: If you plan to use the browser to assist you in completing tasks, then using the browser should be a standalone
@ -66,16 +67,17 @@ class Browser:
>>> await browser.close_tab()
"""
def __init__(self):
self.playwright: Optional[Playwright] = None
self.browser_instance: Optional[Browser_] = None
self.browser_ctx: Optional[BrowserContext] = None
self.page: Optional[Page] = None
self.accessibility_tree: list = []
self.headless: bool = True
self.proxy = get_proxy_from_env()
self.is_empty_page = True
self.reporter = BrowserReporter()
model_config = ConfigDict(arbitrary_types_allowed=True)
playwright: Optional[Playwright] = None
browser_instance: Optional[Browser_] = None
browser_ctx: Optional[BrowserContext] = None
page: Optional[Page] = None
accessibility_tree: list = Field(default_factory=list)
headless: bool = True
proxy: Optional[str] = Field(default_factory=get_proxy_from_env)
is_empty_page: bool = True
reporter: BrowserReporter = Field(default_factory=BrowserReporter)
async def start(self) -> None:
"""Starts Playwright and launches a browser"""

View file

@ -2,9 +2,8 @@ import os
import shutil
import subprocess
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.report import EditorReporter
@ -17,12 +16,12 @@ class FileBlock(BaseModel):
@register_tool()
class Editor:
class Editor(BaseModel):
"""A tool for reading, understanding, writing, and editing files"""
def __init__(self) -> None:
print(f"Editor initialized with root path at: {DEFAULT_WORKSPACE_ROOT}")
self.resource = EditorReporter()
model_config = ConfigDict(arbitrary_types_allowed=True)
resource: EditorReporter = EditorReporter()
def write(self, path: str, content: str):
"""Write the whole content to a file. When used, make sure content arg contains the full content of the file."""

View file

@ -64,6 +64,10 @@ class ToolRecommender(BaseModel):
@field_validator("tools", mode="before")
@classmethod
def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
# If `v` is already a dictionary (e.g., during deserialization), return it as is.
if isinstance(v, dict):
return v
# One can use special symbol ["<all>"] to indicate use of all registered tools
if v == ["<all>"]:
return TOOL_REGISTRY.get_all_tools()

View file

@ -28,6 +28,7 @@ import time
import traceback
from asyncio import iscoroutinefunction
from datetime import datetime
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
@ -577,13 +578,30 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
return data
def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
def write_json_file(json_file: str, data: Any, encoding: str = None, indent: int = 4, use_fallback: bool = False):
folder_path = Path(json_file).parent
if not folder_path.exists():
folder_path.mkdir(parents=True, exist_ok=True)
# For debug, if use_fallback, unknown values will be logged instead of raising an exception.
def fallback(x: Any) -> str:
tip = f"PydanticSerializationError occurred while processing file '{json_file}'"
if inspect.ismethod(x):
logger.error(f"{tip}, Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}")
elif inspect.isfunction(x):
logger.error(f"{tip}, Function: {x.__name__}")
elif hasattr(x, "__class__"):
logger.error(f"{tip}, Instance of: {x.__class__.__name__}")
elif hasattr(x, "__name__"):
logger.error(f"{tip}, Class or module: {x.__name__}")
else:
logger.error(f"{tip}, Unknown type: {type(x)}")
custom_default = partial(to_jsonable_python, fallback=fallback if use_fallback else None)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default)
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):

View file

@ -4,8 +4,11 @@
import copy
import pickle
from typing import Callable, Optional, Type
from metagpt.utils.common import import_class
from pydantic import BaseModel
from metagpt.utils.common import import_class, read_json_file, write_json_file
def actionoutout_schema_to_mapping(schema: dict) -> dict:
@ -81,3 +84,36 @@ def deserialize_message(message_ser: str) -> "Message":
message.instruct_content = ic_new
return message
def serialize_model(model: BaseModel, file_path: str, remove_unserializable: Optional[Callable[[dict], None]] = None):
"""Serializes a Pydantic model to a JSON file.
Args:
model (BaseModel): The Pydantic model to serialize.
file_path (str): The path to the JSON file where the model will be saved.
remove_unserializable (Optional[Callable[[dict], None]]): Optional function to remove unserializable content from the serialized data.
"""
serialized_data = model.model_dump()
if remove_unserializable:
remove_unserializable(serialized_data)
write_json_file(file_path, serialized_data)
def deserialize_model(cls: Type[BaseModel], file_path: str) -> BaseModel:
"""Deserializes a JSON file to a Pydantic model.
Args:
cls (Type[BaseModel]): The Pydantic model class to deserialize into.
file_path (str): The path to the JSON file to read from.
Returns:
BaseModel: An instance of the Pydantic model.
"""
data: dict = read_json_file(file_path)
return cls(**data)

View file

@ -6,13 +6,17 @@
from typing import List
from pydantic import BaseModel
from metagpt.actions import WritePRD
from metagpt.actions.action_node import ActionNode
from metagpt.schema import Message
from metagpt.utils.serialize import (
actionoutout_schema_to_mapping,
deserialize_message,
deserialize_model,
serialize_message,
serialize_model,
)
@ -66,3 +70,35 @@ def test_serialize_and_deserialize_message():
assert new_message.content == message.content
assert new_message.cause_by == message.cause_by
assert new_message.instruct_content.field1 == out_data["field1"]
class TestUserModel(BaseModel):
name: str
value: int
def test_serialize_model(mocker):
model = TestUserModel(name="test", value=42)
file_path = "test.json"
mock_write_json_file = mocker.patch("metagpt.utils.serialize.write_json_file")
# Test without remove_unserializable
serialize_model(model, file_path)
mock_write_json_file.assert_called_once_with(file_path, model.model_dump())
# Test with remove_unserializable
def remove_unserializable(data: dict):
data.pop("value", None)
serialize_model(model, file_path, remove_unserializable)
mock_write_json_file.assert_called_with(file_path, {"name": "test"})
def test_deserialize_model(mocker):
file_path = "test.json"
data = {"name": "test", "value": 42}
mock_read_json_file = mocker.patch("metagpt.utils.serialize.read_json_file", return_value=data)
model = deserialize_model(TestUserModel, file_path)
mock_read_json_file.assert_called_once_with(file_path)
assert model == TestUserModel(**data)