From ff5dbfbc521d6e89c418acee21426b4e906a74a3 Mon Sep 17 00:00:00 2001 From: seeker Date: Fri, 5 Jul 2024 19:53:05 +0800 Subject: [PATCH] =?UTF-8?q?update:=20terminal=20run=20command=20=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E5=BC=82=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/roles/di/swe_agent.py | 8 +- metagpt/tools/libs/terminal.py | 118 ++++++++++-------- .../roles/di/run_swe_agent_for_benchmark.py | 10 +- tests/metagpt/tools/libs/test_terminal.py | 11 +- 4 files changed, 81 insertions(+), 66 deletions(-) diff --git a/metagpt/roles/di/swe_agent.py b/metagpt/roles/di/swe_agent.py index 2e1fb6412..f31300c3f 100644 --- a/metagpt/roles/di/swe_agent.py +++ b/metagpt/roles/di/swe_agent.py @@ -29,7 +29,7 @@ class SWEAgent(RoleZero): async def _think(self) -> bool: self._update_system_msg() - self._format_instruction() + await self._format_instruction() res = await super()._think() if self.run_eval: await self._parse_commands_for_eval() @@ -46,14 +46,14 @@ class SWEAgent(RoleZero): self._bash_window_size = int(os.getenv("WINDOW")) self.system_msg = [self._system_msg.format(WINDOW=self._bash_window_size)] - def _format_instruction(self): + async def _format_instruction(self): """ Formats the instruction message for the SWE agent. Runs the "state" command in the terminal, parses its output as JSON, and uses it to format the `_instruction` template. """ - state_output = self.terminal.run("state") + state_output = await self.terminal.run("state") bash_state = json.loads(state_output) self.instruction = self._instruction.format( @@ -81,7 +81,7 @@ class SWEAgent(RoleZero): if "end" != cmd.get("command_name", ""): return try: - diff_output = self.terminal.run("git diff --cached") + diff_output = await self.terminal.run("git diff --cached") clear_diff = extract_patch(diff_output) logger.info(f"Diff output: \n{clear_diff}") if clear_diff: diff --git a/metagpt/tools/libs/terminal.py b/metagpt/tools/libs/terminal.py index bcf039a5e..73d5e72cf 100644 --- a/metagpt/tools/libs/terminal.py +++ b/metagpt/tools/libs/terminal.py @@ -1,8 +1,10 @@ -import subprocess -import threading -from queue import Queue +import asyncio +from asyncio import Queue +from asyncio.subprocess import PIPE +from typing import Optional from metagpt.const import DEFAULT_WORKSPACE_ROOT, SWE_SETUP_PATH +from metagpt.logs import logger from metagpt.tools.tool_registry import register_tool from metagpt.utils.report import END_MARKER_VALUE, TerminalReporter @@ -19,62 +21,54 @@ class Terminal: def __init__(self): self.shell_command = ["bash"] # FIXME: should consider windows support later self.command_terminator = "\n" - - # Start a persistent shell process - self.process = subprocess.Popen( - self.shell_command, - shell=True, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - executable="/bin/bash", - ) - self.stdout_queue = Queue() + self.stdout_queue = Queue(maxsize=1000) self.observer = TerminalReporter() + self.process: Optional[asyncio.subprocess.Process] = None - self._check_state() + async def _start_process(self): + # Start a persistent shell process + self.process = await asyncio.create_subprocess_exec( + *self.shell_command, stdin=PIPE, stdout=PIPE, stderr=PIPE, executable="/bin/bash" + ) + await self._check_state() - def _check_state(self): - """Check the state of the terminal, e.g. the current directory of the terminal process. Useful for agent to understand.""" - print("The terminal is at:", self.run_command("pwd")) + async def _check_state(self): + """ + Check the state of the terminal, e.g. the current directory of the terminal process. Useful for agent to understand. + """ + output = await self.run_command("pwd") + logger.info("The terminal is at:", output) - def run_command(self, cmd: str, daemon=False) -> str: + async def run_command(self, cmd: str, daemon=False) -> str: """ Executes a specified command in the terminal and streams the output back in real time. This command maintains state across executions, such as the current directory, - allowing for sequential commands to be contextually aware. The output from the - command execution is placed into `stdout_queue`, which can be consumed as needed. + allowing for sequential commands to be contextually aware. Args: cmd (str): The command to execute in the terminal. - daemon (bool): If True, executes the command in a background thread, allowing - the main program to continue execution. The command's output is - collected asynchronously in daemon mode and placed into `stdout_queue`. - + daemon (bool): If True, executes the command in an asynchronous task, allowing + the main program to continue execution. Returns: str: The command's output or an empty string if `daemon` is True. Remember that - when `daemon` is True, the output is collected into `stdout_queue` and must - be consumed from there. - - Note: - If `stdout_queue` is not periodically consumed, it could potentially grow indefinitely, - consuming memory. Ensure that there's a mechanism in place to consume this queue, - especially during long-running or output-heavy command executions. + when `daemon` is True, use the `get_stdout_output` method to get the output. """ + if self.process is None: + await self._start_process() # Send the command self.process.stdin.write((cmd + self.command_terminator).encode()) self.process.stdin.write( - (f'echo "{END_MARKER_VALUE}"{self.command_terminator}').encode() # write EOF + f'echo "{END_MARKER_VALUE}"{self.command_terminator}'.encode() # write EOF ) # Unique marker to signal command end - self.process.stdin.flush() + await self.process.stdin.drain() if daemon: - threading.Thread(target=self._read_and_process_output, args=(cmd,), daemon=True).start() + asyncio.create_task(self._read_and_process_output(cmd)) return "" else: - return self._read_and_process_output(cmd) + return await self._read_and_process_output(cmd) - def execute_in_conda_env(self, cmd: str, env, daemon=False) -> str: + async def execute_in_conda_env(self, cmd: str, env, daemon=False) -> str: """ Executes a given command within a specified Conda environment automatically without the need for manual activation. Users just need to provide the name of the Conda @@ -84,7 +78,7 @@ class Terminal: cmd (str): The command to execute within the Conda environment. env (str, optional): The name of the Conda environment to activate before executing the command. If not specified, the command will run in the current active environment. - daemon (bool): If True, the command is run in a background thread, similar to `run_command`, + daemon (bool): If True, the command is run in an asynchronous task, similar to `run_command`, affecting error logging and handling in the same manner. Returns: @@ -96,19 +90,32 @@ class Terminal: to ensure the specified environment is active for the command's execution. """ cmd = f"conda run -n {env} {cmd}" - return self.run_command(cmd, daemon=daemon) + return await self.run_command(cmd, daemon=daemon) - def _read_and_process_output(self, cmd): - with self.observer as observer: + async def get_stdout_output(self) -> str: + """ + Retrieves all collected output from background running commands and returns it as a string. + + Returns: + str: The collected output from background running commands, returned as a string. + """ + output_lines = [] + while not self.stdout_queue.empty(): + line = await self.stdout_queue.get() + output_lines.append(line) + return "\n".join(output_lines) + + async def _read_and_process_output(self, cmd, daemon=False) -> str: + async with self.observer as observer: cmd_output = [] - observer.report(cmd + self.command_terminator, "cmd") - # report the comman + await observer.async_report(cmd + self.command_terminator, "cmd") + # report the command # Read the output until the unique marker is found. # We read bytes directly from stdout instead of text because when reading text, # '\r' is changed to '\n', resulting in excessive output. tmp = b"" while True: - output = tmp + self.process.stdout.read(1) + output = tmp + await self.process.stdout.read(1) *lines, tmp = output.splitlines(True) for line in lines: line = line.decode() @@ -123,13 +130,13 @@ class Terminal: # log stdout in real-time observer.report(line, "output") cmd_output.append(line) - self.stdout_queue.put(line) + if daemon: + await self.stdout_queue.put(line) - def close(self): + async def close(self): """Close the persistent shell process.""" self.process.stdin.close() - self.process.terminate() - self.process.wait() + await self.process.wait() @register_tool(include_functions=["run"]) @@ -142,10 +149,13 @@ class Bash(Terminal): def __init__(self): """init""" super().__init__() - self.run_command(f"cd {DEFAULT_WORKSPACE_ROOT}") - self.run_command(f"source {SWE_SETUP_PATH}") + self.start_flag = False - def run(self, cmd) -> str: + async def start(self): + await self.run_command(f"cd {DEFAULT_WORKSPACE_ROOT}") + await self.run_command(f"source {SWE_SETUP_PATH}") + + async def run(self, cmd) -> str: """ Executes a bash command. @@ -222,4 +232,8 @@ class Bash(Terminal): Note: Make sure to use these functions as per their defined arguments and behaviors. """ - return self.run_command(cmd) + if not self.start_flag: + await self.start() + self.start_flag = True + + return await self.run_command(cmd) diff --git a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py index 54b3623a4..e2aa3d17f 100644 --- a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py +++ b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py @@ -59,11 +59,11 @@ async def run(instance, swe_result_dir): # 前处理 terminal = Terminal() - terminal.run_command(f"cd {repo_path} && git reset --hard && git clean -n -d && git clean -f -d") - terminal.run_command("BRANCH=$(git remote show origin | awk '/HEAD branch/ {print $NF}')") - logger.info(terminal.run_command("echo $BRANCH")) - logger.info(terminal.run_command('git checkout "$BRANCH"')) - logger.info(terminal.run_command("git branch")) + await terminal.run_command(f"cd {repo_path} && git reset --hard && git clean -n -d && git clean -f -d") + await terminal.run_command("BRANCH=$(git remote show origin | awk '/HEAD branch/ {print $NF}')") + logger.info(await terminal.run_command("echo $BRANCH")) + logger.info(await terminal.run_command('git checkout "$BRANCH"')) + logger.info(await terminal.run_command("git branch")) user_requirement_and_issue = INSTANCE_TEMPLATE.format( issue=instance["problem_statement"], diff --git a/tests/metagpt/tools/libs/test_terminal.py b/tests/metagpt/tools/libs/test_terminal.py index 98ed63dd8..9c64009ae 100644 --- a/tests/metagpt/tools/libs/test_terminal.py +++ b/tests/metagpt/tools/libs/test_terminal.py @@ -4,16 +4,17 @@ from metagpt.const import DATA_PATH, METAGPT_ROOT from metagpt.tools.libs.terminal import Terminal -def test_terminal(): +@pytest.mark.asyncio +async def test_terminal(): terminal = Terminal() - terminal.run_command(f"cd {METAGPT_ROOT}") - output = terminal.run_command("pwd") + await terminal.run_command(f"cd {METAGPT_ROOT}") + output = await terminal.run_command("pwd") assert output.strip() == str(METAGPT_ROOT) # pwd now should be METAGPT_ROOT, cd data should land in DATA_PATH - terminal.run_command("cd data") - output = terminal.run_command("pwd") + await terminal.run_command("cd data") + output = await terminal.run_command("pwd") assert output.strip() == str(DATA_PATH)