mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-12 05:22:37 +02:00
Check for mismatch between model output type and actual output
This commit is contained in:
parent
bc7ab0cc7f
commit
1cc73b3da1
5 changed files with 75 additions and 39 deletions
|
|
@ -8,30 +8,6 @@ from art.estimators.classification import BlackBoxClassifier
|
|||
from art.utils import check_and_transform_label_format
|
||||
|
||||
|
||||
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
|
||||
return len(y.shape) == 2 and y.shape[1] > 1
|
||||
|
||||
|
||||
def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
|
||||
"""
|
||||
Get the number of classes from an array of labels
|
||||
|
||||
:param y: the labels
|
||||
:type y: numpy array
|
||||
:return: the number of classes as integer
|
||||
"""
|
||||
if y is None:
|
||||
return 0
|
||||
|
||||
if type(y) != np.ndarray:
|
||||
raise ValueError("Input should be numpy array")
|
||||
|
||||
if is_one_hot(y):
|
||||
return y.shape[1]
|
||||
else:
|
||||
return int(np.max(y) + 1)
|
||||
|
||||
|
||||
class ModelOutputType(Enum):
|
||||
CLASSIFIER_PROBABILITIES = auto() # vector of probabilities
|
||||
CLASSIFIER_LOGITS = auto() # vector of logits
|
||||
|
|
@ -49,6 +25,45 @@ class ScoringMethod(Enum):
|
|||
MEAN_SQUARED_ERROR = auto() # mean squared error between the predictions and true labels
|
||||
|
||||
|
||||
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
|
||||
return len(y.shape) == 2 and y.shape[1] > 1
|
||||
|
||||
|
||||
def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
|
||||
"""
|
||||
Get the number of classes from an array of labels
|
||||
|
||||
:param y: The labels
|
||||
:type y: numpy array
|
||||
:return: The number of classes as integer
|
||||
"""
|
||||
if y is None:
|
||||
return 0
|
||||
|
||||
if type(y) != np.ndarray:
|
||||
raise ValueError("Input should be numpy array")
|
||||
|
||||
if is_one_hot(y):
|
||||
return y.shape[1]
|
||||
else:
|
||||
return int(np.max(y) + 1)
|
||||
|
||||
|
||||
def check_correct_model_output(y: OUTPUT_DATA_ARRAY_TYPE, output_type: ModelOutputType):
|
||||
"""
|
||||
Checks whether there is a mismatch between the declared model output type and its actual output.
|
||||
:param y: Model output
|
||||
:type y: numpy array
|
||||
:param output_type: Declared output type (provided at init)
|
||||
:type output_type: ModelOutputType
|
||||
:raises: ValueError (in case of mismatch)
|
||||
"""
|
||||
if not is_one_hot(y): # 1D array
|
||||
if output_type == ModelOutputType.CLASSIFIER_PROBABILITIES or output_type == ModelOutputType.CLASSIFIER_LOGITS:
|
||||
raise ValueError("Incompatible model output types. Model outputs 1D array of categorical scalars while "
|
||||
"output type is set to ", output_type)
|
||||
|
||||
|
||||
class Model(metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract base class for ML model wrappers.
|
||||
|
|
@ -147,8 +162,6 @@ class Model(metaclass=ABCMeta):
|
|||
return self._unlimited_queries
|
||||
|
||||
|
||||
|
||||
|
||||
class BlackboxClassifier(Model):
|
||||
"""
|
||||
Wrapper for black-box ML classification models.
|
||||
|
|
@ -168,7 +181,6 @@ class BlackboxClassifier(Model):
|
|||
:type model_type: Either a (unfitted) model object of the underlying framework, or a ModelType representing the
|
||||
type of the model, optional.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Any, output_type: ModelOutputType, black_box_access: Optional[bool] = True,
|
||||
unlimited_queries: Optional[bool] = True, model_type: Optional[Union[Any, ModelType]] = None,
|
||||
**kwargs):
|
||||
|
|
@ -220,7 +232,9 @@ class BlackboxClassifier(Model):
|
|||
:type x: `Dataset`
|
||||
:return: Predictions from the model as numpy array.
|
||||
"""
|
||||
return self._art_model.predict(x.get_samples())
|
||||
predictions = self._art_model.predict(x.get_samples())
|
||||
check_correct_model_output(predictions, self.output_type)
|
||||
return predictions
|
||||
|
||||
def score(self, test_data: Dataset, scoring_method: Optional[ScoringMethod] = ScoringMethod.ACCURACY, **kwargs):
|
||||
"""
|
||||
|
|
@ -266,6 +280,11 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
|
|||
x_test_pred = model.get_test_samples()
|
||||
y_test_pred = model.get_test_labels()
|
||||
|
||||
if y_train_pred is not None:
|
||||
check_correct_model_output(y_train_pred, self.output_type)
|
||||
if y_test_pred is not None:
|
||||
check_correct_model_output(y_test_pred, self.output_type)
|
||||
|
||||
if y_train_pred is not None and len(y_train_pred.shape) == 1:
|
||||
self._nb_classes = get_nb_classes(y_train_pred)
|
||||
y_train_pred = check_and_transform_label_format(y_train_pred, nb_classes=self._nb_classes)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue