mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
update run_swe_agent_for_benchmark
This commit is contained in:
parent
2c045d6833
commit
47996da2b0
1 changed files with 44 additions and 17 deletions
|
|
@ -1,16 +1,21 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.swe_agent import SWEAgent
|
||||
from metagpt.roles.di.engineer2 import Engineer2
|
||||
from metagpt.tools.libs.terminal import Terminal
|
||||
from metagpt.tools.swe_agent_commands.swe_agent_utils import load_hf_dataset
|
||||
|
||||
config = Config.default()
|
||||
# Specify by yourself
|
||||
Role = Engineer2
|
||||
MAX_MINUTES_PRE_INSTANCE = 20
|
||||
TEST_REPO_DIR = METAGPT_ROOT / "data" / "test_repo"
|
||||
DATA_DIR = METAGPT_ROOT / "data/hugging_face"
|
||||
|
||||
|
|
@ -57,13 +62,20 @@ async def run(instance, swe_result_dir):
|
|||
return
|
||||
|
||||
repo_path = TEST_REPO_DIR / (instance["repo"].replace("-", "_").replace("/", "__") + "_" + instance["version"])
|
||||
|
||||
# 前处理
|
||||
# 下载仓库
|
||||
logger.info(f"repo_path:{repo_path}")
|
||||
if os.path.exists(repo_path):
|
||||
# 删除已有的仓库
|
||||
logger.info(f"remove exist repo path:{repo_path}")
|
||||
shutil.rmtree(repo_path)
|
||||
# 下载仓库 并切换分支
|
||||
terminal = Terminal()
|
||||
await terminal.run_command(f"cd {repo_path} && git reset --hard && git clean -n -d && git clean -f -d")
|
||||
await terminal.run_command("BRANCH=$(git remote show origin | awk '/HEAD branch/ {print $NF}')")
|
||||
logger.info(await terminal.run_command("echo $BRANCH"))
|
||||
logger.info(await terminal.run_command('git checkout "$BRANCH"'))
|
||||
repo_identifier = instance["repo"]
|
||||
base_commit = instance["base_commit"]
|
||||
clone_command = f"git clone 'https://github.com/{repo_identifier}.git' {repo_path}"
|
||||
checkout_command = f"cd {repo_path} && git checkout -f {base_commit}" if base_commit else ""
|
||||
await terminal.run_command(clone_command)
|
||||
logger.info(await terminal.run_command(checkout_command))
|
||||
logger.info(await terminal.run_command("git branch"))
|
||||
|
||||
user_requirement_and_issue = INSTANCE_TEMPLATE.format(
|
||||
|
|
@ -75,18 +87,23 @@ async def run(instance, swe_result_dir):
|
|||
)
|
||||
|
||||
logger.info(f"**** Starting to run {instance['instance_id']}****")
|
||||
swe_agent = SWEAgent()
|
||||
swe_agent.run_eval = True
|
||||
await swe_agent.run(user_requirement_and_issue)
|
||||
save_predictions(swe_agent, instance, swe_result_dir)
|
||||
logger.info("User Requirement", user_requirement_and_issue)
|
||||
try:
|
||||
role = Role(run_eval=True)
|
||||
await asyncio.wait_for(role.run(user_requirement_and_issue), timeout=MAX_MINUTES_PRE_INSTANCE * 60)
|
||||
except:
|
||||
logger.info(f"**** exception lead to end: {instance['instance_id']}****")
|
||||
pass
|
||||
|
||||
save_predictions(role, instance, swe_result_dir)
|
||||
logger.info(f"**** Finished running {instance['instance_id']}****")
|
||||
|
||||
|
||||
def save_predictions(swe_agent: SWEAgent, instance, swe_result_dir):
|
||||
def save_predictions(role, instance, swe_result_dir):
|
||||
output_file = swe_result_dir / "all_preds.jsonl"
|
||||
instance["model_name_or_path"] = swe_agent.config.llm.model
|
||||
instance["model_patch"] = swe_agent.output_diff
|
||||
|
||||
instance["model_name_or_path"] = role.config.llm.model
|
||||
instance["model_patch"] = role.output_diff
|
||||
logger.info("model_patch", role.output_diff)
|
||||
logger.info(f"Preparing to save predictions to {output_file}")
|
||||
|
||||
# Save the predictions to a JSONL file
|
||||
|
|
@ -102,11 +119,21 @@ async def async_main():
|
|||
dataset = load_hf_dataset(dataset_name_or_path=dataset_path, cache_dir=DATA_DIR, split="test")
|
||||
date_time = datetime.now().strftime("%m%d")
|
||||
_round = "first"
|
||||
# _round = "second"
|
||||
|
||||
exp_name = f"nano_mgx_{date_time}_{_round}"
|
||||
|
||||
# now = datetime.now()
|
||||
# formatted_time = now.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
# swe_result_dir = (
|
||||
# DEFAULT_WORKSPACE_ROOT / f"result_{config.llm.model.replace('/', '_')}_start_time_{formatted_time}" / exp_name
|
||||
# )
|
||||
swe_result_dir = DEFAULT_WORKSPACE_ROOT / f"result_{config.llm.model.replace('/', '_')}" / exp_name
|
||||
swe_result_dir.mkdir(parents=True, exist_ok=True)
|
||||
for instance in dataset:
|
||||
for index, instance in enumerate(dataset):
|
||||
# switch to a new logger file
|
||||
logger.remove()
|
||||
logger.add(sys.stderr, level="INFO")
|
||||
logger.add(swe_result_dir / f"{index+1}_{instance['instance_id']}.log", level="DEBUG")
|
||||
await run(instance, swe_result_dir)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue