Support for multi-label binary models in minimizer. First test with pytorch model passing.

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-03-13 09:59:02 +02:00
parent 076503b248
commit 7e34f0d2ff
2 changed files with 104 additions and 15 deletions

View file

@ -4,21 +4,25 @@ import pandas as pd
import scipy
from sklearn.compose import ColumnTransformer
from sklearn.datasets import load_diabetes
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from torch import nn, optim
from torch import nn, optim, sigmoid, where
from torch.nn import functional
from scipy.special import expit
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import ModelOutputType
from apt.utils.models.pytorch_model import PyTorchClassifier
from apt.minimization import GeneralizeToRepresentative
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from apt.utils.dataset_utils import get_iris_dataset_np, get_adult_dataset_pd, get_german_credit_dataset_pd
from apt.utils.datasets import ArrayDataset
from apt.utils.models import SklearnClassifier, ModelOutputType, SklearnRegressor, KerasClassifier
@ -1335,6 +1339,79 @@ def test_minimizer_pytorch_iris():
assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF)
def test_minimizer_pytorch_multi_label_binary():
class multi_label_binary_model(nn.Module):
def __init__(self, num_labels, num_features):
super(multi_label_binary_model, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 256),
nn.Tanh(), )
self.classifier1 = nn.Linear(256, num_labels)
def forward(self, x):
return self.classifier1(self.fc1(x))
# missing sigmoid on each output
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.5):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, input, target):
bce_loss = functional.binary_cross_entropy_with_logits(input, target, reduction='none')
p = sigmoid(input)
p = where(target >= 0.5, p, 1-p)
modulating_factor = (1 - p)**self.gamma
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
focal_loss = alpha * modulating_factor * bce_loss
return focal_loss.mean()
(x_train, y_train), _ = get_iris_dataset_np()
features = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
qi = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
# make multi-label binary
y_train = np.column_stack((y_train, y_train, y_train))
y_train[y_train > 1] = 1
x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.float32)
orig_model = multi_label_binary_model(3, 4)
criterion = FocalLoss()
optimizer = optim.RMSprop(orig_model.parameters(), lr=0.01)
model = PyTorchClassifier(model=orig_model,
output_type=ModelOutputType.CLASSIFIER_MULTI_OUTPUT_BINARY_LOGITS,
loss=criterion,
optimizer=optimizer,
input_shape=(24,),
nb_classes=3)
model.fit(PytorchData(x_train, y_train), save_entire_model=False,
nb_epochs=10)
predictions = model.predict(PytorchData(x_train, y_train))
predictions = expit(predictions)
predictions[predictions < 0.5] = 0
predictions[predictions >= 0.5] = 1
target_accuracy = 0.99
gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy, features_to_minimize=qi)
transformed = gen.fit_transform(dataset=ArrayDataset(x_train, predictions, features_names=features))
gener = gen.generalizations
check_features(features, gener, transformed, x_train)
ncp = gen.ncp.transform_score
check_ncp(ncp, gener)
rel_accuracy = model.score(ArrayDataset(transformed.astype(np.float32), predictions))
assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF)
def test_untouched():
cells = [{"id": 1, "ranges": {"age": {"start": None, "end": 38}}, "label": 0,
'categories': {'gender': ['male']}, "representative": {"age": 26, "height": 149}},