diff --git a/metagpt/actions/execute_nb_code.py b/metagpt/actions/execute_nb_code.py index 835233dfa..ee2faa0cb 100644 --- a/metagpt/actions/execute_nb_code.py +++ b/metagpt/actions/execute_nb_code.py @@ -8,8 +8,7 @@ import asyncio import base64 import re import traceback -from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import List, Literal, Tuple import nbformat from nbclient import NotebookClient @@ -25,14 +24,13 @@ from rich.syntax import Syntax from metagpt.actions import Action from metagpt.logs import logger -from metagpt.schema import Message class ExecuteNbCode(Action): """execute notebook code block, return result to llm, and display it.""" - nb: Any - nb_client: Any + nb: NotebookNode + nb_client: NotebookClient console: Console interaction: str timeout: int = 600 @@ -70,13 +68,13 @@ class ExecuteNbCode(Action): await self.build() self.nb_client = NotebookClient(self.nb, timeout=self.timeout) - def add_code_cell(self, code): + def add_code_cell(self, code: str): self.nb.cells.append(new_code_cell(source=code)) - def add_markdown_cell(self, markdown): + def add_markdown_cell(self, markdown: str): self.nb.cells.append(new_markdown_cell(source=markdown)) - def _display(self, code, language: str = "python"): + def _display(self, code: str, language: Literal["python", "markdown"] = "python"): if language == "python": code = Syntax(code, "python", theme="paraiso-dark", line_numbers=True) self.console.print(code) @@ -85,21 +83,18 @@ class ExecuteNbCode(Action): else: raise ValueError(f"Only support for python, markdown, but got {language}") - def add_output_to_cell(self, cell, output): + def add_output_to_cell(self, cell: NotebookNode, output: str): + """add outputs of code execution to notebook cell.""" if "outputs" not in cell: cell["outputs"] = [] - # TODO: show figures else: cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output))) - def parse_outputs(self, outputs: List) -> str: + def parse_outputs(self, outputs: List[str]) -> str: + """Parses the outputs received from notebook execution.""" assert isinstance(outputs, list) parsed_output = "" - # empty outputs: such as 'x=1\ny=2' - if not outputs: - return parsed_output - for i, output in enumerate(outputs): if output["output_type"] == "stream" and not any( tag in output["text"] @@ -117,7 +112,7 @@ class ExecuteNbCode(Action): parsed_output += output["data"]["text/plain"] return parsed_output - def show_bytes_figure(self, image_base64: str, interaction_type: str = "ipython"): + def show_bytes_figure(self, image_base64: str, interaction_type: Literal["ipython", None]): image_bytes = base64.b64decode(image_base64) if interaction_type == "ipython": from IPython.display import Image, display @@ -141,25 +136,12 @@ class ExecuteNbCode(Action): else: return False except NameError: - # 如果在Python脚本中运行,__file__ 变量存在 return False - def _process_code(self, code: Union[str, Dict], language: str = "python") -> Tuple: - """handle different code response formats, support str or dict""" - if isinstance(code, str) and Path(code).suffix in (".py", ".txt"): - code = Path(code).read_text(encoding="utf-8") - return code, language - - if isinstance(code, str): - return code, language - - if isinstance(code, dict): - assert "code" in code - code = code["code"] - return code, language - async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]: - """set timeout for run code""" + """set timeout for run code. + returns the success or failure of the cell execution, and an optional error message. + """ try: await self.nb_client.async_execute_cell(cell, cell_index) return True, "" @@ -175,9 +157,10 @@ class ExecuteNbCode(Action): except Exception: return False, f"{traceback.format_exc()}" - async def run(self, code: Union[str, Dict, Message], language: str = "python") -> Tuple[str, bool]: - code, language = self._process_code(code, language) - + async def run(self, code: str, language: Literal["python", "markdown"] = "python") -> Tuple[str, bool]: + """ + return the output of code execution, and a success indicator (bool) of code execution. + """ self._display(code, language) if language == "python": @@ -198,8 +181,9 @@ class ExecuteNbCode(Action): outputs = self.parse_outputs(self.nb.cells[-1].outputs) return truncate(remove_escape_and_color_codes(outputs), is_success=success) elif language == "markdown": - # markdown + # add markdown content to markdown cell in a notebook. self.add_markdown_cell(code) + # return True, beacuse there is no execution failure for markdown cell. return code, True else: raise ValueError(f"Only support for language: python, markdown, but got {language}, ") @@ -230,7 +214,7 @@ def truncate(result: str, keep_len: int = 2000, is_success: bool = True): return result if not is_same_desc else desc + result, is_success -def remove_escape_and_color_codes(input_str): +def remove_escape_and_color_codes(input_str: str): # 使用正则表达式去除转义字符和颜色代码 pattern = re.compile(r"\x1b\[[0-9;]*[mK]") result = pattern.sub("", input_str) diff --git a/tests/metagpt/actions/test_execute_nb_code.py b/tests/metagpt/actions/test_execute_nb_code.py index 719d14089..d1b40c350 100644 --- a/tests/metagpt/actions/test_execute_nb_code.py +++ b/tests/metagpt/actions/test_execute_nb_code.py @@ -8,8 +8,6 @@ async def test_code_running(): executor = ExecuteNbCode() output, is_success = await executor.run("print('hello world!')") assert is_success - output, is_success = await executor.run({"code": "print('hello world!')", "language": "python"}) - assert is_success @pytest.mark.asyncio