diff --git a/metagpt/actions/execute_code.py b/metagpt/actions/execute_code.py index 6e4a6fd6e..d192ca79a 100644 --- a/metagpt/actions/execute_code.py +++ b/metagpt/actions/execute_code.py @@ -12,6 +12,8 @@ import re import nbformat from nbclient import NotebookClient +from nbclient.exceptions import DeadKernelError, CellTimeoutError +from nbformat import NotebookNode from nbformat.v4 import new_code_cell, new_output from rich.console import Console from rich.syntax import Syntax @@ -46,13 +48,21 @@ class ExecuteCode(ABC): class ExecutePyCode(ExecuteCode, Action): """execute code, return result to llm, and display it.""" - def __init__(self, name: str = "python_executor", context=None, llm=None, nb=None): + def __init__( + self, + name: str = "python_executor", + context=None, + llm=None, + nb=None, + timeout: int = 600, + ): super().__init__(name, context, llm) if nb is None: self.nb = nbformat.v4.new_notebook() else: self.nb = nb - self.nb_client = NotebookClient(self.nb) + self.timeout = timeout + self.nb_client = NotebookClient(self.nb, timeout=self.timeout) self.console = Console() self.interaction = "ipython" if self.is_ipython() else "terminal" @@ -69,7 +79,8 @@ class ExecutePyCode(ExecuteCode, Action): async def reset(self): """reset NotebookClient""" await self.terminate() - self.nb_client = NotebookClient(self.nb) + await self.build() + self.nb_client = NotebookClient(self.nb, timeout=self.timeout) def add_code_cell(self, code): self.nb.cells.append(new_code_cell(source=code)) @@ -160,6 +171,19 @@ class ExecutePyCode(ExecuteCode, Action): return code, language + async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]: + """set timeout for run code""" + try: + await self.nb_client.async_execute_cell(cell, cell_index) + return True, "" + except CellTimeoutError: + return False, "TimeoutError" + except DeadKernelError: + await self.reset() + return False, "DeadKernelError" + except Exception as e: + return False, f"{traceback.format_exc()}" + async def run(self, code: Union[str, Dict, Message], language: str = "python") -> Tuple[str, bool]: code, language = self._process_code(code, language) @@ -168,19 +192,19 @@ class ExecutePyCode(ExecuteCode, Action): if language == "python": # add code to the notebook self.add_code_cell(code=code) - try: - # build code executor - await self.build() - # run code - # TODO: add max_tries for run code. - cell_index = len(self.nb.cells) - 1 - await self.nb_client.async_execute_cell(self.nb.cells[-1], cell_index) + + # build code executor + await self.build() + + # run code + cell_index = len(self.nb.cells) - 1 + success, error_message = await self.run_cell(self.nb.cells[-1], cell_index) + + if success: outputs = self.parse_outputs(self.nb.cells[-1].outputs) - success = True - except Exception as e: - outputs = traceback.format_exc() - success = False - return truncate(remove_escape_and_color_codes(outputs)), success + return truncate(remove_escape_and_color_codes(outputs)), True + else: + return error_message, False else: # TODO: markdown raise NotImplementedError(f"Not support this code type : {language}, Only support code!") diff --git a/tests/metagpt/actions/test_execute_code.py b/tests/metagpt/actions/test_execute_code.py index 95f883e12..8340272e4 100644 --- a/tests/metagpt/actions/test_execute_code.py +++ b/tests/metagpt/actions/test_execute_code.py @@ -88,3 +88,12 @@ def test_truncate(): assert truncate(output) == output output = "hello world" assert truncate(output, 5) == "Truncated to show only the last 5 characters\nworld" + + +@pytest.mark.asyncio +async def test_run_with_timeout(): + pi = ExecutePyCode(timeout=1) + code = "import time; time.sleep(2)" + message, success = await pi.run(code) + assert not success + assert message == "TimeoutError"