mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-25 04:46:21 +02:00
fix
This commit is contained in:
parent
521a2ccda9
commit
74ce92acc4
2 changed files with 8 additions and 8 deletions
|
|
@ -94,9 +94,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
if self._optimizer is None: # pragma: no cover
|
||||
raise ValueError("An optimizer is needed to train the model, but none for provided.")
|
||||
|
||||
_y = check_and_transform_label_format(y, self.nb_classes)
|
||||
y = check_and_transform_label_format(y, self.nb_classes)
|
||||
# Apply preprocessing
|
||||
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True)
|
||||
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
|
||||
# Check label shape
|
||||
y_preprocessed = self.reduce_labels(y_preprocessed)
|
||||
|
||||
|
|
@ -107,8 +107,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
|
|||
val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
|
||||
logger.info("Using train set for validation")
|
||||
else:
|
||||
_y_val = check_and_transform_label_format(y_validation, self.nb_classes)
|
||||
x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, _y_val, fit=False)
|
||||
y_val = check_and_transform_label_format(y_validation, self.nb_classes)
|
||||
x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, y_val, fit=False)
|
||||
# Check label shape
|
||||
y_val_preprocessed = self.reduce_labels(y_val_preprocessed)
|
||||
val_dataset = TensorDataset(torch.from_numpy(x_val_preprocessed), torch.from_numpy(y_val_preprocessed))
|
||||
|
|
|
|||
|
|
@ -59,12 +59,12 @@ def test_nursery_pytorch_state_dict():
|
|||
nb_classes=4)
|
||||
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
|
||||
model.load_latest_state_dict_checkpoint()
|
||||
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||
score = model.score(PytorchData(x_test.astype(np.float32), y_test))
|
||||
print('Base model accuracy: ', score)
|
||||
assert (0 <= score <= 1)
|
||||
# python pytorch numpy
|
||||
model.load_best_state_dict_checkpoint()
|
||||
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||
score = model.score(PytorchData(x_test.astype(np.float32), y_test))
|
||||
print('best model accuracy: ', score)
|
||||
assert (0 <= score <= 1)
|
||||
|
||||
|
|
@ -89,10 +89,10 @@ def test_nursery_pytorch_save_entire_model():
|
|||
nb_classes=4)
|
||||
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)
|
||||
|
||||
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||
score = art_model.score(PytorchData(x_test.astype(np.float32), y_test))
|
||||
print('Base model accuracy: ', score)
|
||||
assert (0 <= score <= 1)
|
||||
art_model.load_best_model_checkpoint()
|
||||
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
|
||||
score = art_model.score(PytorchData(x_test.astype(np.float32), y_test))
|
||||
print('best model accuracy: ', score)
|
||||
assert (0 <= score <= 1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue