mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
Test for sklearn (currently not passing due to ART dependency)
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
8b8b461143
commit
b3f87623b1
2 changed files with 25 additions and 1 deletions
|
|
@ -37,7 +37,9 @@ class ScoringMethod(Enum):
|
|||
|
||||
|
||||
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
|
||||
return len(y.shape) == 2 and y.shape[1] > 1 and np.all(np.around(np.sum(y, axis=1), decimals=4) == 1)
|
||||
if not isinstance(y, list):
|
||||
return len(y.shape) == 2 and y.shape[1] > 1 and np.all(np.around(np.sum(y, axis=1), decimals=4) == 1)
|
||||
return False
|
||||
|
||||
|
||||
def is_multi_label(output_type: ModelOutputType) -> bool:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,28 @@ def test_sklearn_classifier():
|
|||
assert (0.0 <= score <= 1.0)
|
||||
|
||||
|
||||
# This test currently cannot pass due to ART dependency, so sklearn support will need to wait until ART is updated
|
||||
# def test_sklearn_classifier_predictions_multi_label_binary():
|
||||
# (x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
|
||||
#
|
||||
# # make multi-label binary
|
||||
# y_train = np.column_stack((y_train, y_train, y_train))
|
||||
# y_train[y_train > 1] = 1
|
||||
# y_test = np.column_stack((y_test, y_test, y_test))
|
||||
# y_test[y_test > 1] = 1
|
||||
#
|
||||
# test = ArrayDataset(x_test, y_test)
|
||||
#
|
||||
# underlying_model = RandomForestClassifier()
|
||||
# underlying_model.fit(x_train, y_train)
|
||||
# model = SklearnClassifier(underlying_model, ModelOutputType.CLASSIFIER_MULTI_OUTPUT_BINARY_PROBABILITIES)
|
||||
# pred = model.predict(test)
|
||||
# assert (pred[0].shape[0] == x_test.shape[0])
|
||||
#
|
||||
# score = model.score(test)
|
||||
# assert (score == 1.0)
|
||||
|
||||
|
||||
def test_sklearn_regressor():
|
||||
(x_train, y_train), (x_test, y_test) = dataset_utils.get_diabetes_dataset_np()
|
||||
underlying_model = DecisionTreeRegressor()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue