Merge branch 'dev' into fix_truncate

This commit is contained in:
刘棒棒 2023-12-13 16:07:35 +08:00
commit 233b143da8
5 changed files with 112 additions and 12 deletions

View file

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# @Date : 12/12/2023 4:17 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import pytest
import os
import json
import nbformat
from metagpt.actions.write_analysis_code import WriteCodeByGenerate
from metagpt.actions.execute_code import ExecutePyCode
from metagpt.utils.save_code import save_code_file, DATA_PATH
def test_save_code_file_python():
save_code_file("example", "print('Hello, World!')")
file_path = DATA_PATH / "output" / "example" / "code.py"
assert os.path.exists(file_path), f"File does not exist: {file_path}"
def test_save_code_file_python():
save_code_file("example", "print('Hello, World!')")
file_path = DATA_PATH / "output" / "example" / "code.py"
with open(file_path, "r", encoding="utf-8") as fp:
content = fp.read()
assert "print('Hello, World!')" in content, "File content does not match"
def test_save_code_file_json():
save_code_file("example_json", "print('Hello, JSON!')", file_format="json")
file_path = DATA_PATH / "output" / "example_json" / "code.json"
with open(file_path, "r", encoding="utf-8") as fp:
data = json.load(fp)
assert "code" in data, "JSON key 'code' is missing"
assert data["code"] == "print('Hello, JSON!')", "JSON content does not match"
@pytest.mark.asyncio
async def test_save_code_file_notebook():
code = await WriteCodeByGenerate().run(
context="basic python, hello world", plan="", code_steps="", temperature=0.0
)
executor = ExecutePyCode()
await executor.run(code)
# Save as a Notebook file
save_code_file("example_nb", executor.nb, file_format="ipynb")
file_path = DATA_PATH / "output" / "example_nb" / "code.ipynb"
assert os.path.exists(file_path), f"Notebook file does not exist: {file_path}"
# Additional checks specific to notebook format
notebook = nbformat.read(file_path, as_version=4)
assert len(notebook.cells) > 0, "Notebook should have at least one cell"
first_cell_source = notebook.cells[0].source
assert "print('Hello, World!')" in first_cell_source, "Notebook cell content does not match"