update run_swe_bechmark script

This commit is contained in:
黄伟韬 2024-09-06 12:04:40 +08:00
parent 32365ad85c
commit 4063186836
3 changed files with 65 additions and 56 deletions

View file

@ -68,26 +68,31 @@ class Engineer2(RoleZero):
def _update_tool_execution(self):
# validate = ValidateAndRewriteCode()
cr = CodeReview()
self.tool_execution_map.update(
{
"Terminal.run_command": self.terminal.run_command,
"git_create_pull": git_create_pull,
"Engineer2.write_new_code": self.write_new_code,
"CodeReview.review": cr.review,
"CodeReview.fix": cr.fix,
# "ValidateAndRewriteCode.run": validate.run,
# "ValidateAndRewriteCode": validate.run,
}
)
self.exclusive_tool_commands.append("Engineer2.write_new_code")
if self.run_eval:
if self.run_eval is True:
# Evalute tool map
self.tool_execution_map.update(
{
"git_create_pull": git_create_pull,
"Engineer2.write_new_code": self.write_new_code,
"CodeReview.review": cr.review,
"CodeReview.fix": cr.fix,
"Terminal.run_command": self._eval_terminal_run,
"RoleZero.ask_human": self._end,
"RoleZero.reply_to_human": self._end,
}
)
else:
# Default tool map
self.tool_execution_map.update(
{
"git_create_pull": git_create_pull,
"Engineer2.write_new_code": self.write_new_code,
"CodeReview.review": cr.review,
"CodeReview.fix": cr.fix,
"Terminal.run_command": self.terminal.run_command,
}
)
async def _act(self) -> Message:
message = await super()._act()
@ -108,6 +113,7 @@ class Engineer2(RoleZero):
async def write_new_code(self, path: str, instruction: str = "") -> str:
"""Write a new code file.
Args:
path (str): The absolute path of the file to be created.
instruction (optional, str): Further hints or notice other than the current task instruction, must be very concise and can be empty. Defaults to "".

View file

@ -537,15 +537,14 @@ class Editor(BaseModel):
content = "".join(new_lines)
return content, n_added_lines
def _get_indentation_info(self, content, first_error_line):
def _get_indentation_info(self, content, first_line):
"""
Information about the current edit's indentation.
Includes guidance on how to fix it.
The indentation of the first insert line and the previous line, along with guidance for the next attempt.
"""
content_lines = content.split("\n")
pre_line = content_lines[first_error_line - 2] if first_error_line - 2 >= 0 else ""
pre_line = content_lines[first_line - 2] if first_line - 2 >= 0 else ""
pre_line_indent = len(pre_line) - len(pre_line.lstrip())
insert_line = content_lines[first_error_line - 1]
insert_line = content_lines[first_line - 1]
insert_line_indent = len(insert_line) - len(insert_line.lstrip())
ret_str = INDENTATION_INFO.format(
pre_line=pre_line,
@ -802,8 +801,8 @@ class Editor(BaseModel):
new_content: str: The new content to replace the old content with.
NOTE:
This tool is exclusive. If you use this tool, you cannot use any other commands in the current response.
If you need to use it multiple times, wait for the next turn.
This tool is exclusive. If you use this tool, you cannot use any other commands in the current response.
If you need to use it multiple times, wait for the next turn.
"""
# FIXME: support replacing *all* occurrences
if to_replace.strip() == "":
@ -881,8 +880,8 @@ class Editor(BaseModel):
line_number: int: The line number (starting from 1) to insert the content after.
content: str: The content to insert.
NOTE:
This tool is exclusive. If you use this tool, you cannot use any other commands in the current response.
If you need to use it multiple times, wait for the next turn.
This tool is exclusive. If you use this tool, you cannot use any other commands in the current response.
If you need to use it multiple times, wait for the next turn.
"""
file_name = self._try_fix_path(file_name)
@ -904,8 +903,8 @@ class Editor(BaseModel):
file_name: str: The name of the file to edit.
content: str: The content to insert.
NOTE:
This tool is exclusive. If you use this tool, you cannot use any other commands in the current response.
If you need to use it multiple times, wait for the next turn.
This tool is exclusive. If you use this tool, you cannot use any other commands in the current response.
If you need to use it multiple times, wait for the next turn.
"""
file_name = self._try_fix_path(file_name)

View file

@ -64,23 +64,31 @@ async def terminal_run_command(cmd):
return cmd_output
async def refresh_repo(instance, test_repo_dir):
async def refresh_repo(instance, test_repo_dir, reclone_existing_repo=False):
repo_path = Path(test_repo_dir) / (
instance["repo"].replace("-", "_").replace("/", "__") + "_" + instance["version"]
)
repo_identifier = instance["repo"]
base_commit = instance["base_commit"]
clone_command = f"git clone 'https://github.com/{repo_identifier}.git' {repo_path}"
checkout_command = f"cd {repo_path} && git checkout -f {base_commit}" if base_commit else ""
if os.path.exists(repo_path):
# 删除已有的仓库
if os.path.exists(repo_path) and reclone_existing_repo is True:
logger.info(f"remove exist repo path:{repo_path}")
shutil.rmtree(repo_path)
await terminal_run_command(clone_command)
await terminal_run_command(checkout_command)
if os.path.exists(repo_path):
logger.info(f"reset exist repo path:{repo_path}")
await terminal_run_command(f"cd {repo_path} && git reset --hard && git clean -n -d && git clean -f -d")
await terminal_run_command("BRANCH=$(git remote show origin | awk '/HEAD branch/ {print $NF}')")
await terminal_run_command("echo $BRANCH")
await terminal_run_command('git checkout "$BRANCH"')
else:
logger.info(f"clone repo to path:{repo_path}")
clone_command = f"git clone 'https://github.com/{repo_identifier}.git' {repo_path}"
checkout_command = f"cd {repo_path} " + "&& git checkout -f {base_commit}" if base_commit else ""
await terminal_run_command(clone_command)
await terminal_run_command(checkout_command)
await terminal_run_command("git branch")
# ignore backup file
await terminal_run_command("echo '.backup.*' >> .gitignore")
return repo_path
@ -97,14 +105,14 @@ async def get_git_diff():
async def run(instance, swe_result_dir, args):
if not check_instance_status(instance, swe_result_dir) and not args.cover:
if not check_instance_status(instance, swe_result_dir):
logger.info(f"Instance {instance['instance_id']} already exists, skipping execution.")
return
# preparation for the repo
logger.info(f"**** Preparing to run {instance['instance_id']}****")
test_repo_dir = args.test_repo_dir
repo_path = await refresh_repo(instance, test_repo_dir)
repo_path = await refresh_repo(instance, test_repo_dir, args.reclone_existing_repo)
user_requirement_and_issue = INSTANCE_TEMPLATE.format(
issue=instance["problem_statement"],
@ -117,22 +125,20 @@ async def run(instance, swe_result_dir, args):
logger.info(f"**** Starting to run {instance['instance_id']}****")
logger.info("User Requirement", user_requirement_and_issue)
try:
role = Engineer2(run_eval=True, editor=Editor(enable_auto_lint=True))
await asyncio.wait_for(role.run(user_requirement_and_issue), timeout=args.max_wait_time_per_case * 60)
engineer = Engineer2(run_eval=True, editor=Editor(enable_auto_lint=True))
await asyncio.wait_for(engineer.run(user_requirement_and_issue), timeout=args.max_wait_time_per_case * 60)
except Exception as e:
print(e)
logger.info(f"**** exception lead to end: {instance['instance_id']}****")
pass
logger.warning(f"**** exception lead to end: {instance['instance_id']}****\n\nerror:{e}")
# save the difference of repo
await save_predictions(role, instance, swe_result_dir)
await save_predictions(engineer, instance, swe_result_dir)
logger.info(f"**** Finished running {instance['instance_id']}****")
async def save_predictions(role, instance, swe_result_dir):
async def save_predictions(engineer, instance, swe_result_dir):
output_file = swe_result_dir / "all_preds.jsonl"
instance["model_name_or_path"] = role.config.llm.model
instance["model_name_or_path"] = engineer.config.llm.model
instance["model_patch"] = await get_git_diff()
logger.info(f"{instance['model_patch']=}")
logger.info(f"'model_patch':\n{instance['model_patch']}")
logger.info(f"Preparing to save predictions to {output_file}")
# Save the predictions to a JSONL file
@ -147,16 +153,9 @@ async def async_main(args):
dataset = load_hf_dataset(dataset_name_or_path=dataset_path, cache_dir=DATA_DIR, split="test")
swe_result_dir = Path(args.save_folder)
if swe_result_dir.exists():
if args.cover:
logger.info(f"{swe_result_dir} exists and original result remove")
shutil.rmtree(swe_result_dir.absolute())
else:
logger.info(f"{swe_result_dir} exists and continue test")
logger.info(f"{swe_result_dir} exists; resuming test from last checkpoint.")
swe_result_dir.mkdir(parents=True, exist_ok=True)
for index, instance in enumerate(dataset):
if index < args.ignore_first_n:
continue
# switch to a new logger file
logger.remove()
logger.add(sys.stderr, level="INFO")
@ -180,25 +179,30 @@ if __name__ == "__main__":
parser.add_argument(
"-mwtc", "--max_wait_time_per_case", help="Maximum wait time allowed per test case (in minutes)", type=int
)
parser.add_argument("-n", "--ignore_first_n", default=0, help="Cover the original flag", type=int)
parser.add_argument("-c", "--cover", default=False, help="Cover the original flag", type=bool)
parser.add_argument(
"-o",
"--reclone_existing_repo",
action="store_true",
help="If set, the existing repository will be removed and recloned.",
)
# 解析命令行参数
args = parser.parse_args()
asyncio.run(async_main(args))
"""
#
python tests/metagpt/roles/di/run_swe_agent_for_benchmark.py \
--test_repo_dir "./data/test_repo" \
--save_folder "./workspace/deepseek_coder_test1" \
--save_folder "./workspace/deepseek_coder_0907" \
--max_wait_time_per_case 10
"""
"""
Cover Mode:
# 重新克隆仓库
python tests/metagpt/roles/di/run_swe_agent_for_benchmark.py \
--test_repo_dir "./data/test_repo" \
--save_folder "./workspace/deepseek_coder_test1" \
--save_folder "./workspace/deepseek_coder_0907" \
--max_wait_time_per_case 10 \
--cover
--reclone_existing_repo
"""