mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-07-02 16:01:00 +02:00
Merge with main
This commit is contained in:
parent
dc5cc793ee
commit
64038f76f9
2 changed files with 5 additions and 5 deletions
|
|
@ -54,7 +54,7 @@ def test_nursery_pytorch_state_dict():
|
|||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
|
||||
|
||||
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
model = PyTorchClassifier(model=inner_model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=False, nb_epochs=10)
|
||||
|
|
@ -84,7 +84,7 @@ def test_nursery_pytorch_save_entire_model():
|
|||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
||||
|
||||
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_VECTOR, loss=criterion,
|
||||
art_model = PyTorchClassifier(model=model, output_type=ModelOutputType.CLASSIFIER_LOGITS, loss=criterion,
|
||||
optimizer=optimizer, input_shape=(24,),
|
||||
nb_classes=4)
|
||||
art_model.fit(PytorchData(x_train.astype(np.float32), y_train), save_entire_model=True, nb_epochs=10)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue