Check for mismatch between model output type and actual output

This commit is contained in:
abigailt 2022-07-19 08:43:19 +03:00 committed by abigailgold
parent bc7ab0cc7f
commit 1cc73b3da1
5 changed files with 75 additions and 39 deletions

View file

@ -7,7 +7,7 @@ import tensorflow as tf
from tensorflow import keras
tf.compat.v1.disable_eager_execution()
from apt.utils.models import Model, ModelOutputType, ScoringMethod
from apt.utils.models import Model, ModelOutputType, ScoringMethod, check_correct_model_output
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.utils import check_and_transform_label_format
@ -68,7 +68,9 @@ class KerasClassifier(KerasModel):
:type x: `Dataset`
:return: Predictions from the model as numpy array (class probabilities, if supported).
"""
return self._art_model.predict(x.get_samples(), **kwargs)
predictions = self._art_model.predict(x.get_samples(), **kwargs)
check_correct_model_output(predictions, self.output_type)
return predictions
def score(self, test_data: Dataset, scoring_method: Optional[ScoringMethod] = ScoringMethod.ACCURACY, **kwargs):
"""