diff --git a/apt/utils.py b/apt/utils.py index b7aa78a..086492f 100644 --- a/apt/utils.py +++ b/apt/utils.py @@ -2,7 +2,7 @@ from sklearn import datasets, model_selection import sklearn.preprocessing import pandas as pd import ssl -from os import path +from os import path, mkdir from six.moves.urllib.request import urlretrieve @@ -40,9 +40,13 @@ def get_adult_dataset(): 'label'] train_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data' test_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test' + data_dir = '../datasets/adult' train_file = '../datasets/adult/train' test_file = '../datasets/adult/test' + if not path.exists(data_dir): + mkdir(data_dir) + ssl._create_default_https_context = ssl._create_unverified_context if not path.exists(train_file): urlretrieve(train_url, train_file) @@ -139,8 +143,12 @@ def get_nursery_dataset(raw: bool = True, test_set: float = 0.2, transform_socia :return: Dataset and labels as pandas dataframes. """ url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/nursery/nursery.data' + data_dir = '../datasets/nursery' data_file = '../datasets/nursery/data' + if not path.exists(data_dir): + mkdir(data_dir) + ssl._create_default_https_context = ssl._create_unverified_context if not path.exists(data_file): urlretrieve(url, data_file) diff --git a/datasets/.gitignore b/datasets/.gitignore new file mode 100644 index 0000000..86d0cb2 --- /dev/null +++ b/datasets/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore \ No newline at end of file