Merge pull request #1077 from stellaHSR/swebench_di

[bug fix] add code files for load_dataset and oracle_collapsed generation
This commit is contained in:
better629 2024-03-27 14:29:45 +08:00 committed by GitHub
commit a44a2c8e49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 181 additions and 15 deletions

1
.gitignore vendored
View file

@ -188,3 +188,4 @@ cov.xml
*-structure.json
*.dot
.python-version
/data/inference

View file

Can't render this file because it contains an unexpected character in line 2 and column 280.

View file

Can't render this file because it has a wrong number of fields in line 2.

View file

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from pathlib import Path
import numpy as np
from datasets import load_dataset, load_from_disk
from benchmark.swe_bench.inference.const import SCIKIT_LEARN_IDS
def load_oracle_dataset(dataset_name_or_path: str = "", split: str = "test", existing_ids: list = []):
if Path(dataset_name_or_path).exists():
dataset = load_from_disk(dataset_name_or_path)
else:
dataset = load_dataset(dataset_name_or_path)
if split not in dataset:
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
dataset = dataset[split]
lens = np.array(list(map(len, dataset["text"])))
dataset = dataset.select(np.argsort(lens))
if existing_ids:
dataset = dataset.filter(
lambda x: x["instance_id"] not in existing_ids,
desc="Filtering out existing ids",
load_from_cache_file=False,
)
if SCIKIT_LEARN_IDS:
dataset = dataset.filter(
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
desc="Filtering out subset_instance_ids",
load_from_cache_file=False,
)
return dataset

View file

@ -3,11 +3,11 @@
# @Desc :
import pandas as pd
from metagpt.const import DATA_PATH, METAGPT_ROOT
from metagpt.const import METAGPT_ROOT
SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv"
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv"
TESTBED = DATA_PATH / "repos"
SUBSET_DATASET = METAGPT_ROOT / "benchmark" / "swe_bench" / "sub_swebench_dataset" / "sub_swebench.csv"
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "benchmark" / "sub_swebench_dataset" / "scikit-learn-68.csv"
TESTBED = METAGPT_ROOT / "benchmark" / "swe_bench" / "data" / "repos"
# SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET.
# This collection represents a subset specifically related to scikit-learn content.

View file

