mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-03 00:32:37 +02:00
Check for mismatch between model output type and actual output
This commit is contained in:
parent
bc7ab0cc7f
commit
1cc73b3da1
5 changed files with 75 additions and 39 deletions
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from sklearn.preprocessing import OneHotEncoder
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from apt.utils.models import Model, ModelOutputType, get_nb_classes
|
||||
from apt.utils.models import Model, ModelOutputType, get_nb_classes, check_correct_model_output
|
||||
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
||||
|
||||
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
|
||||
|
|
@ -71,7 +71,9 @@ class SklearnClassifier(SklearnModel):
|
|||
:type x: `Dataset`
|
||||
:return: Predictions from the model as numpy array (class probabilities, if supported).
|
||||
"""
|
||||
return self._art_model.predict(x.get_samples(), **kwargs)
|
||||
predictions = self._art_model.predict(x.get_samples(), **kwargs)
|
||||
check_correct_model_output(predictions, self.output_type)
|
||||
return predictions
|
||||
|
||||
|
||||
class SklearnRegressor(SklearnModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue