mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-24 14:15:13 +02:00
Check for mismatch between model output type and actual output
This commit is contained in:
parent
bc7ab0cc7f
commit
1cc73b3da1
5 changed files with 75 additions and 39 deletions
|
|
@ -7,7 +7,7 @@ import tensorflow as tf
|
|||
from tensorflow import keras
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
|
||||
from apt.utils.models import Model, ModelOutputType, ScoringMethod
|
||||
from apt.utils.models import Model, ModelOutputType, ScoringMethod, check_correct_model_output
|
||||
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
||||
|
||||
from art.utils import check_and_transform_label_format
|
||||
|
|
@ -68,7 +68,9 @@ class KerasClassifier(KerasModel):
|
|||
:type x: `Dataset`
|
||||
:return: Predictions from the model as numpy array (class probabilities, if supported).
|
||||
"""
|
||||
return self._art_model.predict(x.get_samples(), **kwargs)
|
||||
predictions = self._art_model.predict(x.get_samples(), **kwargs)
|
||||
check_correct_model_output(predictions, self.output_type)
|
||||
return predictions
|
||||
|
||||
def score(self, test_data: Dataset, scoring_method: Optional[ScoringMethod] = ScoringMethod.ACCURACY, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue