Check for mismatch between model output type and actual output

This commit is contained in:
abigailt 2022-07-19 08:43:19 +03:00 committed by abigailgold
parent bc7ab0cc7f
commit 1cc73b3da1
5 changed files with 75 additions and 39 deletions

View file

@ -3,7 +3,7 @@ from typing import Optional
from sklearn.preprocessing import OneHotEncoder
from sklearn.base import BaseEstimator
from apt.utils.models import Model, ModelOutputType, get_nb_classes
from apt.utils.models import Model, ModelOutputType, get_nb_classes, check_correct_model_output
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
@ -71,7 +71,9 @@ class SklearnClassifier(SklearnModel):
:type x: `Dataset`
:return: Predictions from the model as numpy array (class probabilities, if supported).
"""
return self._art_model.predict(x.get_samples(), **kwargs)
predictions = self._art_model.predict(x.get_samples(), **kwargs)
check_correct_model_output(predictions, self.output_type)
return predictions
class SklearnRegressor(SklearnModel):