Support for multi-label binary models in minimizer. First test with pytorch model passing.

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-03-13 09:59:02 +02:00
parent 076503b248
commit 7e34f0d2ff
2 changed files with 104 additions and 15 deletions

View file

@ -93,6 +93,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
if is_regression:
self.estimator = SklearnRegressor(estimator)
else:
#TODO: maybe we should get model output type from user in this case
self.estimator = SklearnClassifier(estimator,
ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_CLASS_PROBABILITIES)
self.target_accuracy = target_accuracy
@ -679,7 +680,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# this is a leaf
# if it is a regression problem we do not use label
label = self._calculate_cell_label(node) if not self.is_regression else 1
hist = [int(i) for i in self._dt.tree_.value[node][0]] if not self.is_regression else []
hist = self._dt.tree_.value[node]
cell = {'label': label, 'hist': hist, 'ranges': {}, 'id': int(node)}
return [cell]
@ -710,8 +711,11 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
return cells
def _calculate_cell_label(self, node):
label_hist = self._dt.tree_.value[node][0]
return int(self._dt.classes_[np.argmax(label_hist)])
label_hist = self._dt.tree_.value[node]
if isinstance(self._dt.classes_, list):
return [self._dt.classes_[output][class_index]
for output, class_index in enumerate(np.argmax(label_hist, axis=1))]
return [self._dt.classes_[np.argmax(label_hist[0])]]
def _modify_cells(self):
cells = []
@ -808,9 +812,15 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# else: nothing to do, stay with previous cells
def _calculate_level_cell_label(self, left_cell, right_cell, new_cell):
new_cell['hist'] = [x + y for x, y in
zip(left_cell['hist'], right_cell['hist'])] if not self.is_regression else []
new_cell['label'] = int(self._dt.classes_[np.argmax(new_cell['hist'])]) if not self.is_regression else 1
new_cell['hist'] = left_cell['hist'] + right_cell['hist']
# [x + y for x, y in
# zip(left_cell['hist'], right_cell['hist'])] if not self.is_regression else []
if isinstance(self._dt.classes_, list):
new_cell['label'] = [self._dt.classes_[output][class_index]
for output, class_index in enumerate(np.argmax(new_cell['hist'], axis=1))]
else:
new_cell['label'] = [self._dt.classes_[np.argmax(new_cell['hist'][0])]]
def _get_nodes_level(self, level):
# level = distance from lowest leaf
@ -838,26 +848,28 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# return all nodes with depth == level or leaves higher than level
return [i for i, x in enumerate(node_depth) if x == depth or (x < depth and is_leaves[i])]
def _attach_cells_representatives(self, prepared_data, originalTrainFeatures, labelFeature, level_nodes):
def _attach_cells_representatives(self, prepared_data, original_train_features, label_feature, level_nodes):
# prepared data include one hot encoded categorical data,
# if there is no categorical data prepared data is original data
nodeIds = self._find_sample_nodes(prepared_data, level_nodes)
labels_df = pd.DataFrame(labelFeature, columns=['label'])
for cell in self.cells:
cell['representative'] = {}
# get all rows in cell
indexes = [i for i, x in enumerate(nodeIds) if x == cell['id']]
original_rows = originalTrainFeatures.iloc[indexes]
original_rows = original_train_features.iloc[indexes]
sample_rows = prepared_data.iloc[indexes]
sample_labels = labels_df.iloc[indexes]['label'].values.tolist()
# get rows with matching label
if self.is_regression:
if self.is_regression or (len(label_feature.shape) > 1 and label_feature.shape[1] > 1):
match_samples = sample_rows
match_rows = original_rows
else:
indexes = [i for i, label in enumerate(sample_labels) if label == cell['label']]
labels_df = pd.DataFrame(label_feature, columns=['label'])
sample_labels = labels_df.iloc[indexes]['label'].values.tolist()
indexes = [i for i, label in enumerate(sample_labels) if label == cell['label'][0]]
match_samples = sample_rows.iloc[indexes]
match_rows = original_rows.iloc[indexes]
# find the "middle" of the cluster
array = match_samples.values
# Only works with numpy 1.9.0 and higher!!!

View file

@ -4,21 +4,25 @@ import pandas as pd
import scipy
from sklearn.compose import ColumnTransformer
from sklearn.datasets import load_diabetes
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from torch import nn, optim
from torch import nn, optim, sigmoid, where
from torch.nn import functional
from scipy.special import expit
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input
from apt.utils.datasets.datasets import PytorchData
from apt.utils.models import ModelOutputType
from apt.utils.models.pytorch_model import PyTorchClassifier
from apt.minimization import GeneralizeToRepresentative
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from apt.utils.dataset_utils import get_iris_dataset_np, get_adult_dataset_pd, get_german_credit_dataset_pd
from apt.utils.datasets import ArrayDataset
from apt.utils.models import SklearnClassifier, ModelOutputType, SklearnRegressor, KerasClassifier
@ -1335,6 +1339,79 @@ def test_minimizer_pytorch_iris():
assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF)
def test_minimizer_pytorch_multi_label_binary():
class multi_label_binary_model(nn.Module):
def __init__(self, num_labels, num_features):
super(multi_label_binary_model, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 256),
nn.Tanh(), )
self.classifier1 = nn.Linear(256, num_labels)
def forward(self, x):
return self.classifier1(self.fc1(x))
# missing sigmoid on each output
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.5):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, input, target):
bce_loss = functional.binary_cross_entropy_with_logits(input, target, reduction='none')
p = sigmoid(input)
p = where(target >= 0.5, p, 1-p)
modulating_factor = (1 - p)**self.gamma
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
focal_loss = alpha * modulating_factor * bce_loss
return focal_loss.mean()
(x_train, y_train), _ = get_iris_dataset_np()
features = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
qi = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
# make multi-label binary
y_train = np.column_stack((y_train, y_train, y_train))
y_train[y_train > 1] = 1
x_train = x_train.astype(np.float32)
y_train = y_train.astype(np.float32)
orig_model = multi_label_binary_model(3, 4)
criterion = FocalLoss()
optimizer = optim.RMSprop(orig_model.parameters(), lr=0.01)
model = PyTorchClassifier(model=orig_model,
output_type=ModelOutputType.CLASSIFIER_MULTI_OUTPUT_BINARY_LOGITS,
loss=criterion,
optimizer=optimizer,
input_shape=(24,),
nb_classes=3)
model.fit(PytorchData(x_train, y_train), save_entire_model=False,
nb_epochs=10)
predictions = model.predict(PytorchData(x_train, y_train))
predictions = expit(predictions)
predictions[predictions < 0.5] = 0
predictions[predictions >= 0.5] = 1
target_accuracy = 0.99
gen = GeneralizeToRepresentative(model, target_accuracy=target_accuracy, features_to_minimize=qi)
transformed = gen.fit_transform(dataset=ArrayDataset(x_train, predictions, features_names=features))
gener = gen.generalizations
check_features(features, gener, transformed, x_train)
ncp = gen.ncp.transform_score
check_ncp(ncp, gener)
rel_accuracy = model.score(ArrayDataset(transformed.astype(np.float32), predictions))
assert ((rel_accuracy >= target_accuracy) or (target_accuracy - rel_accuracy) <= ACCURACY_DIFF)
def test_untouched():
cells = [{"id": 1, "ranges": {"age": {"start": None, "end": 38}}, "label": 0,
'categories': {'gender': ['male']}, "representative": {"age": 26, "height": 149}},