Merge branch 'mgx_ops' of https://gitlab.deepwisdomai.com/pub/MetaGPT into check_role_zero

This commit is contained in:
garylin2099 2024-08-13 19:38:31 +08:00
commit e52672aea5
20 changed files with 311 additions and 70 deletions

View file

@ -0,0 +1,25 @@
from metagpt.environment.mgx.mgx_env import MGXEnv
from metagpt.logs import logger
def main():
"""Demonstrates serialization and deserialization using SerializationMixin.
This example creates an instance of MGXEnv, serializes it to a file,
and then deserializes it back to an instance.
If executed correctly, the following log messages will be output:
MGXEnv serialization successful. File saved at: /.../workspace/storage/MGXEnv.json
MGXEnv deserialization successful. Instance created from file: /.../workspace/storage/MGXEnv.json
The instance is MGXEnv()
"""
env = MGXEnv()
env.serialize()
env: MGXEnv = MGXEnv.deserialize()
logger.info(f"The instance is {repr(env)}")
if __name__ == "__main__":
main()

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import json
from pydantic import Field, model_validator
from pydantic import Field, PrivateAttr, model_validator
from metagpt.actions import Action
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize
@ -90,6 +90,8 @@ class SearchEnhancedQA(Action):
description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.",
)
_reporter: ThoughtReporter = PrivateAttr(ThoughtReporter())
@model_validator(mode="after")
def initialize(self):
if self.web_browse_and_summarize_action is None:
@ -118,13 +120,14 @@ class SearchEnhancedQA(Action):
Raises:
ValueError: If the query is invalid.
"""
async with self._reporter:
await self._reporter.async_report({"type": "search", "stage": "init"})
self._validate_query(query)
self._validate_query(query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
return await self._generate_answer(processed_query, context)
return await self._generate_answer(processed_query, context)
def _validate_query(self, query: str) -> None:
"""Validate the input query.
@ -203,6 +206,7 @@ class SearchEnhancedQA(Action):
"""
relevant_urls = await self._collect_relevant_links(query)
await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls})
if not relevant_urls:
logger.warning(f"No relevant URLs found for query: {query}")
return []
@ -245,10 +249,12 @@ class SearchEnhancedQA(Action):
contents = await self._fetch_web_contents(urls)
summaries = {}
await self._reporter.async_report(
{"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]}
)
for content in contents:
url = content.url
inner_text = content.inner_text.replace("\n", "")
if self.web_browse_and_summarize_action._is_content_invalid(inner_text):
logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...")
continue
@ -276,8 +282,7 @@ class SearchEnhancedQA(Action):
system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context)
async with ThoughtReporter(enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "quick"})
async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "search", "stage": "answer"})
rsp = await self._aask(query, [system_prompt])
return rsp

View file

@ -7,7 +7,7 @@
"""
from typing import Callable, Optional
from pydantic import Field
from pydantic import ConfigDict, Field
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
@ -16,10 +16,11 @@ from metagpt.utils.yaml_model import YamlModel
class SearchConfig(YamlModel):
"""Config for Search"""
model_config = ConfigDict(extra="allow")
api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO
api_key: str = ""
cse_id: str = "" # for google
discovery_service_url: str = "" # for google
search_func: Optional[Callable] = None
params: dict = Field(
default_factory=lambda: {

View file

@ -10,11 +10,11 @@ from metagpt.const import AGENT
from metagpt.environment.base_env import Environment
from metagpt.logs import get_human_input
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
from metagpt.schema import Message
from metagpt.schema import Message, SerializationMixin
from metagpt.utils.common import any_to_str, any_to_str_set
class MGXEnv(Environment):
class MGXEnv(Environment, SerializationMixin):
"""MGX Environment"""
# If True, fixed software sop bypassing TL is allowed, otherwise, TL will fully take over the routing

View file

@ -208,6 +208,6 @@ class CodeReview(Action):
comments = await self.confirm_comments(patch=patch, comments=comments, points=points)
for comment in comments:
if comment["code"]:
if not (comment["code"].startswith("-") or comment["code"].isspace()):
if not (comment["code"].isspace()):
result.append(comment)
return result

View file

@ -152,7 +152,7 @@ Respond with a concise thought, then provide the appropriate response category:
"""
QUICK_THINK_EXAMPLES ="""
QUICK_THINK_EXAMPLES = """
# Example
1. Request: "How do I design an online document editing platform that supports real-time collaboration?"
@ -190,4 +190,4 @@ Response Category: AMBIGUOUS.
# Instruction
"""
QUICK_THINK_PROMPT = QUICK_THINK_PROMPT.format(examples=QUICK_THINK_EXAMPLES)
QUICK_THINK_PROMPT = QUICK_THINK_PROMPT.format(examples=QUICK_THINK_EXAMPLES)

View file

@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Annotated
from pydantic import Field, model_validator
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
@ -31,7 +33,7 @@ class DataAnalyst(RoleZero):
tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser"]
custom_tools: list[str] = ["web scraping", "Terminal"]
custom_tool_recommender: ToolRecommender = None
experience_retriever: ExpRetriever = KeywordExpRetriever()
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever()
use_reflection: bool = True
write_code: WriteAnalysisCode = Field(default_factory=WriteAnalysisCode, exclude=True)

View file

@ -4,9 +4,9 @@ import inspect
import json
import re
import traceback
from typing import Callable, Dict, List, Literal, Tuple
from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple
from pydantic import model_validator
from pydantic import Field, model_validator
from metagpt.actions import Action, UserRequirement
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
@ -41,7 +41,11 @@ from metagpt.utils.common import (
extract_image_paths,
is_support_image_input,
)
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
from metagpt.utils.repair_llm_raw_output import (
RepairType,
repair_escape_error,
repair_llm_raw_output,
)
from metagpt.utils.report import ThoughtReporter
@ -53,12 +57,13 @@ class RoleZero(Role):
name: str = "Zero"
profile: str = "RoleZero"
goal: str = ""
system_msg: Optional[list[str]] = None # Use None to conform to the default value at llm.aask
system_prompt: str = SYSTEM_PROMPT # Use None to conform to the default value at llm.aask
cmd_prompt: str = CMD_PROMPT
cmd_prompt_current_state: str = ""
thought_guidance: str = THOUGHT_GUIDANCE
instruction: str = ROLE_INSTRUCTION
task_type_desc: str = None
task_type_desc: Optional[str] = None
# React Mode
react_mode: Literal["react"] = "react"
@ -66,15 +71,15 @@ class RoleZero(Role):
# Tools
tools: list[str] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tool_recommender: ToolRecommender = None
tool_execution_map: dict[str, Callable] = {}
tool_recommender: Optional[ToolRecommender] = None
tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {}
special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"]
# Equipped with three basic tools by default for optional use
editor: Editor = Editor()
browser: Browser = Browser()
# Experience
experience_retriever: ExpRetriever = DummyExpRetriever()
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = DummyExpRetriever()
# Others
command_rsp: str = "" # the raw string containing the commands
@ -330,10 +335,20 @@ class RoleZero(Role):
if commands.endswith("]") and not commands.startswith("["):
commands = "[" + commands
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError:
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON for: {self.command_rsp}. Trying to repair...")
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
commands = await self.llm.aask(
msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp, json_decode_error=str(e))
)
try:
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except json.JSONDecodeError:
# repair escape error of code and math
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
new_command = repair_escape_error(commands)
commands = json.loads(
repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON)
)
except Exception as e:
tb = traceback.format_exc()
print(tb)

View file

@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Annotated
from pydantic import Field
from metagpt.actions.di.run_command import RunCommand
from metagpt.prompts.di.team_leader import (
FINISH_CURRENT_TASK_CMD,
@ -24,7 +28,7 @@ class TeamLeader(RoleZero):
tools: list[str] = ["Plan", "RoleZero", "TeamLeader"]
experience_retriever: ExpRetriever = SimpleExpRetriever()
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = SimpleExpRetriever()
def _update_tool_execution(self):
self.tool_execution_map.update(

View file

@ -44,6 +44,7 @@ from metagpt.const import (
MESSAGE_ROUTE_FROM,
MESSAGE_ROUTE_TO,
MESSAGE_ROUTE_TO_ALL,
SERDESER_PATH,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
@ -56,6 +57,8 @@ from metagpt.utils.common import (
any_to_str_set,
aread,
import_class,
read_json_file,
write_json_file,
)
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.report import TaskReporter
@ -127,6 +130,65 @@ class SerializationMixin(BaseModel, extra="forbid"):
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
super().__init_subclass__(**kwargs)
@handle_exception
def serialize(self, file_path: str = None) -> str:
"""Serializes the current instance to a JSON file.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
Returns:
str: The path to the JSON file where the instance was saved.
"""
file_path = file_path or self.get_serialization_path()
serialized_data = self.model_dump()
write_json_file(file_path, serialized_data)
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
return file_path
@classmethod
@handle_exception
def deserialize(cls, file_path: str = None) -> BaseModel:
"""Deserializes a JSON file to an instance of cls.
If an exception occurs, `handle_exception` will catch it and return `None`.
Args:
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
Returns:
An instance of the cls.
"""
file_path = file_path or cls.get_serialization_path()
data: dict = read_json_file(file_path)
model = cls(**data)
logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
return model
@classmethod
def get_serialization_path(cls) -> str:
"""Get the serialization path for the class.
This method constructs a file path for serialization based on the class name.
The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName'
is the name of the class.
Returns:
str: The path to the serialization file.
"""
return str(SERDESER_PATH / f"{cls.__qualname__}.json")
class SimpleMessage(BaseModel):
content: str

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
class ExpRetriever(BaseModel):
"""interface for experience retriever"""
def retrieve(self, context: str) -> str:
def retrieve(self, context: str = "") -> str:
raise NotImplementedError

View file

@ -12,6 +12,7 @@ from playwright.async_api import (
Request,
async_playwright,
)
from pydantic import BaseModel, ConfigDict, Field
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.a11y_tree import (
@ -43,7 +44,7 @@ from metagpt.utils.report import BrowserReporter
"type",
],
)
class Browser:
class Browser(BaseModel):
"""A tool for browsing the web. Don't initialize a new instance of this class if one already exists.
Note: If you plan to use the browser to assist you in completing tasks, then using the browser should be a standalone
@ -66,16 +67,17 @@ class Browser:
>>> await browser.close_tab()
"""
def __init__(self):
self.playwright: Optional[Playwright] = None
self.browser_instance: Optional[Browser_] = None
self.browser_ctx: Optional[BrowserContext] = None
self.page: Optional[Page] = None
self.accessibility_tree: list = []
self.headless: bool = True
self.proxy = get_proxy_from_env()
self.is_empty_page = True
self.reporter = BrowserReporter()
model_config = ConfigDict(arbitrary_types_allowed=True)
playwright: Optional[Playwright] = None
browser_instance: Optional[Browser_] = None
browser_ctx: Optional[BrowserContext] = None
page: Optional[Page] = None
accessibility_tree: list = Field(default_factory=list)
headless: bool = True
proxy: Optional[str] = Field(default_factory=get_proxy_from_env)
is_empty_page: bool = True
reporter: BrowserReporter = Field(default_factory=BrowserReporter)
async def start(self) -> None:
"""Starts Playwright and launches a browser"""

View file

@ -5,9 +5,8 @@ import subprocess
from pathlib import Path
from typing import List, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.tools.tool_registry import register_tool
from metagpt.utils import read_docx
@ -24,12 +23,12 @@ class FileBlock(BaseModel):
@register_tool()
class Editor:
class Editor(BaseModel):
"""A tool for reading, understanding, writing, and editing files"""
def __init__(self) -> None:
print(f"Editor initialized with root path at: {DEFAULT_WORKSPACE_ROOT}")
self.resource = EditorReporter()
model_config = ConfigDict(arbitrary_types_allowed=True)
resource: EditorReporter = EditorReporter()
def write(self, path: str, content: str):
"""Write the whole content to a file. When used, make sure content arg contains the full content of the file."""

View file

@ -31,6 +31,10 @@ async def default_get_env(key: str, app_name: str = None) -> str:
if app_key in os.environ:
return os.environ[app_key]
env_app_key = app_key.replace("-", "_") # "-" is not supported by linux environment variable
if env_app_key in os.environ:
return os.environ[env_app_key]
from metagpt.context import Context
context = Context()

View file

@ -6,7 +6,7 @@
@File : search_engine_serpapi.py
"""
import warnings
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -24,6 +24,7 @@ class SerpAPIWrapper(BaseModel):
"hl": "en",
}
)
url: str = "https://serpapi.com/search"
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@ -49,22 +50,18 @@ class SerpAPIWrapper(BaseModel):
async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
params["output"] = "json"
url = "https://serpapi.com/search"
return url, params
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
params["output"] = "json"
url, params = construct_url_and_params()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, proxy=self.proxy) as response:
async with session.get(self.url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get(url, params=params, proxy=self.proxy) as response:
async with self.aiosession.get(self.url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()

View file

@ -7,7 +7,7 @@
"""
import json
import warnings
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -17,6 +17,7 @@ class SerperWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
api_key: str
url: str = "https://google.serper.dev/search"
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@ -33,6 +34,7 @@ class SerperWrapper(BaseModel):
"To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
"an API key from https://serper.dev/."
)
return values
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
@ -46,20 +48,16 @@ class SerperWrapper(BaseModel):
async def results(self, queries: list[str], max_results: int = 8) -> dict:
"""Use aiohttp to run query through Serper and return the results async."""
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
payloads = self.get_payloads(queries, max_results)
url = "https://google.serper.dev/search"
headers = self.get_headers()
return url, payloads, headers
payloads = self.get_payloads(queries, max_results)
headers = self.get_headers()
url, payloads, headers = construct_url_and_payload_and_headers()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
async with session.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
async with self.aiosession.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()

View file

@ -64,6 +64,10 @@ class ToolRecommender(BaseModel):
@field_validator("tools", mode="before")
@classmethod
def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
# If `v` is already a dictionary (e.g., during deserialization), return it as is.
if isinstance(v, dict):
return v
# One can use special symbol ["<all>"] to indicate use of all registered tools
if v == ["<all>"]:
return TOOL_REGISTRY.get_all_tools()

View file

@ -28,6 +28,7 @@ import time
import traceback
from asyncio import iscoroutinefunction
from datetime import datetime
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
@ -577,13 +578,32 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
return data
def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
def handle_unknown_serialization(x: Any) -> str:
"""For `to_jsonable_python` debug, unknown values will be logged instead of raising an exception."""
if inspect.ismethod(x):
logger.error(f"Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}")
elif inspect.isfunction(x):
logger.error(f"Function: {x.__name__}")
elif hasattr(x, "__class__"):
logger.error(f"Instance of: {x.__class__.__name__}")
elif hasattr(x, "__name__"):
logger.error(f"Class or module: {x.__name__}")
else:
logger.error(f"Unknown type: {type(x)}")
return f"<Unserializable {type(x).__name__} object>"
def write_json_file(json_file: str, data: Any, encoding: str = None, indent: int = 4, use_fallback: bool = False):
folder_path = Path(json_file).parent
if not folder_path.exists():
folder_path.mkdir(parents=True, exist_ok=True)
custom_default = partial(to_jsonable_python, fallback=handle_unknown_serialization if use_fallback else None)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default)
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):

View file

@ -347,3 +347,44 @@ def extract_state_value_from_output(content: str) -> str:
matches = list(set(matches))
state = matches[0] if len(matches) > 0 else "-1"
return state
def repair_escape_error(commands):
"""
Repaires escape errors in command responses.
When RoleZero parses a command, the command may contain unknown escape characters.
This function has two steps:
1. Transform unescaped substrings like "\d" and "\(" to "\\\\d" and "\\\\(".
2. Transform escaped characters like '\f' to substrings like "\\\\f".
Example:
When the original JSON string is " {"content":"\\\\( \\\\frac{1}{2} \\\\)"} ",
The "content" will be parsed correctly to "\( \frac{1}{2} \)".
However, if the original JSON string is " {"content":"\( \frac{1}{2} \)"}" directly.
It will cause a parsing error.
To repair the wrong JSON string, the following transformations will be used:
"\(" ---> "\\\\("
'\f' ---> "\\\\f"
"\)" ---> "\\\\)"
"""
escape_repair_map = {
"\a": "\\\\a",
"\b": "\\\\b",
"\f": "\\\\f",
"\r": "\\\\r",
"\t": "\\\\t",
"\v": "\\\\v",
}
new_command = ""
for index, ch in enumerate(commands):
if ch == "\\" and index + 1 < len(commands):
if commands[index + 1] not in ["n", '"', " "]:
new_command += "\\"
elif ch in escape_repair_map:
ch = escape_repair_map[ch]
new_command += ch
return new_command

View file

@ -9,13 +9,15 @@
"""
import json
from typing import Annotated
import pytest
from pydantic import BaseModel, Field
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
CodeSummarizeContext,
@ -23,6 +25,7 @@ from metagpt.schema import (
Message,
MessageQueue,
Plan,
SerializationMixin,
SystemMessage,
Task,
UMLClassAttribute,
@ -398,5 +401,64 @@ def test_create_instruct_value(name, value):
assert obj.model_dump() == value
class TestUserModel(SerializationMixin, BaseModel):
name: str
value: int
class TestUserModelWithExclude(TestUserModel):
age: Annotated[int, Field(exclude=True)]
class TestSerializationMixin:
@pytest.fixture
def mock_write_json_file(self, mocker):
return mocker.patch("metagpt.schema.write_json_file")
@pytest.fixture
def mock_read_json_file(self, mocker):
return mocker.patch("metagpt.schema.read_json_file")
@pytest.fixture
def mock_user_model(self):
return TestUserModel(name="test", value=42)
def test_serialize(self, mock_write_json_file, mock_user_model):
file_path = "test.json"
mock_user_model.serialize(file_path)
mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump())
def test_deserialize(self, mock_read_json_file):
file_path = "test.json"
data = {"name": "test", "value": 42}
mock_read_json_file.return_value = data
model = TestUserModel.deserialize(file_path)
mock_read_json_file.assert_called_once_with(file_path)
assert model == TestUserModel(**data)
def test_serialize_with_exclude(self, mock_write_json_file):
model = TestUserModelWithExclude(name="test", value=42, age=10)
file_path = "test.json"
model.serialize(file_path)
expected_data = {
"name": "test",
"value": 42,
"__module_class_name": "tests.metagpt.test_schema.TestUserModelWithExclude",
}
mock_write_json_file.assert_called_once_with(file_path, expected_data)
def test_get_serialization_path(self):
expected_path = str(SERDESER_PATH / "TestUserModel.json")
assert TestUserModel.get_serialization_path() == expected_path
if __name__ == "__main__":
pytest.main([__file__, "-s"])