mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-08 15:05:13 +02:00
Merge with main
This commit is contained in:
parent
dc5cc793ee
commit
64038f76f9
2 changed files with 5 additions and 5 deletions
|
|
@ -273,9 +273,9 @@ class DatasetWithPredictions(Dataset):
|
|||
y: Optional[INPUT_DATA_ARRAY_TYPE] = None, features_names: Optional[list] = None, **kwargs):
|
||||
self.is_pandas = False
|
||||
self.features_names = features_names
|
||||
self._pred = self._array2numpy(pred)
|
||||
self._y = self._array2numpy(y) if y is not None else None
|
||||
self._x = self._array2numpy(x) if x is not None else None
|
||||
self._pred = array2numpy(pred)
|
||||
self._y = array2numpy(y) if y is not None else None
|
||||
self._x = array2numpy(x) if x is not None else None
|
||||
if self.is_pandas and x is not None:
|
||||
if features_names and not np.array_equal(features_names, x.columns):
|
||||
raise ValueError("The supplied features are not the same as in the data features")
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def test_nursery_pytorch_state_dict():
|
|||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
|
||||
|
||||
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
|
||||
|
|
@ -84,7 +84,7 @@ def test_nursery_pytorch_save_entire_model():
|
|||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
||||
|
||||
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue