Remove redundant code.

Use data wrappers in model wrapper APIs.
More typing.
This commit is contained in:
abigailt 2022-03-06 21:15:07 +02:00
parent 9f4d649934
commit 3d82db80c4
5 changed files with 57 additions and 166 deletions

View file

@ -1,7 +1,10 @@
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from apt.utils.models import Model, ModelWithLoss, SingleOutputModel
from sklearn.preprocessing import OneHotEncoder
from sklearn.base import BaseEstimator
from apt.utils.models import Model
from apt.utils.datasets import BaseDataset, DATA_ARRAY_TYPE
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
from art.estimators.regression.scikitlearn import ScikitlearnRegressor
@ -11,23 +14,21 @@ class SklearnModel(Model):
"""
Wrapper class for scikitlearn models.
"""
def score(self, x: np.ndarray, y: np.ndarray, **kwargs):
def score(self, test_data: BaseDataset, **kwargs):
"""
Score the model using test data `(x, y)`.
Score the model using test data.
:param x: Test data.
:type x: `np.ndarray` or `pandas.DataFrame`
:param y: True labels.
:type y: `np.ndarray` or `pandas.DataFrame`
:param test_data: Test data.
:type train_data: `BaseDataset`
"""
return self.model.score(x, y, **kwargs)
return self.model.score(test_data.get_samples(), test_data.get_labels(), **kwargs)
class SklearnClassifier(SklearnModel):
"""
Wrapper class for scikitlearn classification models.
"""
def __init__(self, model, **kwargs):
def __init__(self, model: BaseEstimator, **kwargs):
"""
Initialize a `SklearnClassifier` wrapper object.
@ -36,35 +37,33 @@ class SklearnClassifier(SklearnModel):
super().__init__(model, **kwargs)
self._art_model = ArtSklearnClassifier(model)
def fit(self, x: np.ndarray, y: np.ndarray, **kwargs) -> None:
def fit(self, train_data: BaseDataset, **kwargs) -> None:
"""
Fit the model using the training data `(x, y)`.
Fit the model using the training data.
:param x: Training data.
:type x: `np.ndarray` or `pandas.DataFrame`
:param y: True labels.
:type y: `np.ndarray` or `pandas.DataFrame`
:param train_data: Training data.
:type train_data: `BaseDataset`
"""
encoder = OneHotEncoder(sparse=False)
y_encoded = encoder.fit_transform(y.reshape(-1, 1))
self._art_model.fit(x, y_encoded, **kwargs)
y_encoded = encoder.fit_transform(train_data.get_labels().reshape(-1, 1))
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)
def predict(self, x: np.ndarray, **kwargs) -> np.ndarray:
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
"""
Perform predictions using the model for input `x`.
:param x: Input samples.
:type x: `np.ndarray` or `pandas.DataFrame`
:return: Predictions from the model.
:return: Predictions from the model (class probabilities, if supported).
"""
return self._art_model.predict(x, **kwargs)
class SklearnRegressor(SklearnModel, SingleOutputModel, ModelWithLoss):
class SklearnRegressor(SklearnModel):
"""
Wrapper class for scikitlearn regression models.
"""
def __init__(self, model, **kwargs):
def __init__(self, model: BaseEstimator, **kwargs):
"""
Initialize a `SklearnRegressor` wrapper object.
@ -73,18 +72,16 @@ class SklearnRegressor(SklearnModel, SingleOutputModel, ModelWithLoss):
super().__init__(model, **kwargs)
self._art_model = ScikitlearnRegressor(model)
def fit(self, x: np.ndarray, y: np.ndarray, **kwargs) -> None:
def fit(self, train_data: BaseDataset, **kwargs) -> None:
"""
Fit the model using the training data `(x, y)`.
Fit the model using the training data.
:param x: Training data.
:type x: `np.ndarray` or `pandas.DataFrame`
:param y: True labels.
:type y: `np.ndarray` or `pandas.DataFrame`
:param train_data: Training data.
:type train_data: `BaseDataset`
"""
self._art_model.fit(x, y, **kwargs)
self._art_model.fit(train_data.get_samples(), train_data.get_labels(), **kwargs)
def predict(self, x: np.ndarray, **kwargs) -> np.ndarray:
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
"""
Perform predictions using the model for input `x`.
@ -93,50 +90,3 @@ class SklearnRegressor(SklearnModel, SingleOutputModel, ModelWithLoss):
:return: Predictions from the model.
"""
return self._art_model.predict(x, **kwargs)
def loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
"""
Compute the loss of the model for samples `x`.
:param x: Input samples.
:type x: `np.ndarray` or `pandas.DataFrame`
:param y: True labels.
:type y: `np.ndarray` or `pandas.DataFrame`
:return: Loss values.
"""
return self._art_model.compute_loss(x, y, **kwargs)
# Probably not needed for now, as we will not be using these wrappers directly in ART.
# class SklearnDecisionTreeClassifier(SklearnClassifier, MultipleOutputModel):
# """
# Wrapper class for scikitlearn decision tree classifier models.
# """
# def __init__(self, model):
# """
# Initialize a `DecisionTreeClassifier` wrapper object.
#
# :param model: The original sklearn decision tree model object
# """
# super().__init__(model)
# self._art_model = ScikitlearnDecisionTreeClassifier(model)
#
# def get_decision_path(self, x: np.ndarray) -> np.ndarray:
# """
# Returns the nodes along the path taken in the tree when classifying x. Last node is the leaf, first node is the
# root node.
#
# :param x: Input samples.
# :type x: `np.ndarray` or `pandas.DataFrame`
# :return: The indices of the nodes in the array structure of the tree.
# """
# return self._art_model.get_decision_path(x)
#
# def get_samples_at_node(self, node_id: int) -> int:
# """
# Returns the number of training samples mapped to a node.
#
# :param node_id: The ID of the node.
# :return: Number of samples mapped this node.
# """
# return self._art_model.get_samples_at_node(node_id)