add load dataset

This commit is contained in:
stellahsr 2024-03-22 16:27:00 +08:00
parent e4d02ca68c
commit f26a5cd1de

35
data/load_dataset.py Normal file
View file

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from pathlib import Path
import numpy as np
from datasets import load_dataset, load_from_disk
from data.inference.const import SCIKIT_LEARN_IDS
def load_oracle_dataset(dataset_name_or_path: str = "", split: str = "test", existing_ids: list = []):
if Path(dataset_name_or_path).exists():
dataset = load_from_disk(dataset_name_or_path)
else:
dataset = load_dataset(dataset_name_or_path)
if split not in dataset:
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
dataset = dataset[split]
lens = np.array(list(map(len, dataset["text"])))
dataset = dataset.select(np.argsort(lens))
if len(existing_ids) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] not in existing_ids,
desc="Filtering out existing ids",
load_from_cache_file=False,
)
if len(SCIKIT_LEARN_IDS) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
desc="Filtering out subset_instance_ids",
load_from_cache_file=False,
)
return dataset