Addressing review comments

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-06-19 11:17:17 +03:00
parent 846de0f753
commit 2895b40f05
3 changed files with 16 additions and 25 deletions

View file

@ -63,7 +63,6 @@ class KerasClassifier(KerasModel):
:return: Predictions from the model as numpy array (class probabilities, if supported).
"""
predictions = self._art_model.predict(x.get_samples(), **kwargs)
# check_correct_model_output(predictions, self.output_type)
return predictions
def score(self, test_data: Dataset, scoring_method: Optional[ScoringMethod] = ScoringMethod.ACCURACY, **kwargs):

View file

@ -175,27 +175,27 @@ class Model(metaclass=ABCMeta):
:param test_data: Test data.
:type test_data: `Dataset`
:param predictions: Model predictions to score. If provided, these will be used instead of calling the model's
:keyword predictions: Model predictions to score. If provided, these will be used instead of calling the model's
`predict` method.
:type predictions: `DatasetWithPredictions` with the `pred` field filled.
:param scoring_method: The method for scoring predictions. Default is ACCURACY.
:keyword scoring_method: The method for scoring predictions. Default is ACCURACY.
:type scoring_method: `ScoringMethod`, optional
:param binary_threshold: The threshold to use on binary classification probabilities to assign the positive
:keyword binary_threshold: The threshold to use on binary classification probabilities to assign the positive
class.
:type binary_threshold: float, optional. Default is 0.5.
:param apply_non_linearity: A non-linear function to apply to the result of the 'predict' method, in case the
:keyword apply_non_linearity: A non-linear function to apply to the result of the 'predict' method, in case the
model outputs logits (e.g., sigmoid).
:type apply_non_linearity: Callable, should be possible to apply directly to the numpy output of the 'predict'
method, optional.
:param nb_classes: number of classes (for classification models).
:keyword nb_classes: number of classes (for classification models).
:type nb_classes: int, optional.
:return: the score as float (for classifiers, between 0 and 1)
"""
predictions = kwargs['predictions'] if 'predictions' in kwargs else None
nb_classes = kwargs['nb_classes'] if 'nb_classes' in kwargs else None
scoring_method = kwargs['scoring_method'] if 'scoring_method' in kwargs else ScoringMethod.ACCURACY
binary_threshold = kwargs['binary_threshold'] if 'binary_threshold' in kwargs else 0.5
apply_non_linearity = kwargs['apply_non_linearity'] if 'apply_non_linearity' in kwargs else expit
predictions = kwargs.get('predictions')
nb_classes = kwargs.get('nb_classes')
scoring_method = kwargs.get('scoring_method', ScoringMethod.ACCURACY)
binary_threshold = kwargs.get('binary_threshold', 0.5)
apply_non_linearity = kwargs.get('apply_non_linearity', expit)
if test_data.get_samples() is None and predictions is None:
raise ValueError('score can only be computed when test data or predictions are available')
@ -435,17 +435,9 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
if y_test_pred is None:
y_test_pred = model.get_test_labels()
# if y_train_pred is not None:
# check_correct_model_output(y_train_pred, self.output_type)
# if y_test_pred is not None:
# check_correct_model_output(y_test_pred, self.output_type)
if y_train_pred is not None and len(y_train_pred.shape) == 1:
# self._nb_classes = get_nb_classes(y_train_pred, self.output_type)
y_train_pred = check_and_transform_label_format(y_train_pred, nb_classes=self._nb_classes)
if y_test_pred is not None and len(y_test_pred.shape) == 1:
# if self._nb_classes is None:
# self._nb_classes = get_nb_classes(y_test_pred, self.output_type)
y_test_pred = check_and_transform_label_format(y_test_pred, nb_classes=self._nb_classes)
if x_train_pred is not None and y_train_pred is not None and x_test_pred is not None and y_test_pred is not None:

View file

@ -252,10 +252,10 @@ def test_pytorch_predictions_single_label_binary_prob():
def test_pytorch_predictions_multi_label_cat():
# This kind of model requires special training and will not be supported using the 'fit' method.
class multi_label_cat_model(nn.Module):
class MultiLabelCatModel(nn.Module):
def __init__(self, num_classes, num_features):
super(multi_label_cat_model, self).__init__()
super(MultiLabelCatModel, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 256),
@ -279,7 +279,7 @@ def test_pytorch_predictions_multi_label_cat():
y_test = np.stack([y_test, y_test], axis=1)
test = PytorchData(x_test.astype(np.float32), y_test.astype(np.float32))
inner_model = multi_label_cat_model(num_classes, 4)
inner_model = MultiLabelCatModel(num_classes, 4)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(inner_model.parameters(), lr=0.01)
@ -321,9 +321,9 @@ def test_pytorch_predictions_multi_label_cat():
def test_pytorch_predictions_multi_label_binary():
class multi_label_binary_model(nn.Module):
class MultiLabelBinaryModel(nn.Module):
def __init__(self, num_labels, num_features):
super(multi_label_binary_model, self).__init__()
super(MultiLabelBinaryModel, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(num_features, 256),
@ -343,7 +343,7 @@ def test_pytorch_predictions_multi_label_binary():
y_test[y_test > 1] = 1
test = PytorchData(x_test.astype(np.float32), y_test)
inner_model = multi_label_binary_model(3, 4)
inner_model = MultiLabelBinaryModel(3, 4)
criterion = FocalLoss()
optimizer = optim.RMSprop(inner_model.parameters(), lr=0.01)