mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-02 16:22:37 +02:00
Add more to wrappers
This commit is contained in:
parent
f2df2fcc8c
commit
45cc9180b8
6 changed files with 74 additions and 30 deletions
|
|
@ -3,8 +3,8 @@ import numpy as np
|
|||
from sklearn.preprocessing import OneHotEncoder
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from apt.utils.models import Model
|
||||
from apt.utils.datasets import Dataset, DATA_ARRAY_TYPE
|
||||
from apt.utils.models import Model, ModelOutputType
|
||||
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
|
||||
|
||||
from art.estimators.classification.scikitlearn import SklearnClassifier as ArtSklearnClassifier
|
||||
from art.estimators.regression.scikitlearn import ScikitlearnRegressor
|
||||
|
|
@ -28,13 +28,13 @@ class SklearnClassifier(SklearnModel):
|
|||
"""
|
||||
Wrapper class for scikitlearn classification models.
|
||||
"""
|
||||
def __init__(self, model: BaseEstimator, **kwargs):
|
||||
def __init__(self, model: BaseEstimator, output_type: ModelOutputType, **kwargs):
|
||||
"""
|
||||
Initialize a `SklearnClassifier` wrapper object.
|
||||
|
||||
:param model: The original sklearn model object
|
||||
"""
|
||||
super().__init__(model, **kwargs)
|
||||
super().__init__(model, output_type, **kwargs)
|
||||
self._art_model = ArtSklearnClassifier(model)
|
||||
|
||||
def fit(self, train_data: Dataset, **kwargs) -> None:
|
||||
|
|
@ -48,7 +48,7 @@ class SklearnClassifier(SklearnModel):
|
|||
y_encoded = encoder.fit_transform(train_data.get_labels().reshape(-1, 1))
|
||||
self._art_model.fit(train_data.get_samples(), y_encoded, **kwargs)
|
||||
|
||||
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
|
||||
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""
|
||||
Perform predictions using the model for input `x`.
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ class SklearnRegressor(SklearnModel):
|
|||
|
||||
:param model: The original sklearn model object
|
||||
"""
|
||||
super().__init__(model, **kwargs)
|
||||
super().__init__(model, ModelOutputType.REGRESSOR_SCALAR, **kwargs)
|
||||
self._art_model = ScikitlearnRegressor(model)
|
||||
|
||||
def fit(self, train_data: Dataset, **kwargs) -> None:
|
||||
|
|
@ -81,7 +81,7 @@ class SklearnRegressor(SklearnModel):
|
|||
"""
|
||||
self._art_model.fit(train_data.get_samples(), train_data.get_labels(), **kwargs)
|
||||
|
||||
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
|
||||
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
|
||||
"""
|
||||
Perform predictions using the model for input `x`.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue