mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-08 11:32:37 +02:00
Add more to wrappers
This commit is contained in:
parent
a2b560920f
commit
a9162fbc43
6 changed files with 86 additions and 38 deletions
|
|
@ -1,7 +1,14 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any
|
||||
from enum import Enum, auto
|
||||
|
||||
from apt.utils.datasets import Dataset, DATA_ARRAY_TYPE
|
||||
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
||||
|
||||
|
||||
class ModelOutputType(Enum):
|
||||
CLASSIFIER_VECTOR = auto() # probabilities or logits
|
||||
CLASSIFIER_SCALAR = auto() # label only
|
||||
REGRESSOR_SCALAR = auto() # value
|
||||
|
||||
|
||||
class Model(metaclass=ABCMeta):
|
||||
|
|
@ -9,13 +16,16 @@ class Model(metaclass=ABCMeta):
|
|||
Abstract base class for ML model wrappers.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Any, **kwargs):
|
||||
def __init__(self, model: Any, output_type: ModelOutputType, **kwargs):
|
||||
"""
|
||||
Initialize a `Model` wrapper object.
|
||||
|
||||
:param model: The original model object (of the underlying ML framework)
|
||||
:param output_type: The type of output the model yields (vector/label only for classifiers,
|
||||
value for regressors)
|
||||
"""
|
||||
self._model = model
|
||||
self._output_type = output_type
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, train_data: Dataset, **kwargs) -> None:
|
||||
|
|
@ -28,7 +38,7 @@ class Model(metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
|
||||
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""
|
||||
Perform predictions using the model for input `x`.
|
||||
|
||||
|
|
@ -39,10 +49,19 @@ class Model(metaclass=ABCMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
def model(self) -> Any:
|
||||
"""
|
||||
Return the model.
|
||||
|
||||
:return: The model.
|
||||
"""
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def output_type(self) -> ModelOutputType:
|
||||
"""
|
||||
Return the model's output type.
|
||||
|
||||
:return: The model's output type.
|
||||
"""
|
||||
return self._output_type
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue