Formatting

Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
abigailt 2024-07-03 13:42:29 +03:00
parent 367cae679b
commit bcb7c47cc6
3 changed files with 12 additions and 14 deletions

View file

@ -25,7 +25,7 @@ class ModelOutputType(Flag):
CLASSIFIER_SINGLE_OUTPUT_CATEGORICAL = ModelOutputType.CLASSIFIER
# single binary probability
CLASSIFIER_SINGLE_OUTPUT_BINARY_PROBABILITIES = ModelOutputType.CLASSIFIER | ModelOutputType.BINARY | \
ModelOutputType.PROBABILITIES
ModelOutputType.PROBABILITIES
# vector of class probabilities
CLASSIFIER_SINGLE_OUTPUT_CLASS_PROBABILITIES = ModelOutputType.CLASSIFIER | ModelOutputType.PROBABILITIES
# single binary logit
@ -36,16 +36,16 @@ CLASSIFIER_SINGLE_OUTPUT_CLASS_LOGITS = ModelOutputType.CLASSIFIER | ModelOutput
CLASSIFIER_MULTI_OUTPUT_CATEGORICAL = ModelOutputType.MULTI_OUTPUT | ModelOutputType.CLASSIFIER
# vector of binary probabilities, 1 per output
CLASSIFIER_MULTI_OUTPUT_BINARY_PROBABILITIES = ModelOutputType.MULTI_OUTPUT | ModelOutputType.CLASSIFIER | \
ModelOutputType.BINARY | ModelOutputType.PROBABILITIES
ModelOutputType.BINARY | ModelOutputType.PROBABILITIES
# vector of class probabilities for multiple outputs
CLASSIFIER_MULTI_OUTPUT_CLASS_PROBABILITIES = ModelOutputType.MULTI_OUTPUT | ModelOutputType.CLASSIFIER | \
ModelOutputType.PROBABILITIES
ModelOutputType.PROBABILITIES
# vector of binary logits
CLASSIFIER_MULTI_OUTPUT_BINARY_LOGITS = ModelOutputType.MULTI_OUTPUT | ModelOutputType.CLASSIFIER | \
ModelOutputType.BINARY | ModelOutputType.LOGITS
ModelOutputType.BINARY | ModelOutputType.LOGITS
# vector of logits for multiple outputs
CLASSIFIER_MULTI_OUTPUT_CLASS_LOGITS = ModelOutputType.MULTI_OUTPUT | ModelOutputType.CLASSIFIER | \
ModelOutputType.LOGITS
ModelOutputType.LOGITS
class ModelType(Enum):
@ -79,9 +79,9 @@ def is_binary(output_type: ModelOutputType) -> bool:
def is_categorical(output_type: ModelOutputType) -> bool:
return (ModelOutputType.CLASSIFIER in output_type
and not ModelOutputType.BINARY in output_type
and not ModelOutputType.PROBABILITIES in output_type
and not ModelOutputType.LOGITS in output_type)
and ModelOutputType.BINARY not in output_type
and ModelOutputType.PROBABILITIES not in output_type
and ModelOutputType.LOGITS not in output_type)
def is_probabilities(output_type: ModelOutputType) -> bool: