Add more to wrappers

This commit is contained in:
abigailt 2022-03-15 11:42:57 +02:00 committed by olasaadi
parent a2b560920f
commit a9162fbc43
6 changed files with 86 additions and 38 deletions

View file

@ -1,2 +1,2 @@
from apt.utils.models.model import Model
from apt.utils.models.model import Model, ModelOutputType
from apt.utils.models.sklearn_model import SklearnModel, SklearnClassifier, SklearnRegressor

View file

@ -1,7 +1,14 @@
from abc import ABCMeta, abstractmethod
from typing import Any
from enum import Enum, auto
from apt.utils.datasets import Dataset, DATA_ARRAY_TYPE
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
class ModelOutputType(Enum):
CLASSIFIER_VECTOR = auto() # probabilities or logits
CLASSIFIER_SCALAR = auto() # label only
REGRESSOR_SCALAR = auto() # value
class Model(metaclass=ABCMeta):
@ -9,13 +16,16 @@ class Model(metaclass=ABCMeta):
Abstract base class for ML model wrappers.
"""
def __init__(self, model: Any, **kwargs):
def __init__(self, model: Any, output_type: ModelOutputType, **kwargs):
"""
Initialize a `Model` wrapper object.
:param model: The original model object (of the underlying ML framework)
:param output_type: The type of output the model yields (vector/label only for classifiers,
value for regressors)
"""
self._model = model
self._output_type = output_type
@abstractmethod
def fit(self, train_data: Dataset, **kwargs) -> None:
@ -28,7 +38,7 @@ class Model(metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""
Perform predictions using the model for input `x`.
@ -39,10 +49,19 @@ class Model(metaclass=ABCMeta):
raise NotImplementedError
@property
def model(self):
def model(self) -> Any:
"""
Return the model.
:return: The model.
"""
return self._model
@property
def output_type(self) -> ModelOutputType:
"""
Return the model's output type.
:return: The model's output type.
"""
return self._output_type

View file

@ -3,8 +3,8 @@ import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn.base import BaseEstimator
from apt.utils.models import Model
from apt.utils.datasets import Dataset, DATA_ARRAY_TYPE
from apt.utils.models import Model, ModelOutputType
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
from art.estimators.regression.scikitlearn import ScikitlearnRegressor
@ -28,13 +28,13 @@ class SklearnClassifier(SklearnModel):
"""
Wrapper class for scikitlearn classification models.
"""
def __init__(self, model: BaseEstimator, **kwargs):
def __init__(self, model: BaseEstimator, output_type: ModelOutputType, **kwargs):
"""
Initialize a `SklearnClassifier` wrapper object.
:param model: The original sklearn model object
"""
super().__init__(model, **kwargs)
super().__init__(model, output_type, **kwargs)
self._art_model = ArtSklearnClassifier(model)
def fit(self, train_data: Dataset, **kwargs) -> None:
@ -48,7 +48,7 @@ class SklearnClassifier(SklearnModel):
y_encoded = encoder.fit_transform(train_data.get_labels().reshape(-1, 1))
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""
Perform predictions using the model for input `x`.
@ -69,7 +69,7 @@ class SklearnRegressor(SklearnModel):
:param model: The original sklearn model object
"""
super().__init__(model, **kwargs)
super().__init__(model, ModelOutputType.REGRESSOR_SCALAR, **kwargs)
self._art_model = ScikitlearnRegressor(model)
def fit(self, train_data: Dataset, **kwargs) -> None:
@ -81,7 +81,7 @@ class SklearnRegressor(SklearnModel):
"""
self._art_model.fit(train_data.get_samples(), train_data.get_labels(), **kwargs)
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""
Perform predictions using the model for input `x`.