Add model type to blackbox classifier (#49)

This commit is contained in:
abigailt 2022-07-18 19:27:14 +03:00 committed by abigailgold
parent bc28f7f26a
commit bc7ab0cc7f
2 changed files with 26 additions and 2 deletions

View file

@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Any, Optional, Callable, Tuple
from typing import Any, Optional, Callable, Tuple, Union
from enum import Enum, auto
import numpy as np
@ -31,6 +31,7 @@ def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
else:
return int(np.max(y) + 1)
class ModelOutputType(Enum):
CLASSIFIER_PROBABILITIES = auto() # vector of probabilities
CLASSIFIER_LOGITS = auto() # vector of logits
@ -38,6 +39,11 @@ class ModelOutputType(Enum):
REGRESSOR_SCALAR = auto() # value
class ModelType(Enum):
SKLEARN_DECISION_TREE = auto()
SKLEARN_GRADIENT_BOOSTING = auto()
class ScoringMethod(Enum):
ACCURACY = auto() # number of correct predictions divided by the number of samples
MEAN_SQUARED_ERROR = auto() # mean squared error between the predictions and true labels
@ -157,13 +163,19 @@ class BlackboxClassifier(Model):
:type black_box_access: boolean, optional
:param unlimited_queries: Boolean indicating whether a user can perform unlimited queries to the model API.
:type unlimited_queries: boolean, optional
:param model_type: The type of model this BlackboxClassifier represents. Needed in order to build and/or fit
similar dummy/shadow models.
:type model_type: Either a (unfitted) model object of the underlying framework, or a ModelType representing the
type of the model, optional.
"""
def __init__(self, model: Any, output_type: ModelOutputType, black_box_access: Optional[bool] = True,
unlimited_queries: Optional[bool] = True, **kwargs):
unlimited_queries: Optional[bool] = True, model_type: Optional[Union[Any, ModelType]] = None,
**kwargs):
super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs)
self._nb_classes = None
self._input_shape = None
self._model_type = model_type
@property
def nb_classes(self) -> int:
@ -183,6 +195,16 @@ class BlackboxClassifier(Model):
"""
return self._input_shape
@property
def model_type(self) -> Optional[Union[Any, ModelType]]:
"""
Return the type of the model.
:return: Either a (unfitted) model object of the underlying framework, or a ModelType representing the type of
the model, or None (of none provided at init).
"""
return self._model_type
def fit(self, train_data: Dataset, **kwargs) -> None:
"""
A blackbox model cannot be fit.

View file

@ -79,6 +79,8 @@ def test_blackbox_classifier():
score = model.score(test)
assert(score == 1.0)
assert model.model_type is None
def test_blackbox_classifier_no_test():
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()