This commit is contained in:
olasaadi 2022-05-23 12:15:35 +03:00
parent e0385b0d04
commit 019f49861d

View file

@ -155,11 +155,11 @@ class PyTorchClassifierWrapper(ArtPyTorchClassifier):
class PyTorchClassifier(PyTorchModel):
"""
Wrapper class for scikitlearn classification models.
Wrapper class for pytorch classification models.
"""
def __init__(self, model: "torch.nn.Module", output_type: ModelOutputType, loss: "torch.nn.modules.loss._Loss",
input_shape: Tuple[int, ...], nb_classes: int, optimizer: Optional["torch.optim.Optimizer"] = None,
input_shape: Tuple[int, ...], nb_classes: int, optimizer: "torch.optim.Optimizer",
black_box_access: Optional[bool] = True, unlimited_queries: Optional[bool] = True, **kwargs):
"""
Initialization specifically for the PyTorch-based implementation.