Test for sklearn (currently not passing due to ART dependency)

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-04-30 14:51:35 +03:00
parent 8b8b461143
commit b3f87623b1
2 changed files with 25 additions and 1 deletions

View file

@ -37,7 +37,9 @@ class ScoringMethod(Enum):
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
return len(y.shape) == 2 and y.shape[1] > 1 and np.all(np.around(np.sum(y, axis=1), decimals=4) == 1)
if not isinstance(y, list):
return len(y.shape) == 2 and y.shape[1] > 1 and np.all(np.around(np.sum(y, axis=1), decimals=4) == 1)
return False
def is_multi_label(output_type: ModelOutputType) -> bool:

View file

@ -35,6 +35,28 @@ def test_sklearn_classifier():
assert (0.0 <= score <= 1.0)
# This test currently cannot pass due to ART dependency, so sklearn support will need to wait until ART is updated
# def test_sklearn_classifier_predictions_multi_label_binary():
# (x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
#
# # make multi-label binary
# y_train = np.column_stack((y_train, y_train, y_train))
# y_train[y_train > 1] = 1
# y_test = np.column_stack((y_test, y_test, y_test))
# y_test[y_test > 1] = 1
#
# test = ArrayDataset(x_test, y_test)
#
# underlying_model = RandomForestClassifier()
# underlying_model.fit(x_train, y_train)
# model = SklearnClassifier(underlying_model, ModelOutputType.CLASSIFIER_MULTI_OUTPUT_BINARY_PROBABILITIES)
# pred = model.predict(test)
# assert (pred[0].shape[0] == x_test.shape[0])
#
# score = model.score(test)
# assert (score == 1.0)
def test_sklearn_regressor():
(x_train, y_train), (x_test, y_test) = dataset_utils.get_diabetes_dataset_np()
underlying_model = DecisionTreeRegressor()