update: 添加nb支持

This commit is contained in:
stellahsr 2023-12-12 17:17:40 +08:00
parent 35c9d744a4
commit a4cef261e0
4 changed files with 32 additions and 1 deletions

1
.gitignore vendored
View file

@ -165,3 +165,4 @@ output.wav
metagpt/roles/idea_agent.py
.aider*
/config/config.yaml
/tests/metagpt/actions/check_data.py

View file

@ -93,7 +93,7 @@ class MLEngineer(Role):
summary = await SummarizeAnalysis().run(self.plan)
rsp = Message(content=summary, cause_by=SummarizeAnalysis)
self._rc.memory.add(rsp)
return rsp
async def _write_and_exec_code(self, max_retry: int = 3):

View file

@ -5,6 +5,8 @@
import os
import json
import nbformat
from metagpt.const import DATA_PATH
def save_code_file(name: str, code_context: str, file_format: str = "py") -> None:
@ -32,6 +34,8 @@ def save_code_file(name: str, code_context: str, file_format: str = "py") -> Non
data = {"code": code_context}
with open(file_path, "w", encoding="utf-8") as fp:
json.dump(data, fp, indent=2)
elif file_format == "ipynb":
nbformat.write(code_context, file_path)
else:
raise ValueError("Unsupported file format. Please choose 'py' or 'json'.")

View file

@ -2,8 +2,13 @@
# @Date : 12/12/2023 4:17 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import pytest
import os
import json
import nbformat
from metagpt.actions.write_analysis_code import WriteCodeByGenerate
from metagpt.actions.execute_code import ExecutePyCode
from metagpt.utils.save_code import save_code_file, DATA_PATH
@ -21,6 +26,7 @@ def test_save_code_file_python():
content = fp.read()
assert "print('Hello, World!')" in content, "File content does not match"
def test_save_code_file_json():
save_code_file("example_json", "print('Hello, JSON!')", file_format="json")
file_path = DATA_PATH / "output" / "example_json" / "code.json"
@ -28,3 +34,23 @@ def test_save_code_file_json():
data = json.load(fp)
assert "code" in data, "JSON key 'code' is missing"
assert data["code"] == "print('Hello, JSON!')", "JSON content does not match"
@pytest.mark.asyncio
async def test_save_code_file_notebook():
code = await WriteCodeByGenerate().run(
context="basic python, hello world", plan="", code_steps="", temperature=0.0
)
executor = ExecutePyCode()
await executor.run(code)
# Save as a Notebook file
save_code_file("example_nb", executor.nb, file_format="ipynb")
file_path = DATA_PATH / "output" / "example_nb" / "code.ipynb"
assert os.path.exists(file_path), f"Notebook file does not exist: {file_path}"
# Additional checks specific to notebook format
notebook = nbformat.read(file_path, as_version=4)
assert len(notebook.cells) > 0, "Notebook should have at least one cell"
first_cell_source = notebook.cells[0].source
assert "print('Hello, World!')" in first_cell_source, "Notebook cell content does not match"