Check for mismatch between model output type and actual output

This commit is contained in:
abigailt 2022-07-19 08:43:19 +03:00 committed by abigailgold
parent bc7ab0cc7f
commit 1cc73b3da1
5 changed files with 75 additions and 39 deletions

View file

@ -8,30 +8,6 @@ from art.estimators.classification import BlackBoxClassifier
from art.utils import check_and_transform_label_format
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
return len(y.shape) == 2 and y.shape[1] > 1
def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
"""
Get the number of classes from an array of labels
:param y: the labels
:type y: numpy array
:return: the number of classes as integer
"""
if y is None:
return 0
if type(y) != np.ndarray:
raise ValueError("Input should be numpy array")
if is_one_hot(y):
return y.shape[1]
else:
return int(np.max(y) + 1)
class ModelOutputType(Enum):
CLASSIFIER_PROBABILITIES = auto() # vector of probabilities
CLASSIFIER_LOGITS = auto() # vector of logits
@ -49,6 +25,45 @@ class ScoringMethod(Enum):
MEAN_SQUARED_ERROR = auto() # mean squared error between the predictions and true labels
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
return len(y.shape) == 2 and y.shape[1] > 1
def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
"""
Get the number of classes from an array of labels
:param y: The labels
:type y: numpy array
:return: The number of classes as integer
"""
if y is None:
return 0
if type(y) != np.ndarray:
raise ValueError("Input should be numpy array")
if is_one_hot(y):
return y.shape[1]
else:
return int(np.max(y) + 1)
def check_correct_model_output(y: OUTPUT_DATA_ARRAY_TYPE, output_type: ModelOutputType):
"""
Checks whether there is a mismatch between the declared model output type and its actual output.
:param y: Model output
:type y: numpy array
:param output_type: Declared output type (provided at init)
:type output_type: ModelOutputType
:raises: ValueError (in case of mismatch)
"""
if not is_one_hot(y): # 1D array
if output_type == ModelOutputType.CLASSIFIER_PROBABILITIES or output_type == ModelOutputType.CLASSIFIER_LOGITS:
raise ValueError("Incompatible model output types. Model outputs 1D array of categorical scalars while "
"output type is set to ", output_type)
class Model(metaclass=ABCMeta):
"""
Abstract base class for ML model wrappers.
@ -147,8 +162,6 @@ class Model(metaclass=ABCMeta):
return self._unlimited_queries
class BlackboxClassifier(Model):
"""
Wrapper for black-box ML classification models.
@ -168,7 +181,6 @@ class BlackboxClassifier(Model):
:type model_type: Either a (unfitted) model object of the underlying framework, or a ModelType representing the
type of the model, optional.
"""
def __init__(self, model: Any, output_type: ModelOutputType, black_box_access: Optional[bool] = True,
unlimited_queries: Optional[bool] = True, model_type: Optional[Union[Any, ModelType]] = None,
**kwargs):
@ -220,7 +232,9 @@ class BlackboxClassifier(Model):
:type x: `Dataset`
:return: Predictions from the model as numpy array.
"""
return self._art_model.predict(x.get_samples())
predictions = self._art_model.predict(x.get_samples())
check_correct_model_output(predictions, self.output_type)
return predictions
def score(self, test_data: Dataset, scoring_method: Optional[ScoringMethod] = ScoringMethod.ACCURACY, **kwargs):
"""
@ -266,6 +280,11 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
x_test_pred = model.get_test_samples()
y_test_pred = model.get_test_labels()
if y_train_pred is not None:
check_correct_model_output(y_train_pred, self.output_type)
if y_test_pred is not None:
check_correct_model_output(y_test_pred, self.output_type)
if y_train_pred is not None and len(y_train_pred.shape) == 1:
self._nb_classes = get_nb_classes(y_train_pred)
y_train_pred = check_and_transform_label_format(y_train_pred, nb_classes=self._nb_classes)