From 78bfc9b2898bf122d21ed8842f2916f506618a71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E4=BC=9F=E9=9F=AC?= Date: Fri, 6 Sep 2024 20:00:24 +0800 Subject: [PATCH] update run_swe_agent_benchmark --- .../roles/di/run_swe_agent_for_benchmark.py | 82 +++++++++++-------- 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py index c0a11c609..54d16d146 100644 --- a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py +++ b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py @@ -17,7 +17,6 @@ from metagpt.tools.swe_agent_commands.swe_agent_utils import load_hf_dataset config = Config.default() # Specify by yourself -GLOBAL_TERMINAL = Terminal() TEST_REPO_DIR = METAGPT_ROOT / "data" / "test_repo" DATA_DIR = METAGPT_ROOT / "data/hugging_face" @@ -58,51 +57,65 @@ def check_instance_status(instance, swe_result_dir): return True -async def terminal_run_command(cmd): - cmd_output = await GLOBAL_TERMINAL.run_command(cmd) +async def terminal_run_command(cmd, terminal): + cmd_output = await terminal.run_command(cmd) logger.info(f"command:{cmd} output:\n {cmd_output}") return cmd_output async def refresh_repo(instance, test_repo_dir, reclone_existing_repo=False): - repo_path = Path(test_repo_dir) / ( - instance["repo"].replace("-", "_").replace("/", "__") + "_" + instance["version"] - ) - repo_identifier = instance["repo"] - base_commit = instance["base_commit"] - if os.path.exists(repo_path) and reclone_existing_repo is True: - logger.info(f"remove exist repo path:{repo_path.absolute()}") - shutil.rmtree(repo_path) - - if os.path.exists(repo_path): - logger.info(f"reset exist repo path:{repo_path.absolute()}") - await terminal_run_command( - f"cd {repo_path.absolute()} && git reset --hard && git clean -n -d && git clean -f -d" + terminal = Terminal() + try: + repo_path = Path(test_repo_dir) / ( + instance["repo"].replace("-", "_").replace("/", "__") + "_" + instance["version"] ) - await terminal_run_command("BRANCH=$(git remote show origin | awk '/HEAD branch/ {print $NF}')") - await terminal_run_command("echo $BRANCH") - await terminal_run_command('git checkout "$BRANCH"') - else: - logger.info(f"clone repo to path:{repo_path}") - clone_command = f"git clone 'https://github.com/{repo_identifier}.git' {repo_path.absolute()}" - checkout_command = f"cd {repo_path.absolute()} " + f"&& git checkout -f {base_commit}" if base_commit else "" - await terminal_run_command(clone_command) - await terminal_run_command(checkout_command) - - await terminal_run_command("git branch") + repo_identifier = instance["repo"] + base_commit = instance["base_commit"] + if os.path.exists(repo_path) and reclone_existing_repo is True: + logger.info(f"remove exist repo path:{repo_path.absolute()}") + shutil.rmtree(repo_path) + if os.path.exists(repo_path): + logger.info(f"reset exist repo path:{repo_path.absolute()}") + await terminal_run_command( + f"cd {repo_path.absolute()} && git reset --hard && git clean -n -d && git clean -f -d", terminal + ) + await terminal_run_command("BRANCH=$(git remote show origin | awk '/HEAD branch/ {print $NF}')", terminal) + await terminal_run_command("echo $BRANCH", terminal) + await terminal_run_command('git checkout "$BRANCH"', terminal) + else: + logger.info(f"clone repo to path:{repo_path}") + clone_command = f"git clone 'https://github.com/{repo_identifier}.git' {repo_path.absolute()}" + checkout_command = ( + f"cd {repo_path.absolute()} " + f"&& git checkout -f {base_commit}" if base_commit else "" + ) + await terminal_run_command(clone_command, terminal) + await terminal_run_command(checkout_command, terminal) + await terminal_run_command("git branch", terminal) + await terminal_run_command("pwd", terminal) + except Exception as e: + logger.warning(e) + finally: + terminal.close() return repo_path -async def get_git_diff(): +async def get_git_diff(instance, test_repo_dir): git_diff = "" + terminal = Terminal() try: + repo_path = Path(test_repo_dir) / ( + instance["repo"].replace("-", "_").replace("/", "__") + "_" + instance["version"] + ) # ignore backup file - await terminal_run_command("echo '.backup.*' >> .gitignore") - await terminal_run_command("git add -A") - git_diff = await terminal_run_command("git diff --cached") + await terminal_run_command(f"cd {repo_path.absolute()} ", terminal) + await terminal_run_command("echo '.backup.*' >> .gitignore", terminal) + await terminal_run_command("git add -A", terminal) + git_diff = await terminal_run_command("git diff --cached", terminal) except Exception as e: logger.error(f"Error during submission: {e}") + finally: + terminal.close() return git_diff @@ -133,14 +146,14 @@ async def run(instance, swe_result_dir, args): except Exception as e: logger.warning(f"**** exception lead to end: {instance['instance_id']}****\n\nerror:{e}") # save the difference of repo - await save_predictions(engineer, instance, swe_result_dir) + await save_predictions(engineer, instance, test_repo_dir, swe_result_dir) logger.info(f"**** Finished running {instance['instance_id']}****") -async def save_predictions(engineer, instance, swe_result_dir): +async def save_predictions(engineer, instance, test_repo_dir, swe_result_dir): output_file = swe_result_dir / "all_preds.jsonl" instance["model_name_or_path"] = engineer.config.llm.model - instance["model_patch"] = await get_git_diff() + instance["model_patch"] = await get_git_diff(instance, test_repo_dir) logger.info(f"'model_patch':\n{instance['model_patch']}") logger.info(f"Preparing to save predictions to {output_file}") @@ -207,5 +220,4 @@ python tests/metagpt/roles/di/run_swe_agent_for_benchmark.py \ --test_repo_dir "./data/test_repo" \ --save_folder "./workspace/deepseek_coder_0907" \ --max_wait_time_per_case 10 \ ---reclone_existing_repo """