mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-17 15:35:21 +02:00
Merge pull request #1046 from stellaHSR/swebench_di
Add instance filter
This commit is contained in:
commit
c52dcc77d7
3 changed files with 78 additions and 0 deletions
53
data/inference/const.py
Normal file
53
data/inference/const.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
|
||||
SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv"
|
||||
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv"
|
||||
|
||||
# SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET.
|
||||
# This collection represents a subset specifically related to scikit-learn content.
|
||||
SCIKIT_LEARN_IDS = [
|
||||
"scikit-learn__scikit-learn-11578",
|
||||
"scikit-learn__scikit-learn-10297",
|
||||
"scikit-learn__scikit-learn-25747",
|
||||
"scikit-learn__scikit-learn-15512",
|
||||
"scikit-learn__scikit-learn-15119",
|
||||
"scikit-learn__scikit-learn-10870",
|
||||
"scikit-learn__scikit-learn-15100",
|
||||
"scikit-learn__scikit-learn-14496",
|
||||
"scikit-learn__scikit-learn-14890",
|
||||
"scikit-learn__scikit-learn-10428",
|
||||
"scikit-learn__scikit-learn-25744",
|
||||
"scikit-learn__scikit-learn-11542",
|
||||
"scikit-learn__scikit-learn-10198",
|
||||
"scikit-learn__scikit-learn-10459",
|
||||
]
|
||||
|
||||
|
||||
def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"):
|
||||
try:
|
||||
df = pd.read_excel(path)
|
||||
# Filter for instances containing the tag in either column
|
||||
pass_filter = df["instance_id_pass"].str.contains(tag, na=False)
|
||||
fail_filter = df["instance_id_fail"].str.contains(tag, na=False)
|
||||
|
||||
# Combine the filters using | (OR operator) for efficiency
|
||||
combined_filter = pass_filter | fail_filter
|
||||
|
||||
# Apply combined filter and select the specific columns
|
||||
filtered_df = df[combined_filter][["instance_id_pass", "instance_id_fail"]]
|
||||
|
||||
# Flatten the DataFrame into a list and remove NaN values
|
||||
subset_instance = filtered_df.stack().dropna().tolist()
|
||||
|
||||
return subset_instance
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {path}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return []
|
||||
16
data/inference/run.py
Normal file
16
data/inference/run.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import runpy
|
||||
import sys
|
||||
|
||||
original_argv = sys.argv.copy()
|
||||
|
||||
try:
|
||||
# 设置你想要传递给脚本的命令行参数
|
||||
sys.argv = ["run_api.py", "--dataset_name_or_path", "princeton-nlp/SWE-bench_oracle", "--output_dir", "./outputs"]
|
||||
# 执行脚本
|
||||
runpy.run_path(path_name="run_api.py", run_name="__main__")
|
||||
finally:
|
||||
# 恢复原始的sys.argv以避免对后续代码的潜在影响
|
||||
sys.argv = original_argv
|
||||
|
|
@ -10,6 +10,7 @@ from make_datasets.utils import extract_diff
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.inference.const import SCIKIT_LEARN_IDS
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
|
@ -70,6 +71,7 @@ async def openai_inference(
|
|||
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
|
||||
di = DataInterpreter(use_reflection=use_reflection)
|
||||
instance_id = datum["instance_id"]
|
||||
|
||||
if instance_id in existing_ids:
|
||||
continue
|
||||
output_dict = {"instance_id": instance_id}
|
||||
|
|
@ -124,12 +126,19 @@ async def main(
|
|||
dataset = dataset[split]
|
||||
lens = np.array(list(map(len, dataset["text"])))
|
||||
dataset = dataset.select(np.argsort(lens))
|
||||
|
||||
if len(existing_ids) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] not in existing_ids,
|
||||
desc="Filtering out existing ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
if len(SCIKIT_LEARN_IDS) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
|
||||
desc="Filtering out subset_instance_ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
inference_args = {
|
||||
"test_dataset": dataset,
|
||||
"model_name_or_path": model_name_or_path,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue