Remove check of correct shape of predictions which becomes too complicated with the new output types supported.

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-05-02 18:56:31 +03:00
parent a4816878f9
commit 846de0f753
6 changed files with 27 additions and 25 deletions

View file

@ -93,7 +93,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
if is_regression:
self.estimator = SklearnRegressor(estimator)
else:
# TODO: maybe we should get model output type from user in this case
# model output type is not critical as it only affects computation of nb_classes, which is in any case
# the same currently for single and multi output probabilities.
self.estimator = SklearnClassifier(estimator,
ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_CLASS_PROBABILITIES)
self.target_accuracy = target_accuracy

View file

@ -4,7 +4,7 @@ import numpy as np
from sklearn.metrics import mean_squared_error
from apt.utils.models import Model, ModelOutputType, ScoringMethod, check_correct_model_output, is_logits
from apt.utils.models import Model, ModelOutputType, ScoringMethod, is_logits
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.utils import check_and_transform_label_format
@ -63,7 +63,7 @@ class KerasClassifier(KerasModel):
:return: Predictions from the model as numpy array (class probabilities, if supported).
"""
predictions = self._art_model.predict(x.get_samples(), **kwargs)
check_correct_model_output(predictions, self.output_type)
# check_correct_model_output(predictions, self.output_type)
return predictions
def score(self, test_data: Dataset, scoring_method: Optional[ScoringMethod] = ScoringMethod.ACCURACY, **kwargs):

View file

@ -85,6 +85,8 @@ def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE, output_type: ModelOutputType) -> i
:param y: The labels
:type y: numpy array
:param output_type: The output type of the model, as provided by the user
:type output_type: ModelOutputType
:return: The number of classes as integer, or list of integers for multi-label
"""
if y is None:
@ -96,9 +98,8 @@ def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE, output_type: ModelOutputType) -> i
if is_one_hot(y):
return y.shape[1]
elif is_multi_label(output_type):
# for now just return the number of labels
# for now just return the prediction dimension - this works in most cases
return y.shape[1]
# return [int(np.max(y.T[i]) + 1) for i in range(y.shape[1])]
elif is_categorical(output_type):
return int(np.max(y) + 1)
else: # binary
@ -391,7 +392,7 @@ class BlackboxClassifier(Model):
:return: Predictions from the model as numpy array.
"""
predictions = self._art_model.predict(x.get_samples())
check_correct_model_output(predictions, self.output_type)
# check_correct_model_output(predictions, self.output_type)
return predictions
@abstractmethod
@ -434,17 +435,17 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
if y_test_pred is None:
y_test_pred = model.get_test_labels()
if y_train_pred is not None:
check_correct_model_output(y_train_pred, self.output_type)
if y_test_pred is not None:
check_correct_model_output(y_test_pred, self.output_type)
# if y_train_pred is not None:
# check_correct_model_output(y_train_pred, self.output_type)
# if y_test_pred is not None:
# check_correct_model_output(y_test_pred, self.output_type)
if y_train_pred is not None and len(y_train_pred.shape) == 1:
self._nb_classes = get_nb_classes(y_train_pred, self.output_type)
# self._nb_classes = get_nb_classes(y_train_pred, self.output_type)
y_train_pred = check_and_transform_label_format(y_train_pred, nb_classes=self._nb_classes)
if y_test_pred is not None and len(y_test_pred.shape) == 1:
if self._nb_classes is None:
self._nb_classes = get_nb_classes(y_test_pred, self.output_type)
# if self._nb_classes is None:
# self._nb_classes = get_nb_classes(y_test_pred, self.output_type)
y_test_pred = check_and_transform_label_format(y_test_pred, nb_classes=self._nb_classes)
if x_train_pred is not None and y_train_pred is not None and x_test_pred is not None and y_test_pred is not None:

View file

@ -2,7 +2,7 @@ from typing import Optional
from sklearn.base import BaseEstimator
from apt.utils.models import Model, ModelOutputType, get_nb_classes, check_correct_model_output
from apt.utils.models import Model, ModelOutputType, get_nb_classes
from apt.utils.datasets import Dataset, ArrayDataset, OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
@ -71,7 +71,7 @@ class SklearnClassifier(SklearnModel):
:return: Predictions from the model as numpy array (class probabilities, if supported).
"""
predictions = self._art_model.predict(x.get_samples(), **kwargs)
check_correct_model_output(predictions, self.output_type)
# check_correct_model_output(predictions, self.output_type)
return predictions

View file

@ -1,6 +1,6 @@
from typing import Optional, Tuple
from apt.utils.models import Model, ModelOutputType, ScoringMethod, check_correct_model_output, is_one_hot
from apt.utils.models import Model, ModelOutputType, ScoringMethod, is_one_hot
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
import numpy as np
@ -63,7 +63,7 @@ class XGBoostClassifier(XGBoostModel):
:return: Predictions from the model as numpy array (class probabilities, if supported).
"""
predictions = self._art_model.predict(x.get_samples(), **kwargs)
check_correct_model_output(predictions, self.output_type)
# check_correct_model_output(predictions, self.output_type)
return predictions
def score(self, test_data: Dataset, scoring_method: Optional[ScoringMethod] = ScoringMethod.ACCURACY, **kwargs):

View file

@ -227,14 +227,14 @@ def test_blackbox_classifier_predictions_multi_label_binary():
assert model.model_type is None
def test_blackbox_classifier_mismatch():
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
train = ArrayDataset(x_train, y_train)
test = ArrayDataset(x_test, y_test)
data = Data(train, test)
with pytest.raises(ValueError):
BlackboxClassifierPredictions(data, ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_CLASS_PROBABILITIES)
# def test_blackbox_classifier_mismatch():
# (x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
#
# train = ArrayDataset(x_train, y_train)
# test = ArrayDataset(x_test, y_test)
# data = Data(train, test)
# with pytest.raises(ValueError):
# BlackboxClassifierPredictions(data, ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_CLASS_PROBABILITIES)
def test_blackbox_classifier_no_test():