Remove redundant code.

Use data wrappers in model wrapper APIs.
More typing.
This commit is contained in:
abigailt 2022-03-06 21:15:07 +02:00
parent 9f4d649934
commit 3d82db80c4
5 changed files with 57 additions and 166 deletions

View file

@ -5,17 +5,21 @@ Implementation of utility classes for dataset handling
"""
from abc import ABCMeta, abstractmethod
from typing import Callable, Collection, Any
from typing import Callable, Collection, Any, Union
import tarfile
import os
import urllib.request
import numpy as np
import pandas as pd
import logging
logger = logging.getLogger(__name__)
DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame]
class DatasetABC(metaclass=ABCMeta):
"""Base Abstract Class for Dataset"""
@ -122,7 +126,7 @@ class StoredDatasetABC(DatasetABC):
class BaseDataset(DatasetABC):
"""Base Class for Dataset"""
def __init__(self, x, y, **kwargs):
def __init__(self, x: DATA_ARRAY_TYPE, y: DATA_ARRAY_TYPE, **kwargs):
"""
BaseDataset constructor.
:param x: collection of data samples
@ -135,11 +139,11 @@ class BaseDataset(DatasetABC):
if len(self.x) != len(self.y):
raise ValueError('Non equivalent lengths of x and y')
def get_samples(self) -> Collection[Any]:
def get_samples(self) -> DATA_ARRAY_TYPE:
"""Return data samples"""
return self.x
def get_labels(self) -> Collection[Any]:
def get_labels(self) -> DATA_ARRAY_TYPE:
"""Return labels"""
return self.y
@ -192,7 +196,7 @@ class Data:
The class stores train and test datasets.
If neither of the datasets was provided,
Both train and test datasets will be create using
Factory command to create dataset instance
DatasetFactory to create a dataset instance
"""
if train or test:
self.train = train
@ -209,18 +213,18 @@ class Data:
"""Return test DatasetBase"""
return self.test
def get_train_samples(self):
def get_train_samples(self) -> Collection[Any]:
"""Return train set samples"""
return self.train.get_samples()
def get_train_labels(self):
def get_train_labels(self) -> Collection[Any]:
"""Return train set labels"""
return self.train.get_labels()
def get_test_samples(self):
def get_test_samples(self) -> Collection[Any]:
"""Return test set samples"""
return self.test.get_samples()
def get_test_labels(self):
def get_test_labels(self) -> Collection[Any]:
"""Return test set labels"""
return self.test.get_labels()
return self.test.get_labels()