Initial version of general model wrappers and methods supporting multi-label classifiers

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-02-12 09:45:36 +02:00
parent cb70ca10e6
commit f197199e54
2 changed files with 80 additions and 7 deletions

View file

@ -156,6 +156,54 @@ def test_blackbox_classifier_predictions_y():
assert model.model_type is None
def test_blackbox_classifier_predictions_multi_label_cat():
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
# make multi-label categorical
y_train = np.column_stack((y_train, y_train, y_train))
y_test = np.column_stack((y_test, y_test, y_test))
train = DatasetWithPredictions(y_train, x_train, y_train)
test = DatasetWithPredictions(y_test, x_test, y_test)
data = Data(train, test)
model = BlackboxClassifierPredictions(data, ModelOutputType.CLASSIFIER_SCALAR)
pred = model.predict(test)
assert (pred.shape[0] == x_test.shape[0])
score = model.score(test)
assert (score == 1.0)
assert model.model_type is None
def test_blackbox_classifier_predictions_multi_label_binary():
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()
# make multi-label categorical
y_train = np.column_stack((y_train, y_train, y_train))
y_train[y_train > 1] = 1
pred_train = y_train.copy().astype(float)
pred_train[pred_train == 0] = 0.2
pred_train[pred_train == 1] = 0.6
y_test = np.column_stack((y_test, y_test, y_test))
y_test[y_test > 1] = 1
pred_test = y_test.copy().astype(float)
pred_test[pred_test == 0] = 0.2
pred_test[pred_test == 1] = 0.6
train = DatasetWithPredictions(pred_train, x_train, y_train)
test = DatasetWithPredictions(pred_test, x_test, y_test)
data = Data(train, test)
model = BlackboxClassifierPredictions(data, ModelOutputType.CLASSIFIER_SCALAR)
pred = model.predict(test)
assert (pred.shape[0] == x_test.shape[0])
score = model.score(test)
assert (score == 1.0)
assert model.model_type is None
def test_blackbox_classifier_mismatch():
(x_train, y_train), (x_test, y_test) = dataset_utils.get_iris_dataset_np()