mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-06-26 15:49:37 +02:00
formatting
Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
f85fc87bdd
commit
69e45d99e5
4 changed files with 17 additions and 16 deletions
|
|
@ -43,7 +43,7 @@ def get_nb_classes(y: OUTPUT_DATA_ARRAY_TYPE) -> int:
|
|||
if y is None:
|
||||
return 0
|
||||
|
||||
if type(y) != np.ndarray:
|
||||
if not isinstance(y, np.ndarray):
|
||||
raise ValueError("Input should be numpy array")
|
||||
|
||||
if is_one_hot(y):
|
||||
|
|
@ -339,8 +339,8 @@ class BlackboxClassifierPredictions(BlackboxClassifier):
|
|||
y_test_pred = check_and_transform_label_format(y_test_pred, nb_classes=self._nb_classes)
|
||||
|
||||
if x_train_pred is not None and y_train_pred is not None and x_test_pred is not None and y_test_pred is not None:
|
||||
if type(y_train_pred) != np.ndarray or type(y_test_pred) != np.ndarray \
|
||||
or type(y_train_pred) != np.ndarray or type(y_test_pred) != np.ndarray:
|
||||
if not isinstance(y_train_pred, np.ndarray) or not isinstance(y_test_pred, np.ndarray) \
|
||||
or not isinstance(y_train_pred, np.ndarray) or not isinstance(y_test_pred, np.ndarray):
|
||||
raise NotImplementedError("X/Y Data should be numpy array")
|
||||
x_pred = np.vstack((x_train_pred, x_test_pred))
|
||||
y_pred = np.vstack((y_train_pred, y_test_pred))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue