Formatting (#68)

Fix most flake/lint errors and ignore a few others

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailgold 2022-12-25 15:13:57 +02:00 committed by GitHub
parent b47ba24906
commit d52fcd0041
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 91 additions and 92 deletions

View file

@ -16,7 +16,7 @@ from sklearn.utils.validation import check_is_fitted
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from apt.utils.datasets import ArrayDataset, Data, DATA_PANDAS_NUMPY_TYPE
from apt.utils.datasets import ArrayDataset, DATA_PANDAS_NUMPY_TYPE
from apt.utils.models import Model, SklearnRegressor, ModelOutputType, SklearnClassifier
@ -268,14 +268,14 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
if self.encoder is None:
numeric_features = [f for f in self._features if f not in self.categorical_features]
numeric_transformer = Pipeline(
steps=[('imputer', SimpleImputer(strategy='constant', fill_value=0))]
steps=[('imputer', SimpleImputer(strategy='constant', fill_value=0))]
)
categorical_transformer = OneHotEncoder(handle_unknown="ignore", sparse=False)
self.encoder = ColumnTransformer(
transformers=[
("num", numeric_transformer, numeric_features),
("cat", categorical_transformer, self.categorical_features),
]
transformers=[
("num", numeric_transformer, numeric_features),
("cat", categorical_transformer, self.categorical_features),
]
)
self.encoder.fit(x)
@ -345,7 +345,6 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
print('Pruned tree to level: %d, new relative accuracy: %f' % (level, accuracy))
level += 1
# if accuracy below threshold, improve accuracy by removing features from generalization
elif accuracy < self.target_accuracy:
print('Improving accuracy')
@ -599,8 +598,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
new_cell['ranges'][feature]['end'] = right_cell['ranges'][feature]['start']
for feature in left_cell['categories'].keys():
new_cell['categories'][feature] = \
list(set(left_cell['categories'][feature]) |
set(right_cell['categories'][feature]))
list(set(left_cell['categories'][feature])
| set(right_cell['categories'][feature]))
for feature in left_cell['untouched']:
if feature in right_cell['untouched']:
new_cell['untouched'].append(feature)
@ -707,8 +706,8 @@ class GeneralizeToRepresentative(BaseEstimator, MetaEstimatorMixin, TransformerM
for feature in self._features:
# if feature has a representative value in the cell and should not be left untouched,
# take the representative value
if feature in cells[i]['representative'] and ('untouched' not in cells[i] or
feature not in cells[i]['untouched']):
if feature in cells[i]['representative'] \
and ('untouched' not in cells[i] or feature not in cells[i]['untouched']):
representatives.loc[i, feature] = cells[i]['representative'][feature]
# else, drop the feature (removes from representatives columns that do not have a
# representative value or should remain untouched)