mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-23 15:48:06 +02:00
Initial version of general model wrappers and methods supporting multi-label classifiers
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
cb70ca10e6
commit
f197199e54
2 changed files with 80 additions and 7 deletions
|
|
@ -156,6 +156,54 @@ def test_blackbox_classifier_predictions_y():
|
|||
assert model.model_type is None
|
||||
|
||||
|
||||
def test_blackbox_classifier_predictions_multi_label_cat():
|
||||
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
|
||||
|
||||
# make multi-label categorical
|
||||
y_train = np.column_stack((y_train, y_train, y_train))
|
||||
y_test = np.column_stack((y_test, y_test, y_test))
|
||||
|
||||
train = DatasetWithPredictions(y_train, x_train, y_train)
|
||||
test = DatasetWithPredictions(y_test, x_test, y_test)
|
||||
data = Data(train, test)
|
||||
model = BlackboxClassifierPredictions(data, ModelOutputType.CLASSIFIER_SCALAR)
|
||||
pred = model.predict(test)
|
||||
assert (pred.shape[0] == x_test.shape[0])
|
||||
|
||||
score = model.score(test)
|
||||
assert (score == 1.0)
|
||||
|
||||
assert model.model_type is None
|
||||
|
||||
|
||||
def test_blackbox_classifier_predictions_multi_label_binary():
|
||||
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
|
||||
|
||||
# make multi-label categorical
|
||||
y_train = np.column_stack((y_train, y_train, y_train))
|
||||
y_train[y_train > 1] = 1
|
||||
pred_train = y_train.copy().astype(float)
|
||||
pred_train[pred_train == 0] = 0.2
|
||||
pred_train[pred_train == 1] = 0.6
|
||||
y_test = np.column_stack((y_test, y_test, y_test))
|
||||
y_test[y_test > 1] = 1
|
||||
pred_test = y_test.copy().astype(float)
|
||||
pred_test[pred_test == 0] = 0.2
|
||||
pred_test[pred_test == 1] = 0.6
|
||||
|
||||
train = DatasetWithPredictions(pred_train, x_train, y_train)
|
||||
test = DatasetWithPredictions(pred_test, x_test, y_test)
|
||||
data = Data(train, test)
|
||||
model = BlackboxClassifierPredictions(data, ModelOutputType.CLASSIFIER_SCALAR)
|
||||
pred = model.predict(test)
|
||||
assert (pred.shape[0] == x_test.shape[0])
|
||||
|
||||
score = model.score(test)
|
||||
assert (score == 1.0)
|
||||
|
||||
assert model.model_type is None
|
||||
|
||||
|
||||
def test_blackbox_classifier_mismatch():
|
||||
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue