diff --git a/apt/utils/models/model.py b/apt/utils/models/model.py index fc48633..2be6f8f 100644 --- a/apt/utils/models/model.py +++ b/apt/utils/models/model.py @@ -37,7 +37,9 @@ class ScoringMethod(Enum): def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool: - return len(y.shape) == 2 and y.shape[1] > 1 and np.all(np.around(np.sum(y, axis=1), decimals=4) == 1) + if not isinstance(y, list): + return len(y.shape) == 2 and y.shape[1] > 1 and np.all(np.around(np.sum(y, axis=1), decimals=4) == 1) + return False def is_multi_label(output_type: ModelOutputType) -> bool: diff --git a/tests/test_model.py b/tests/test_model.py index efbc4c0..eba8b79 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -35,6 +35,28 @@ def test_sklearn_classifier(): assert (0.0 <= score <= 1.0) +# This test currently cannot pass due to ART dependency, so sklearn support will need to wait until ART is updated +# def test_sklearn_classifier_predictions_multi_label_binary(): +# (x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np() +# +# # make multi-label binary +# y_train = np.column_stack((y_train, y_train, y_train)) +# y_train[y_train > 1] = 1 +# y_test = np.column_stack((y_test, y_test, y_test)) +# y_test[y_test > 1] = 1 +# +# test = ArrayDataset(x_test, y_test) +# +# underlying_model = RandomForestClassifier() +# underlying_model.fit(x_train, y_train) +# model = SklearnClassifier(underlying_model, ModelOutputType.CLASSIFIER_MULTI_OUTPUT_BINARY_PROBABILITIES) +# pred = model.predict(test) +# assert (pred[0].shape[0] == x_test.shape[0]) +# +# score = model.score(test) +# assert (score == 1.0) + + def test_sklearn_regressor(): (x_train, y_train), (x_test, y_test) = dataset_utils.get_diabetes_dataset_np() underlying_model = DecisionTreeRegressor()