diff --git a/data/inference/const.py b/data/inference/const.py new file mode 100644 index 000000000..42e63ec0e --- /dev/null +++ b/data/inference/const.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pandas as pd + +from metagpt.const import METAGPT_ROOT + +SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv" +SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv" + +# SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET. +# This collection represents a subset specifically related to scikit-learn content. +SCIKIT_LEARN_IDS = [ + "scikit-learn__scikit-learn-11578", + "scikit-learn__scikit-learn-10297", + "scikit-learn__scikit-learn-25747", + "scikit-learn__scikit-learn-15512", + "scikit-learn__scikit-learn-15119", + "scikit-learn__scikit-learn-10870", + "scikit-learn__scikit-learn-15100", + "scikit-learn__scikit-learn-14496", + "scikit-learn__scikit-learn-14890", + "scikit-learn__scikit-learn-10428", + "scikit-learn__scikit-learn-25744", + "scikit-learn__scikit-learn-11542", + "scikit-learn__scikit-learn-10198", + "scikit-learn__scikit-learn-10459", +] + + +def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"): + try: + df = pd.read_excel(path) + # Filter for instances containing the tag in either column + pass_filter = df["instance_id_pass"].str.contains(tag, na=False) + fail_filter = df["instance_id_fail"].str.contains(tag, na=False) + + # Combine the filters using | (OR operator) for efficiency + combined_filter = pass_filter | fail_filter + + # Apply combined filter and select the specific columns + filtered_df = df[combined_filter][["instance_id_pass", "instance_id_fail"]] + + # Flatten the DataFrame into a list and remove NaN values + subset_instance = filtered_df.stack().dropna().tolist() + + return subset_instance + except FileNotFoundError: + print(f"File not found: {path}") + return [] + except Exception as e: + print(f"An error occurred: {e}") + return [] diff --git a/data/inference/run.py b/data/inference/run.py new file mode 100644 index 000000000..a3f3c54aa --- /dev/null +++ b/data/inference/run.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import runpy +import sys + +original_argv = sys.argv.copy() + +try: + # 设置你想要传递给脚本的命令行参数 + sys.argv = ["run_api.py", "--dataset_name_or_path", "princeton-nlp/SWE-bench_oracle", "--output_dir", "./outputs"] + # 执行脚本 + runpy.run_path(path_name="run_api.py", run_name="__main__") +finally: + # 恢复原始的sys.argv以避免对后续代码的潜在影响 + sys.argv = original_argv diff --git a/data/inference/run_api.py b/data/inference/run_api.py index 66f229f85..7882f13e7 100644 --- a/data/inference/run_api.py +++ b/data/inference/run_api.py @@ -10,6 +10,7 @@ from make_datasets.utils import extract_diff from tenacity import retry, stop_after_attempt, wait_random_exponential from tqdm.auto import tqdm +from data.inference.const import SCIKIT_LEARN_IDS from metagpt.config2 import config from metagpt.logs import logger from metagpt.roles.di.data_interpreter import DataInterpreter @@ -70,6 +71,7 @@ async def openai_inference( for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"): di = DataInterpreter(use_reflection=use_reflection) instance_id = datum["instance_id"] + if instance_id in existing_ids: continue output_dict = {"instance_id": instance_id} @@ -124,12 +126,19 @@ async def main( dataset = dataset[split] lens = np.array(list(map(len, dataset["text"]))) dataset = dataset.select(np.argsort(lens)) + if len(existing_ids) > 0: dataset = dataset.filter( lambda x: x["instance_id"] not in existing_ids, desc="Filtering out existing ids", load_from_cache_file=False, ) + if len(SCIKIT_LEARN_IDS) > 0: + dataset = dataset.filter( + lambda x: x["instance_id"] in SCIKIT_LEARN_IDS, + desc="Filtering out subset_instance_ids", + load_from_cache_file=False, + ) inference_args = { "test_dataset": dataset, "model_name_or_path": model_name_or_path,