Test for sklearn (currently not passing due to ART dependency)

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-04-30 14:51:35 +03:00
parent 8b8b461143
commit b3f87623b1
2 changed files with 25 additions and 1 deletions

View file

@ -37,7 +37,9 @@ class ScoringMethod(Enum):
def is_one_hot(y: OUTPUT_DATA_ARRAY_TYPE) -> bool:
return len(y.shape) == 2 and y.shape[1] > 1 and np.all(np.around(np.sum(y, axis=1), decimals=4) == 1)
if not isinstance(y, list):
return len(y.shape) == 2 and y.shape[1] > 1 and np.all(np.around(np.sum(y, axis=1), decimals=4) == 1)
return False
def is_multi_label(output_type: ModelOutputType) -> bool: