Merge with main

This commit is contained in:
abigailt 2022-08-01 18:12:03 +03:00
parent dc5cc793ee
commit 64038f76f9
2 changed files with 5 additions and 5 deletions

View file

@ -273,9 +273,9 @@ class DatasetWithPredictions(Dataset):
y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names: Optional[list] = None, **kwargs):
self.is_pandas = False
self.features_names = features_names
self._pred = self._array2numpy(pred)
self._y = self._array2numpy(y) if y is not None else None
self._x = self._array2numpy(x) if x is not None else None
self._pred = array2numpy(pred)
self._y = array2numpy(y) if y is not None else None
self._x = array2numpy(x) if x is not None else None
if self.is_pandas and x is not None:
if features_names and not np.array_equal(features_names, x.columns):
raise ValueError("The supplied features are not the same as in the data features")

View file

@ -54,7 +54,7 @@ def test_nursery_pytorch_state_dict():
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
@ -84,7 +84,7 @@ def test_nursery_pytorch_save_entire_model():
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
optimizer=optimizer, input_shape=(24,),
nb_classes=4)
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)