diff --git a/data/inference/run_api.py b/data/inference/run_api.py index 66f229f85..9202d6a42 100644 --- a/data/inference/run_api.py +++ b/data/inference/run_api.py @@ -15,6 +15,7 @@ from metagpt.logs import logger from metagpt.roles.di.data_interpreter import DataInterpreter from metagpt.utils import count_string_tokens from metagpt.utils.recovery_util import save_history +from data.inference.const import SCIKIT_LEARN_IDS # Replace with your own MAX_TOKEN = 128000 @@ -70,6 +71,7 @@ async def openai_inference( for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"): di = DataInterpreter(use_reflection=use_reflection) instance_id = datum["instance_id"] + if instance_id in existing_ids: continue output_dict = {"instance_id": instance_id} @@ -124,12 +126,19 @@ async def main( dataset = dataset[split] lens = np.array(list(map(len, dataset["text"]))) dataset = dataset.select(np.argsort(lens)) + if len(existing_ids) > 0: dataset = dataset.filter( lambda x: x["instance_id"] not in existing_ids, desc="Filtering out existing ids", load_from_cache_file=False, ) + if len(SCIKIT_LEARN_IDS) > 0: + dataset = dataset.filter( + lambda x: x["instance_id"] in SCIKIT_LEARN_IDS, + desc="Filtering out subset_instance_ids", + load_from_cache_file=False, + ) inference_args = { "test_dataset": dataset, "model_name_or_path": model_name_or_path,