diff --git a/apt/utils/models/model.py b/apt/utils/models/model.py index 8ea837a..9e8379d 100644 --- a/apt/utils/models/model.py +++ b/apt/utils/models/model.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Any, Optional, Callable, Tuple +from typing import Any, Optional, Callable, Tuple, Union from enum import Enum, auto import numpy as np @@ -31,6 +31,7 @@ def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int: else: return int(np.max(y) + 1) + class ModelOutputType(Enum): CLASSIFIER_PROBABILITIES = auto() # vector of probabilities CLASSIFIER_LOGITS = auto() # vector of logits @@ -38,6 +39,11 @@ class ModelOutputType(Enum): REGRESSOR_SCALAR = auto() # value +class ModelType(Enum): + SKLEARN_DECISION_TREE = auto() + SKLEARN_GRADIENT_BOOSTING = auto() + + class ScoringMethod(Enum): ACCURACY = auto() # number of correct predictions divided by the number of samples MEAN_SQUARED_ERROR = auto() # mean squared error between the predictions and true labels @@ -157,13 +163,19 @@ class BlackboxClassifier(Model): :type black_box_access: boolean, optional :param unlimited_queries: Boolean indicating whether a user can perform unlimited queries to the model API. :type unlimited_queries: boolean, optional + :param model_type: The type of model this BlackboxClassifier represents. Needed in order to build and/or fit + similar dummy/shadow models. + :type model_type: Either a (unfitted) model object of the underlying framework, or a ModelType representing the + type of the model, optional. """ def __init__(self, model: Any, output_type: ModelOutputType, black_box_access: Optional[bool] = True, - unlimited_queries: Optional[bool] = True, **kwargs): + unlimited_queries: Optional[bool] = True, model_type: Optional[Union[Any, ModelType]] = None, + **kwargs): super().__init__(model, output_type, black_box_access=True, unlimited_queries=unlimited_queries, **kwargs) self._nb_classes = None self._input_shape = None + self._model_type = model_type @property def nb_classes(self) -> int: @@ -183,6 +195,16 @@ class BlackboxClassifier(Model): """ return self._input_shape + @property + def model_type(self) -> Optional[Union[Any, ModelType]]: + """ + Return the type of the model. + + :return: Either a (unfitted) model object of the underlying framework, or a ModelType representing the type of + the model, or None (of none provided at init). + """ + return self._model_type + def fit(self, train_data: Dataset, **kwargs) -> None: """ A blackbox model cannot be fit. diff --git a/tests/test_model.py b/tests/test_model.py index 138c11c..195ad81 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -79,6 +79,8 @@ def test_blackbox_classifier(): score = model.score(test) assert(score == 1.0) + assert model.model_type is None + def test_blackbox_classifier_no_test(): (x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()