Regression minimization (#20)

* support regression in minimization and add test

* fix #10
This commit is contained in:
olasaadi 2022-01-27 15:57:55 +02:00 committed by GitHub
parent cb9278ddb5
commit 3feebe8973
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 19 deletions

View file

@ -13,7 +13,7 @@ from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.model_selection import train_test_split
@ -84,7 +84,7 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
def __init__(self, estimator=None, target_accuracy=0.998, features=None,
cells=None, categorical_features=None, features_to_minimize: Union[np.ndarray, list] = None
, train_only_QI=True):
, train_only_QI=True, is_regression=False):
self.estimator = estimator
self.target_accuracy = target_accuracy
self.features = features
@ -94,7 +94,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self.categorical_features = categorical_features
self.features_to_minimize = features_to_minimize
self.train_only_QI = train_only_QI
self.is_regression = is_regression
def get_params(self, deep=True):
"""Get parameters for this estimator.
@ -227,9 +228,11 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
used_data = X
if self.train_only_QI:
used_data = x_QI
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y,
test_size=0.4,
random_state=18)
if self.is_regression:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=14)
else:
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.4, random_state=18)
X_train_QI = X_train.loc[:, self.features_to_minimize]
X_test_QI = X_test.loc[:, self.features_to_minimize]
used_X_train = X_train
@ -292,7 +295,10 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
self._preprocessor = preprocessor
self.cells_ = {}
self.dt_ = DecisionTreeClassifier(random_state=0, min_samples_split=2,
if self.is_regression:
self.dt_ = DecisionTreeRegressor(random_state=10, min_samples_split=2, min_samples_leaf=1)
else:
self.dt_ = DecisionTreeClassifier(random_state=0, min_samples_split=2,
min_samples_leaf=1)
self.dt_.fit(x_prepared, y_train)
self._modify_categorical_features(used_data)
@ -528,8 +534,9 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
feature_index = self.dt_.tree_.feature[node]
if feature_index == -2:
# this is a leaf
label = self._calculate_cell_label(node)
hist = [int(i) for i in self.dt_.tree_.value[node][0]]
# if it is a regression problem we do not use label
label = self._calculate_cell_label(node) if not self.is_regression else 1
hist = [int(i) for i in self.dt_.tree_.value[node][0]] if not self.is_regression else []
cell = {'label': label, 'hist': hist, 'ranges': {}, 'id': int(node)}
return [cell]
@ -632,8 +639,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
# else: nothing to do, stay with previous cells
def _calculate_level_cell_label(self, left_cell, right_cell, new_cell):
new_cell['hist'] = [x + y for x, y in zip(left_cell['hist'], right_cell['hist'])]
new_cell['label'] = int(self.dt_.classes_[np.argmax(new_cell['hist'])])
new_cell['hist'] = [x + y for x, y in zip(left_cell['hist'], right_cell['hist'])] if not self.is_regression else []
new_cell['label'] = int(self.dt_.classes_[np.argmax(new_cell['hist'])]) if not self.is_regression else 1
def _get_nodes_level(self, level):
# level = distance from lowest leaf
@ -674,9 +681,13 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
sample_rows = prepared_data.iloc[indexes]
sample_labels = labels_df.iloc[indexes]['label'].values.tolist()
# get rows with matching label
indexes = [i for i, label in enumerate(sample_labels) if label == cell['label']]
match_samples = sample_rows.iloc[indexes]
match_rows = original_rows.iloc[indexes]
if self.is_regression:
match_samples = sample_rows
match_rows = original_rows
else:
indexes = [i for i, label in enumerate(sample_labels) if label == cell['label']]
match_samples = sample_rows.iloc[indexes]
match_rows = original_rows.iloc[indexes]
# find the "middle" of the cluster
array = match_samples.values
# Only works with numpy 1.9.0 and higher!!!