mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 10:26:32 +02:00
add load dataset
This commit is contained in:
parent
e4d02ca68c
commit
f26a5cd1de
1 changed files with 35 additions and 0 deletions
35
data/load_dataset.py
Normal file
35
data/load_dataset.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from data.inference.const import SCIKIT_LEARN_IDS
|
||||
|
||||
|
||||
def load_oracle_dataset(dataset_name_or_path: str = "", split: str = "test", existing_ids: list = []):
|
||||
if Path(dataset_name_or_path).exists():
|
||||
dataset = load_from_disk(dataset_name_or_path)
|
||||
else:
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
if split not in dataset:
|
||||
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
|
||||
dataset = dataset[split]
|
||||
lens = np.array(list(map(len, dataset["text"])))
|
||||
dataset = dataset.select(np.argsort(lens))
|
||||
|
||||
if len(existing_ids) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] not in existing_ids,
|
||||
desc="Filtering out existing ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
if len(SCIKIT_LEARN_IDS) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
|
||||
desc="Filtering out subset_instance_ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
return dataset
|
||||
Loading…
Add table
Add a link
Reference in a new issue