Test for sklearn (currently not passing due to ART dependency)

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-04-30 14:51:35 +03:00
parent 8b8b461143
commit b3f87623b1
2 changed files with 25 additions and 1 deletions

View file

@ -35,6 +35,28 @@ def test_sklearn_classifier():
assert (0.0 <= score <= 1.0)
# This test currently cannot pass due to ART dependency, so sklearn support will need to wait until ART is updated
# def test_sklearn_classifier_predictions_multi_label_binary():
# (x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
#
# # make multi-label binary
# y_train = np.column_stack((y_train, y_train, y_train))
# y_train[y_train > 1] = 1
# y_test = np.column_stack((y_test, y_test, y_test))
# y_test[y_test > 1] = 1
#
# test = ArrayDataset(x_test, y_test)
#
# underlying_model = RandomForestClassifier()
# underlying_model.fit(x_train, y_train)
# model = SklearnClassifier(underlying_model, ModelOutputType.CLASSIFIER_MULTI_OUTPUT_BINARY_PROBABILITIES)
# pred = model.predict(test)
# assert (pred[0].shape[0] == x_test.shape[0])
#
# score = model.score(test)
# assert (score == 1.0)
def test_sklearn_regressor():
(x_train, y_train), (x_test, y_test) = dataset_utils.get_diabetes_dataset_np()
underlying_model = DecisionTreeRegressor()