mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-24 20:36:21 +02:00
Add model type to blackbox classifier (#49)
This commit is contained in:
parent
bc28f7f26a
commit
bc7ab0cc7f
2 changed files with 26 additions and 2 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Any, Optional, Callable, Tuple
|
from typing import Any, Optional, Callable, Tuple, Union
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -31,6 +31,7 @@ def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
|
||||||
else:
|
else:
|
||||||
return int(np.max(y) + 1)
|
return int(np.max(y) + 1)
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputType(Enum):
|
class ModelOutputType(Enum):
|
||||||
CLASSIFIER_PROBABILITIES = auto() # vector of probabilities
|
CLASSIFIER_PROBABILITIES = auto() # vector of probabilities
|
||||||
CLASSIFIER_LOGITS = auto() # vector of logits
|
CLASSIFIER_LOGITS = auto() # vector of logits
|
||||||
|
|
@ -38,6 +39,11 @@ class ModelOutputType(Enum):
|
||||||
REGRESSOR_SCALAR = auto() # value
|
REGRESSOR_SCALAR = auto() # value
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
SKLEARN_DECISION_TREE = auto()
|
||||||
|
SKLEARN_GRADIENT_BOOSTING = auto()
|
||||||
|
|
||||||
|
|
||||||
class ScoringMethod(Enum):
|
class ScoringMethod(Enum):
|
||||||
ACCURACY = auto() # number of correct predictions divided by the number of samples
|
ACCURACY = auto() # number of correct predictions divided by the number of samples
|
||||||
MEAN_SQUARED_ERROR = auto() # mean squared error between the predictions and true labels
|
MEAN_SQUARED_ERROR = auto() # mean squared error between the predictions and true labels
|
||||||
|
|
@ -157,13 +163,19 @@ class BlackboxClassifier(Model):
|
||||||
:type black_box_access: boolean, optional
|
:type black_box_access: boolean, optional
|
||||||
:param unlimited_queries: Boolean indicating whether a user can perform unlimited queries to the model API.
|
:param unlimited_queries: Boolean indicating whether a user can perform unlimited queries to the model API.
|
||||||
:type unlimited_queries: boolean, optional
|
:type unlimited_queries: boolean, optional
|
||||||
|
:param model_type: The type of model this BlackboxClassifier represents. Needed in order to build and/or fit
|
||||||
|
similar dummy/shadow models.
|
||||||
|
:type model_type: Either a (unfitted) model object of the underlying framework, or a ModelType representing the
|
||||||
|
type of the model, optional.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Any, output_type: ModelOutputType, black_box_access: Optional[bool] = True,
|
def __init__(self, model: Any, output_type: ModelOutputType, black_box_access: Optional[bool] = True,
|
||||||
unlimited_queries: Optional[bool] = True, **kwargs):
|
unlimited_queries: Optional[bool] = True, model_type: Optional[Union[Any, ModelType]] = None,
|
||||||
|
**kwargs):
|
||||||
super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs)
|
super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs)
|
||||||
self._nb_classes = None
|
self._nb_classes = None
|
||||||
self._input_shape = None
|
self._input_shape = None
|
||||||
|
self._model_type = model_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nb_classes(self) -> int:
|
def nb_classes(self) -> int:
|
||||||
|
|
@ -183,6 +195,16 @@ class BlackboxClassifier(Model):
|
||||||
"""
|
"""
|
||||||
return self._input_shape
|
return self._input_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_type(self) -> Optional[Union[Any, ModelType]]:
|
||||||
|
"""
|
||||||
|
Return the type of the model.
|
||||||
|
|
||||||
|
:return: Either a (unfitted) model object of the underlying framework, or a ModelType representing the type of
|
||||||
|
the model, or None (of none provided at init).
|
||||||
|
"""
|
||||||
|
return self._model_type
|
||||||
|
|
||||||
def fit(self, train_data: Dataset, **kwargs) -> None:
|
def fit(self, train_data: Dataset, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
A blackbox model cannot be fit.
|
A blackbox model cannot be fit.
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,8 @@ def test_blackbox_classifier():
|
||||||
score = model.score(test)
|
score = model.score(test)
|
||||||
assert(score == 1.0)
|
assert(score == 1.0)
|
||||||
|
|
||||||
|
assert model.model_type is None
|
||||||
|
|
||||||
def test_blackbox_classifier_no_test():
|
def test_blackbox_classifier_no_test():
|
||||||
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
|
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue