mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
Merge pull request #1077 from stellaHSR/swebench_di
[bug fix] add code files for load_dataset and oracle_collapsed generation
This commit is contained in:
commit
a44a2c8e49
21 changed files with 181 additions and 15 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -188,3 +188,4 @@ cov.xml
|
|||
*-structure.json
|
||||
*.dot
|
||||
.python-version
|
||||
/data/inference
|
||||
|
|
|
|||
|
Can't render this file because it contains an unexpected character in line 2 and column 280.
|
|
Can't render this file because it has a wrong number of fields in line 2.
|
35
benchmark/swe_bench/data/load_dataset.py
Normal file
35
benchmark/swe_bench/data/load_dataset.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from benchmark.swe_bench.inference.const import SCIKIT_LEARN_IDS
|
||||
|
||||
|
||||
def load_oracle_dataset(dataset_name_or_path: str = "", split: str = "test", existing_ids: list = []):
|
||||
if Path(dataset_name_or_path).exists():
|
||||
dataset = load_from_disk(dataset_name_or_path)
|
||||
else:
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
if split not in dataset:
|
||||
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
|
||||
dataset = dataset[split]
|
||||
lens = np.array(list(map(len, dataset["text"])))
|
||||
dataset = dataset.select(np.argsort(lens))
|
||||
|
||||
if existing_ids:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] not in existing_ids,
|
||||
desc="Filtering out existing ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
if SCIKIT_LEARN_IDS:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
|
||||
desc="Filtering out subset_instance_ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
return dataset
|
||||
|
|
@ -3,11 +3,11 @@
|
|||
# @Desc :
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.const import DATA_PATH, METAGPT_ROOT
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
|
||||
SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv"
|
||||
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv"
|
||||
TESTBED = DATA_PATH / "repos"
|
||||
SUBSET_DATASET = METAGPT_ROOT / "benchmark" / "swe_bench" / "sub_swebench_dataset" / "sub_swebench.csv"
|
||||
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "benchmark" / "sub_swebench_dataset" / "scikit-learn-68.csv"
|
||||
TESTBED = METAGPT_ROOT / "benchmark" / "swe_bench" / "data" / "repos"
|
||||
|
||||
# SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET.
|
||||
# This collection represents a subset specifically related to scikit-learn content.
|
||||
|
|
@ -5,12 +5,12 @@ import re
|
|||
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from benchmark.swe_bench.gitagent import GitAgent
|
||||
from benchmark.swe_bench.make_datasets.make_dataset import reset_task_env
|
||||
from benchmark.swe_bench.utils.utils import extract_scripts_from_codetext
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.recovery_util import save_history
|
||||
from swe_bench.gitagent import GitAgent
|
||||
from swe_bench.make_datasets.make_dataset import reset_task_env
|
||||
from swe_bench.utils.utils import extract_scripts_from_codetext
|
||||
|
||||
PATCH_FORMAT = """
|
||||
```diff
|
||||
|
|
@ -2,14 +2,14 @@ import json
|
|||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
from data.load_dataset import load_oracle_dataset
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from benchmark.swe_bench.data.load_dataset import load_oracle_dataset
|
||||
from benchmark.swe_bench.inference.run_agent import run_instance
|
||||
from benchmark.swe_bench.utils.utils import check_existing_ids, extract_diff
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils import count_string_tokens
|
||||
from swe_bench.inference.run_agent import run_instance
|
||||
from swe_bench.utils.utils import check_existing_ids, extract_diff
|
||||
|
||||
# Replace with your own
|
||||
MAX_TOKEN = 128000
|
||||
|
|
@ -56,7 +56,7 @@ async def openai_inference(
|
|||
logger.info(f"{repo_prefix}_{version}")
|
||||
data.append(f"{repo_prefix}_{version}")
|
||||
|
||||
response = await run_instance(instance=datum)
|
||||
response = await run_instance(instance=datum, use_reflection=use_reflection)
|
||||
if response is None:
|
||||
continue
|
||||
logger.info(f"Final response: {response}")
|
||||
|
|
@ -6,11 +6,11 @@ from pathlib import Path
|
|||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.inference.const import TESTBED
|
||||
from benchmark.swe_bench.inference.const import TESTBED
|
||||
from benchmark.swe_bench.make_datasets.make_instance import prompt_style_2_edits_only
|
||||
from benchmark.swe_bench.utils.parse_diff import filter_changed_line
|
||||
from benchmark.swe_bench.utils.repo_utils import EnvManager
|
||||
from metagpt.logs import logger
|
||||
from swe_bench.make_datasets.make_instance import prompt_style_2_edits_only
|
||||
from swe_bench.utils.parse_diff import filter_changed_line
|
||||
from swe_bench.utils.repo_utils import EnvManager
|
||||
|
||||
|
||||
def reset_task_env(instance: dict = {}):
|
||||
3
benchmark/swe_bench/utils/__init__.py
Normal file
3
benchmark/swe_bench/utils/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
12
benchmark/swe_bench/utils/enums.py
Normal file
12
benchmark/swe_bench/utils/enums.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from enum import Enum, auto
|
||||
|
||||
|
||||
class FileChangeMode(Enum):
|
||||
create = auto()
|
||||
delete = auto()
|
||||
change = auto()
|
||||
|
||||
|
||||
class LineChangeType(Enum):
|
||||
addition = auto()
|
||||
deletion = auto()
|
||||
115
benchmark/swe_bench/utils/parse_diff.py
Normal file
115
benchmark/swe_bench/utils/parse_diff.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from benchmark.swe_bench.utils.enums import FileChangeMode, LineChangeType
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def extract_changes_from_patch(diff: str) -> List[Dict[str, any]]:
|
||||
"""Parses the patch text through the standard syntax of git diff, outputs the information of added and deleted lines.
|
||||
|
||||
Extracts detailed information about file changes based on the output content of git diff.
|
||||
|
||||
Args:
|
||||
diff: A string containing the output of git diff.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries containing information about each file change.
|
||||
"""
|
||||
changes = []
|
||||
current_file = None
|
||||
|
||||
file_pattern = re.compile(r"^diff --git a/(.+) b/(.+)$")
|
||||
line_change_pattern = re.compile(r"^@@ -(\d+),\d+ \+(\d+),\d+ @@.*$")
|
||||
new_file_flag_line = "--- /dev/null"
|
||||
deleted_file_flag_line = "+++ /dev/null"
|
||||
|
||||
for line in diff.splitlines():
|
||||
file_section_start = file_pattern.match(line)
|
||||
if file_section_start:
|
||||
if current_file:
|
||||
changes.append(current_file)
|
||||
file_a, file_b = file_section_start.groups()
|
||||
current_file = start_new_file_section(file_a, file_b)
|
||||
current_file["mode"] = FileChangeMode.change
|
||||
elif current_file:
|
||||
# 匹配到新文件模式,标记当前文件为新增
|
||||
if line == new_file_flag_line:
|
||||
current_file["mode"] = FileChangeMode.create
|
||||
# 匹配到删除文件模式,标记当前文件为删除
|
||||
elif line == deleted_file_flag_line:
|
||||
current_file["mode"] = FileChangeMode.delete
|
||||
update_file_changes(current_file, line, line_change_pattern)
|
||||
|
||||
if current_file:
|
||||
changes.append(current_file)
|
||||
|
||||
return changes
|
||||
|
||||
|
||||
def start_new_file_section(file_before_change: str, file_after_change: str) -> Dict[str, any]:
|
||||
"""Function to initialize a new file section
|
||||
|
||||
When encountering a new file change, this function is called to initialize a dictionary recording the file change information.
|
||||
|
||||
Args:
|
||||
file_before_change: The file name before the change
|
||||
file_after_change: The file name after the change, or "/dev/null" if the file was deleted.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the file change.
|
||||
"""
|
||||
return {
|
||||
"file_before_change": file_before_change,
|
||||
"file_after_change": file_after_change,
|
||||
"changes": [],
|
||||
}
|
||||
|
||||
|
||||
def update_file_changes(current_file: Dict[str, any], line: str, line_change_pattern: re.Pattern):
|
||||
"""Updates the current file change information
|
||||
|
||||
Updates the current file's change record based on a line from the diff.
|
||||
|
||||
Args:
|
||||
current_file: The current file information being processed
|
||||
line: The current line from the diff
|
||||
line_change_pattern: The regex pattern used to identify line changes
|
||||
"""
|
||||
line_change_match = line_change_pattern.match(line)
|
||||
if line_change_match:
|
||||
current_file["base_line"], current_file["changed_line"] = map(int, line_change_match.groups())
|
||||
elif line.startswith("+"):
|
||||
current_file["changes"].append(
|
||||
{"type": LineChangeType.addition, "line": current_file.get("changed_line", 1), "content": line[1:]}
|
||||
)
|
||||
current_file["changed_line"] = current_file.get("changed_line", 0) + 1
|
||||
elif line.startswith("-"):
|
||||
current_file["changes"].append(
|
||||
{"type": LineChangeType.deletion, "line": current_file.get("base_line", 1), "content": line[1:]}
|
||||
)
|
||||
current_file["base_line"] = current_file.get("base_line", 0) + 1
|
||||
|
||||
|
||||
def filter_changed_line(patch):
|
||||
"""Filters changed lines
|
||||
|
||||
Filters the part of the change record of the current file that needs to be used.
|
||||
|
||||
Args:
|
||||
patch: The git diff text
|
||||
"""
|
||||
parsed_changes = extract_changes_from_patch(patch)
|
||||
res = {}
|
||||
for change in parsed_changes:
|
||||
file_name = change["file_before_change"]
|
||||
res[file_name] = []
|
||||
# 新增的文件略过
|
||||
if change["mode"] is FileChangeMode.create:
|
||||
continue
|
||||
for c in change["changes"]:
|
||||
if c["type"] is LineChangeType.addition:
|
||||
continue
|
||||
logger.debug(f" {c['type']} at line {c['line']}: {c['content']}")
|
||||
res[file_name].append(c)
|
||||
return res
|
||||
Loading…
Add table
Add a link
Reference in a new issue