mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-20 15:38:05 +02:00
Test for sklearn (currently not passing due to ART dependency)
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
8b8b461143
commit
b3f87623b1
2 changed files with 25 additions and 1 deletions
|
|
@ -35,6 +35,28 @@ def test_sklearn_classifier():
|
|||
assert (0.0 <= score <= 1.0)
|
||||
|
||||
|
||||
# This test currently cannot pass due to ART dependency, so sklearn support will need to wait until ART is updated
|
||||
# def test_sklearn_classifier_predictions_multi_label_binary():
|
||||
# (x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
|
||||
#
|
||||
# # make multi-label binary
|
||||
# y_train = np.column_stack((y_train, y_train, y_train))
|
||||
# y_train[y_train > 1] = 1
|
||||
# y_test = np.column_stack((y_test, y_test, y_test))
|
||||
# y_test[y_test > 1] = 1
|
||||
#
|
||||
# test = ArrayDataset(x_test, y_test)
|
||||
#
|
||||
# underlying_model = RandomForestClassifier()
|
||||
# underlying_model.fit(x_train, y_train)
|
||||
# model = SklearnClassifier(underlying_model, ModelOutputType.CLASSIFIER_MULTI_OUTPUT_BINARY_PROBABILITIES)
|
||||
# pred = model.predict(test)
|
||||
# assert (pred[0].shape[0] == x_test.shape[0])
|
||||
#
|
||||
# score = model.score(test)
|
||||
# assert (score == 1.0)
|
||||
|
||||
|
||||
def test_sklearn_regressor():
|
||||
(x_train, y_train), (x_test, y_test) = dataset_utils.get_diabetes_dataset_np()
|
||||
underlying_model = DecisionTreeRegressor()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue