diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 55f4ce378..9d6a6bb24 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -485,13 +485,13 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: return data -def write_json_file(json_file: str, data: list, encoding=None): +def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) def import_class(class_name: str, module_name: str) -> type: diff --git a/metagpt/utils/save_code.py b/metagpt/utils/save_code.py index d55b058e6..18cb5cd62 100644 --- a/metagpt/utils/save_code.py +++ b/metagpt/utils/save_code.py @@ -2,12 +2,12 @@ # @Date : 12/12/2023 4:14 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : -import json import os import nbformat from metagpt.const import DATA_PATH +from metagpt.utils.common import write_json_file def save_code_file(name: str, code_context: str, file_format: str = "py") -> None: @@ -33,8 +33,7 @@ def save_code_file(name: str, code_context: str, file_format: str = "py") -> Non elif file_format == "json": # Parse the code content as JSON and save data = {"code": code_context} - with open(file_path, "w", encoding="utf-8") as fp: - json.dump(data, fp, indent=2) + write_json_file(file_path, data, encoding="utf-8", indent=2) elif file_format == "ipynb": nbformat.write(code_context, file_path) else: diff --git a/tests/metagpt/utils/test_save_code.py b/tests/metagpt/utils/test_save_code.py index bb0b07d63..62724dde5 100644 --- a/tests/metagpt/utils/test_save_code.py +++ b/tests/metagpt/utils/test_save_code.py @@ -2,30 +2,27 @@ # @Date : 12/12/2023 4:17 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : -import json -import os import nbformat import pytest from metagpt.actions.execute_nb_code import ExecuteNbCode +from metagpt.utils.common import read_json_file from metagpt.utils.save_code import DATA_PATH, save_code_file def test_save_code_file_python(): save_code_file("example", "print('Hello, World!')") file_path = DATA_PATH / "output" / "example" / "code.py" - assert os.path.exists(file_path), f"File does not exist: {file_path}" - with open(file_path, "r", encoding="utf-8") as fp: - content = fp.read() + assert file_path.exists, f"File does not exist: {file_path}" + content = file_path.read_text() assert "print('Hello, World!')" in content, "File content does not match" def test_save_code_file_json(): save_code_file("example_json", "print('Hello, JSON!')", file_format="json") file_path = DATA_PATH / "output" / "example_json" / "code.json" - with open(file_path, "r", encoding="utf-8") as fp: - data = json.load(fp) + data = read_json_file(file_path) assert "code" in data, "JSON key 'code' is missing" assert data["code"] == "print('Hello, JSON!')", "JSON content does not match" @@ -38,7 +35,7 @@ async def test_save_code_file_notebook(): # Save as a Notebook file save_code_file("example_nb", executor.nb, file_format="ipynb") file_path = DATA_PATH / "output" / "example_nb" / "code.ipynb" - assert os.path.exists(file_path), f"Notebook file does not exist: {file_path}" + assert file_path.exists, f"Notebook file does not exist: {file_path}" # Additional checks specific to notebook format notebook = nbformat.read(file_path, as_version=4)