Many fixes, some tests pass

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2023-05-29 19:13:35 +03:00
parent 4a252e65fe
commit 2c01187f1f
2 changed files with 200 additions and 58 deletions

View file

@ -70,13 +70,42 @@ def test_minimizer_params_not_transform(data):
[45, 158],
[18, 190]])
y = [1, 1, 0]
samples = ArrayDataset(X, y, features)
base_est = DecisionTreeClassifier(random_state=0, min_samples_split=2,
min_samples_leaf=1)
model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)
model.fit(ArrayDataset(X, y))
gen = GeneralizeToRepresentative(model, cells=cells, generalize_using_transform=False)
gen.calculate_ncp(X)
gen.calculate_ncp(samples)
ncp = gen.ncp
assert (ncp > 0.0)
def test_minimizer_params_not_transform_no_data(data):
# Assume two features, age and height, and boolean label
cells = [{"id": 1, "ranges": {"age": {"start": None, "end": 38}, "height": {"start": None, "end": 170}}, "label": 0,
'categories': {}, "representative": {"age": 26, "height": 149}},
{"id": 2, "ranges": {"age": {"start": 39, "end": None}, "height": {"start": None, "end": 170}}, "label": 1,
'categories': {}, "representative": {"age": 58, "height": 163}},
{"id": 3, "ranges": {"age": {"start": None, "end": 38}, "height": {"start": 171, "end": None}}, "label": 0,
'categories': {}, "representative": {"age": 31, "height": 184}},
{"id": 4, "ranges": {"age": {"start": 39, "end": None}, "height": {"start": 171, "end": None}}, "label": 1,
'categories': {}, "representative": {"age": 45, "height": 176}}
]
features = ['age', 'height']
X = np.array([[23, 165],
[45, 158],
[18, 190]])
y = [1, 1, 0]
samples = ArrayDataset(X, y, features)
base_est = DecisionTreeClassifier(random_state=0, min_samples_split=2,
min_samples_leaf=1)
model = SklearnClassifier(base_est, ModelOutputType.CLASSIFIER_PROBABILITIES)
model.fit(ArrayDataset(X, y))
gen = GeneralizeToRepresentative(model, cells=cells, generalize_using_transform=False)
gen.calculate_ncp(samples)
ncp = gen.ncp
assert (ncp > 0.0)