@ -5,12 +5,12 @@ import re
from tenacity import retry, stop_after_attempt, wait_random_exponential
from benchmark.swe_bench.gitagent import GitAgent
from benchmark.swe_bench.make_datasets.make_dataset import reset_task_env
from benchmark.swe_bench.utils.utils import extract_scripts_from_codetext
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.recovery_util import save_history
from swe_bench.gitagent import GitAgent
from swe_bench.make_datasets.make_dataset import reset_task_env
from swe_bench.utils.utils import extract_scripts_from_codetext
PATCH_FORMAT = """
```diff

View file

@ -2,14 +2,14 @@ import json
from pathlib import Path
import fire
from data.load_dataset import load_oracle_dataset
from tqdm.auto import tqdm
from benchmark.swe_bench.data.load_dataset import load_oracle_dataset
from benchmark.swe_bench.inference.run_agent import run_instance
from benchmark.swe_bench.utils.utils import check_existing_ids, extract_diff
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils import count_string_tokens
from swe_bench.inference.run_agent import run_instance
from swe_bench.utils.utils import check_existing_ids, extract_diff
# Replace with your own
MAX_TOKEN = 128000
@ -56,7 +56,7 @@ async def openai_inference(
logger.info(f"{repo_prefix}_{version}")
data.append(f"{repo_prefix}_{version}")
response = await run_instance(instance=datum)
response = await run_instance(instance=datum, use_reflection=use_reflection)
if response is None:
continue
logger.info(f"Final response: {response}")

View file

@ -6,11 +6,11 @@ from pathlib import Path
from tqdm.auto import tqdm
from data.inference.const import TESTBED
from benchmark.swe_bench.inference.const import TESTBED
from benchmark.swe_bench.make_datasets.make_instance import prompt_style_2_edits_only
from benchmark.swe_bench.utils.parse_diff import filter_changed_line
from benchmark.swe_bench.utils.repo_utils import EnvManager
from metagpt.logs import logger
from swe_bench.make_datasets.make_instance import prompt_style_2_edits_only
from swe_bench.utils.parse_diff import filter_changed_line
from swe_bench.utils.repo_utils import EnvManager
def reset_task_env(instance: dict = {}):

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,12 @@
from enum import Enum, auto
class FileChangeMode(Enum):
create = auto()
delete = auto()
change = auto()
class LineChangeType(Enum):
addition = auto()
deletion = auto()

View file

@ -0,0 +1,115 @@
import re
from typing import Dict, List
from benchmark.swe_bench.utils.enums import FileChangeMode, LineChangeType
from metagpt.logs import logger
def extract_changes_from_patch(diff: str) -> List[Dict[str, any]]:
"""Parses the patch text through the standard syntax of git diff, outputs the information of added and deleted lines.
Extracts detailed information about file changes based on the output content of git diff.
Args:
diff: A string containing the output of git diff.
Returns:
A list of dictionaries containing information about each file change.
"""
changes = []
current_file = None
file_pattern = re.compile(r"^diff --git a/(.+) b/(.+)$")
line_change_pattern = re.compile(r"^@@ -(\d+),\d+ \+(\d+),\d+ @@.*$")
new_file_flag_line = "--- /dev/null"
deleted_file_flag_line = "+++ /dev/null"
for line in diff.splitlines():
file_section_start = file_pattern.match(line)
if file_section_start:
if current_file:
changes.append(current_file)
file_a, file_b = file_section_start.groups()
current_file = start_new_file_section(file_a, file_b)
current_file["mode"] = FileChangeMode.change
elif current_file:
# 匹配到新文件模式,标记当前文件为新增
if line == new_file_flag_line:
current_file["mode"] = FileChangeMode.create
# 匹配到删除文件模式,标记当前文件为删除
elif line == deleted_file_flag_line:
current_file["mode"] = FileChangeMode.delete
update_file_changes(current_file, line, line_change_pattern)
if current_file:
changes.append(current_file)
return changes
def start_new_file_section(file_before_change: str, file_after_change: str) -> Dict[str, any]:
"""Function to initialize a new file section
When encountering a new file change, this function is called to initialize a dictionary recording the file change information.
Args:
file_before_change: The file name before the change
file_after_change: The file name after the change, or "/dev/null" if the file was deleted.
Returns:
A dictionary representing the file change.
"""
return {
"file_before_change": file_before_change,
"file_after_change": file_after_change,
"changes": [],
}
def update_file_changes(current_file: Dict[str, any], line: str, line_change_pattern: re.Pattern):
"""Updates the current file change information
Updates the current file's change record based on a line from the diff.
Args:
current_file: The current file information being processed
line: The current line from the diff
line_change_pattern: The regex pattern used to identify line changes
"""
line_change_match = line_change_pattern.match(line)
if line_change_match:
current_file["base_line"], current_file["changed_line"] = map(int, line_change_match.groups())
elif line.startswith("+"):
current_file["changes"].append(
{"type": LineChangeType.addition, "line": current_file.get("changed_line", 1), "content": line[1:]}
)
current_file["changed_line"] = current_file.get("changed_line", 0) + 1
elif line.startswith("-"):
current_file["changes"].append(
{"type": LineChangeType.deletion, "line": current_file.get("base_line", 1), "content": line[1:]}
)
current_file["base_line"] = current_file.get("base_line", 0) + 1
def filter_changed_line(patch):
"""Filters changed lines
Filters the part of the change record of the current file that needs to be used.
Args:
patch: The git diff text
"""
parsed_changes = extract_changes_from_patch(patch)
res = {}
for change in parsed_changes:
file_name = change["file_before_change"]
res[file_name] = []
# 新增的文件略过
if change["mode"] is FileChangeMode.create:
continue
for c in change["changes"]:
if c["type"] is LineChangeType.addition:
continue
logger.debug(f" {c['type']} at line {c['line']}: {c['content']}")
res[file_name].append(c)
return res