mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-05-06 10:32:38 +02:00
Add more to wrappers
This commit is contained in:
parent
f2df2fcc8c
commit
45cc9180b8
6 changed files with 74 additions and 30 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from apt.utils.models import SklearnClassifier, SklearnRegressor
|
||||
from apt.utils.models import SklearnClassifier, SklearnRegressor, ModelOutputType
|
||||
from apt.utils.datasets import ArrayDataset
|
||||
from apt.utils import dataset_utils
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ from sklearn.ensemble import RandomForestClassifier
|
|||
def test_sklearn_classifier():
|
||||
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset()
|
||||
underlying_model = RandomForestClassifier()
|
||||
model = SklearnClassifier(underlying_model)
|
||||
model = SklearnClassifier(underlying_model, ModelOutputType.CLASSIFIER_VECTOR)
|
||||
train = ArrayDataset(x_train, y_train)
|
||||
test = ArrayDataset(x_test, y_test)
|
||||
model.fit(train)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue