mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-27 05:46:22 +02:00
Remove redundant code.
Use data wrappers in model wrapper APIs. More typing.
This commit is contained in:
parent
9f4d649934
commit
3d82db80c4
5 changed files with 57 additions and 166 deletions
|
|
@ -5,17 +5,21 @@ Implementation of utility classes for dataset handling
|
|||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Collection, Any
|
||||
from typing import Callable, Collection, Any, Union
|
||||
|
||||
import tarfile
|
||||
import os
|
||||
import urllib.request
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame]
|
||||
|
||||
|
||||
class DatasetABC(metaclass=ABCMeta):
|
||||
"""Base Abstract Class for Dataset"""
|
||||
|
||||
|
|
@ -122,7 +126,7 @@ class StoredDatasetABC(DatasetABC):
|
|||
class BaseDataset(DatasetABC):
|
||||
"""Base Class for Dataset"""
|
||||
|
||||
def __init__(self, x, y, **kwargs):
|
||||
def __init__(self, x: DATA_ARRAY_TYPE, y: DATA_ARRAY_TYPE, **kwargs):
|
||||
"""
|
||||
BaseDataset constructor.
|
||||
:param x: collection of data samples
|
||||
|
|
@ -135,11 +139,11 @@ class BaseDataset(DatasetABC):
|
|||
if len(self.x) != len(self.y):
|
||||
raise ValueError('Non equivalent lengths of x and y')
|
||||
|
||||
def get_samples(self) -> Collection[Any]:
|
||||
def get_samples(self) -> DATA_ARRAY_TYPE:
|
||||
"""Return data samples"""
|
||||
return self.x
|
||||
|
||||
def get_labels(self) -> Collection[Any]:
|
||||
def get_labels(self) -> DATA_ARRAY_TYPE:
|
||||
"""Return labels"""
|
||||
return self.y
|
||||
|
||||
|
|
@ -192,7 +196,7 @@ class Data:
|
|||
The class stores train and test datasets.
|
||||
If neither of the datasets was provided,
|
||||
Both train and test datasets will be create using
|
||||
Factory command to create dataset instance
|
||||
DatasetFactory to create a dataset instance
|
||||
"""
|
||||
if train or test:
|
||||
self.train = train
|
||||
|
|
@ -209,18 +213,18 @@ class Data:
|
|||
"""Return test DatasetBase"""
|
||||
return self.test
|
||||
|
||||
def get_train_samples(self):
|
||||
def get_train_samples(self) -> Collection[Any]:
|
||||
"""Return train set samples"""
|
||||
return self.train.get_samples()
|
||||
|
||||
def get_train_labels(self):
|
||||
def get_train_labels(self) -> Collection[Any]:
|
||||
"""Return train set labels"""
|
||||
return self.train.get_labels()
|
||||
|
||||
def get_test_samples(self):
|
||||
def get_test_samples(self) -> Collection[Any]:
|
||||
"""Return test set samples"""
|
||||
return self.test.get_samples()
|
||||
|
||||
def get_test_labels(self):
|
||||
def get_test_labels(self) -> Collection[Any]:
|
||||
"""Return test set labels"""
|
||||
return self.test.get_labels()
|
||||
return self.test.get_labels()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue