Compute generalizations with test data when possible (for computing better representatives).

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-08-07 15:59:22 +03:00
parent b48b829a01
commit c2e0fced03
2 changed files with 50 additions and 24 deletions

View file

@ -325,7 +325,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes)
# self._cells currently holds the generalization created from the tree leaves
self._calculate_generalizations()
self._calculate_generalizations(X_test)
if generalize_using_transform:
generalized = self._generalize_from_tree(X_test, x_prepared_test, nodes, self.cells, self._cells_by_id)
else:
@ -355,7 +355,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._attach_cells_representatives(x_prepared, used_X_train, y_train, nodes)
self._calculate_generalizations()
self._calculate_generalizations(X_test)
if generalize_using_transform:
generalized = self._generalize_from_tree(X_test, x_prepared_test, nodes, self.cells,
self._cells_by_id)
@ -385,7 +385,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
if removed_feature is None:
break
self._calculate_generalizations()
self._calculate_generalizations(X_test)
if generalize_using_transform:
generalized = self._generalize_from_tree(X_test, x_prepared_test, nodes, self.cells,
self._cells_by_id)
@ -1084,6 +1084,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._generalizations['ranges'],
self._generalizations['categories'])
# categorical - use most common value
old_category_representatives = category_representatives
category_representatives = {}
for feature in self._generalizations['categories']:
category_representatives[feature] = []
@ -1092,34 +1093,42 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# for c_index in range(len(group)):
# indexes = [i for i, s in enumerate(sample_indexes) if s[feature][g_index] == c_index]
indexes = [i for i, s in enumerate(sample_indexes) if s[feature] == g_index]
rows = samples[indexes]
values = rows[:, feature]
category = Counter(values).most_common(1)[0][0]
category_representatives[feature].append(group[category])
# c_count = len([s for s in sample_indexes if s[feature][g_index] == c_index])
# if c_count > max_count:
# max_count = c_count
# category = c_index
# category_representatives[feature].append(group[category])
if indexes:
rows = samples.iloc[indexes]
values = rows[feature]
category = Counter(values).most_common(1)[0][0]
category_representatives[feature].append(category)
# c_count = len([s for s in sample_indexes if s[feature][g_index] == c_index])
# if c_count > max_count:
# max_count = c_count
# category = c_index
# category_representatives[feature].append(group[category])
else:
category_representatives[feature].append(old_category_representatives[feature][g_index])
# numerical - use actual value closest to mean
old_range_representatives = range_representatives
range_representatives = {}
for feature in self._generalizations['ranges']:
range_representatives[feature] = []
# find the mean value (per feature)
for index in range(len(self._generalizations['ranges'][feature])):
indexes = [i for i, s in enumerate(sample_indexes) if s[feature] == index]
rows = samples[indexes]
values = rows[:, feature]
median = np.median(values)
min_value = max(values)
min_dist = float("inf")
for value in values:
# euclidean distance between two floating point values
dist = abs(value - median)
if dist < min_dist:
min_dist = dist
min_value = value
range_representatives[feature].append(min_value)
if indexes:
rows = samples.iloc[indexes]
values = rows[feature]
median = np.median(values)
min_value = max(values)
min_dist = float("inf")
for value in values:
# euclidean distance between two floating point values
dist = abs(value - median)
if dist < min_dist:
min_dist = dist
min_value = value
range_representatives[feature].append(min_value)
else:
range_representatives[feature].append(old_range_representatives[feature][index])
self._generalizations['category_representatives'] = category_representatives
self._generalizations['range_representatives'] = range_representatives