Merge branch 'dev_code_exec_timeout' into 'dev'

add timeout and retry when code execution

See merge request agents/data_agents_opt!37
This commit is contained in:
林义章 2024-01-10 02:36:08 +00:00
commit bbc23c3db2
2 changed files with 48 additions and 15 deletions

View file

@ -12,6 +12,8 @@ import re
import nbformat
from nbclient import NotebookClient
from nbclient.exceptions import DeadKernelError, CellTimeoutError
from nbformat import NotebookNode
from nbformat.v4 import new_code_cell, new_output
from rich.console import Console
from rich.syntax import Syntax
@ -46,13 +48,21 @@ class ExecuteCode(ABC):
class ExecutePyCode(ExecuteCode, Action):
"""execute code, return result to llm, and display it."""
def __init__(self, name: str = "python_executor", context=None, llm=None, nb=None):
def __init__(
self,
name: str = "python_executor",
context=None,
llm=None,
nb=None,
timeout: int = 600,
):
super().__init__(name, context, llm)
if nb is None:
self.nb = nbformat.v4.new_notebook()
else:
self.nb = nb
self.nb_client = NotebookClient(self.nb)
self.timeout = timeout
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
self.console = Console()
self.interaction = "ipython" if self.is_ipython() else "terminal"
@ -69,7 +79,8 @@ class ExecutePyCode(ExecuteCode, Action):
async def reset(self):
"""reset NotebookClient"""
await self.terminate()
self.nb_client = NotebookClient(self.nb)
await self.build()
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
def add_code_cell(self, code):
self.nb.cells.append(new_code_cell(source=code))
@ -160,6 +171,19 @@ class ExecutePyCode(ExecuteCode, Action):
return code, language
async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]:
"""set timeout for run code"""
try:
await self.nb_client.async_execute_cell(cell, cell_index)
return True, ""
except CellTimeoutError:
return False, "TimeoutError"
except DeadKernelError:
await self.reset()
return False, "DeadKernelError"
except Exception as e:
return False, f"{traceback.format_exc()}"
async def run(self, code: Union[str, Dict, Message], language: str = "python") -> Tuple[str, bool]:
code, language = self._process_code(code, language)
@ -168,19 +192,19 @@ class ExecutePyCode(ExecuteCode, Action):
if language == "python":
# add code to the notebook
self.add_code_cell(code=code)
try:
# build code executor
await self.build()
# run code
# TODO: add max_tries for run code.
cell_index = len(self.nb.cells) - 1
await self.nb_client.async_execute_cell(self.nb.cells[-1], cell_index)
# build code executor
await self.build()
# run code
cell_index = len(self.nb.cells) - 1
success, error_message = await self.run_cell(self.nb.cells[-1], cell_index)
if success:
outputs = self.parse_outputs(self.nb.cells[-1].outputs)
success = True
except Exception as e:
outputs = traceback.format_exc()
success = False
return truncate(remove_escape_and_color_codes(outputs)), success
return truncate(remove_escape_and_color_codes(outputs)), True
else:
return error_message, False
else:
# TODO: markdown
raise NotImplementedError(f"Not support this code type : {language}, Only support code!")

View file

@ -88,3 +88,12 @@ def test_truncate():
assert truncate(output) == output
output = "hello world"
assert truncate(output, 5) == "Truncated to show only the last 5 characters\nworld"
@pytest.mark.asyncio
async def test_run_with_timeout():
pi = ExecutePyCode(timeout=1)
code = "import time; time.sleep(2)"
message, success = await pi.run(code)
assert not success
assert message == "TimeoutError"