This commit is contained in:
olasaadi 2022-07-26 18:37:44 +03:00
parent 521a2ccda9
commit 74ce92acc4
2 changed files with 8 additions and 8 deletions

View file

@ -94,9 +94,9 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
if self._optimizer is None: # pragma: no cover
raise ValueError("An optimizer is needed to train the model, but none for provided.")
_y = check_and_transform_label_format(y, self.nb_classes)
y = check_and_transform_label_format(y, self.nb_classes)
# Apply preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, _y, fit=True)
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)
@ -107,8 +107,8 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
logger.info("Using train set for validation")
else:
_y_val = check_and_transform_label_format(y_validation, self.nb_classes)
x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, _y_val, fit=False)
y_val = check_and_transform_label_format(y_validation, self.nb_classes)
x_val_preprocessed, y_val_preprocessed = self._apply_preprocessing(x_validation, y_val, fit=False)
# Check label shape
y_val_preprocessed = self.reduce_labels(y_val_preprocessed)
val_dataset = TensorDataset(torch.from_numpy(x_val_preprocessed), torch.from_numpy(y_val_preprocessed))

View file

@ -59,12 +59,12 @@ def test_nursery_pytorch_state_dict():
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
model.load_latest_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
score = model.score(PytorchData(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
# python pytorch numpy
model.load_best_state_dict_checkpoint()
score = model.score(ArrayDataset(x_test.astype(np.float32), y_test))
score = model.score(PytorchData(x_test.astype(np.float32), y_test))
print('best model accuracy: ', score)
assert (0 <= score <= 1)
@ -89,10 +89,10 @@ def test_nursery_pytorch_save_entire_model():
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
score = art_model.score(PytorchData(x_test.astype(np.float32), y_test))
print('Base model accuracy: ', score)
assert (0 <= score <= 1)
art_model.load_best_model_checkpoint()
score = art_model.score(ArrayDataset(x_test.astype(np.float32), y_test))
score = art_model.score(PytorchData(x_test.astype(np.float32), y_test))
print('best model accuracy: ', score)
assert (0 <= score <= 1)