Support for multi-label binary models in minimizer. First test with pytorch model passing.

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-03-13 09:59:02 +02:00
parent 076503b248
commit 7e34f0d2ff
2 changed files with 104 additions and 15 deletions

View file

@ -93,6 +93,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
if is_regression:
self.estimator = SklearnRegressor(estimator)
else:
#TODO: maybe we should get model output type from user in this case
self.estimator = SklearnClassifier(estimator,
ModelOutputType.CLASSIFIER_SINGLE_OUTPUT_CLASS_PROBABILITIES)
self.target_accuracy = target_accuracy
@ -679,7 +680,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# this is a leaf
# if it is a regression problem we do not use label
label = self._calculate_cell_label(node) if not self.is_regression else 1
hist = [int(i) for i in self._dt.tree_.value[node][0]] if not self.is_regression else []
hist = self._dt.tree_.value[node]
cell = {'label': label, 'hist': hist, 'ranges': {}, 'id': int(node)}
return [cell]
@ -710,8 +711,11 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
return cells
def _calculate_cell_label(self, node):
label_hist = self._dt.tree_.value[node][0]
return int(self._dt.classes_[np.argmax(label_hist)])
label_hist = self._dt.tree_.value[node]
if isinstance(self._dt.classes_, list):
return [self._dt.classes_[output][class_index]
for output, class_index in enumerate(np.argmax(label_hist, axis=1))]
return [self._dt.classes_[np.argmax(label_hist[0])]]
def _modify_cells(self):
cells = []
@ -808,9 +812,15 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# else: nothing to do, stay with previous cells
def _calculate_level_cell_label(self, left_cell, right_cell, new_cell):
new_cell['hist'] = [x + y for x, y in
zip(left_cell['hist'], right_cell['hist'])] if not self.is_regression else []
new_cell['label'] = int(self._dt.classes_[np.argmax(new_cell['hist'])]) if not self.is_regression else 1
new_cell['hist'] = left_cell['hist'] + right_cell['hist']
# [x + y for x, y in
# zip(left_cell['hist'], right_cell['hist'])] if not self.is_regression else []
if isinstance(self._dt.classes_, list):
new_cell['label'] = [self._dt.classes_[output][class_index]
for output, class_index in enumerate(np.argmax(new_cell['hist'], axis=1))]
else:
new_cell['label'] = [self._dt.classes_[np.argmax(new_cell['hist'][0])]]
def _get_nodes_level(self, level):
# level = distance from lowest leaf
@ -838,26 +848,28 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# return all nodes with depth == level or leaves higher than level
return [i for i, x in enumerate(node_depth) if x == depth or (x < depth and is_leaves[i])]
def _attach_cells_representatives(self, prepared_data, originalTrainFeatures, labelFeature, level_nodes):
def _attach_cells_representatives(self, prepared_data, original_train_features, label_feature, level_nodes):
# prepared data include one hot encoded categorical data,
# if there is no categorical data prepared data is original data
nodeIds = self._find_sample_nodes(prepared_data, level_nodes)
labels_df = pd.DataFrame(labelFeature, columns=['label'])
for cell in self.cells:
cell['representative'] = {}
# get all rows in cell
indexes = [i for i, x in enumerate(nodeIds) if x == cell['id']]
original_rows = originalTrainFeatures.iloc[indexes]
original_rows = original_train_features.iloc[indexes]
sample_rows = prepared_data.iloc[indexes]
sample_labels = labels_df.iloc[indexes]['label'].values.tolist()
# get rows with matching label
if self.is_regression:
if self.is_regression or (len(label_feature.shape) > 1 and label_feature.shape[1] > 1):
match_samples = sample_rows
match_rows = original_rows
else:
indexes = [i for i, label in enumerate(sample_labels) if label == cell['label']]
labels_df = pd.DataFrame(label_feature, columns=['label'])
sample_labels = labels_df.iloc[indexes]['label'].values.tolist()
indexes = [i for i, label in enumerate(sample_labels) if label == cell['label'][0]]
match_samples = sample_rows.iloc[indexes]
match_rows = original_rows.iloc[indexes]
# find the "middle" of the cluster
array = match_samples.values
# Only works with numpy 1.9.0 and higher!!!