Add more to wrappers

This commit is contained in:
abigailt 2022-03-15 11:42:57 +02:00 committed by olasaadi
parent a2b560920f
commit a9162fbc43
6 changed files with 86 additions and 38 deletions

View file

@ -1,7 +1,14 @@
from abc import ABCMeta, abstractmethod
from typing import Any
from enum import Enum, auto
from apt.utils.datasets import Dataset, DATA_ARRAY_TYPE
from apt.utils.datasets import Dataset, OUTPUT_DATA_ARRAY_TYPE
class ModelOutputType(Enum):
CLASSIFIER_VECTOR = auto() # probabilities or logits
CLASSIFIER_SCALAR = auto() # label only
REGRESSOR_SCALAR = auto() # value
class Model(metaclass=ABCMeta):
@ -9,13 +16,16 @@ class Model(metaclass=ABCMeta):
Abstract base class for ML model wrappers.
"""
def __init__(self, model: Any, **kwargs):
def __init__(self, model: Any, output_type: ModelOutputType, **kwargs):
"""
Initialize a `Model` wrapper object.
:param model: The original model object (of the underlying ML framework)
:param output_type: The type of output the model yields (vector/label only for classifiers,
value for regressors)
"""
self._model = model
self._output_type = output_type
@abstractmethod
def fit(self, train_data: Dataset, **kwargs) -> None:
@ -28,7 +38,7 @@ class Model(metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
def predict(self, x: Dataset, **kwargs) -> OUTPUT_DATA_ARRAY_TYPE:
"""
Perform predictions using the model for input `x`.
@ -39,10 +49,19 @@ class Model(metaclass=ABCMeta):
raise NotImplementedError
@property
def model(self):
def model(self) -> Any:
"""
Return the model.
:return: The model.
"""
return self._model
@property
def output_type(self) -> ModelOutputType:
"""
Return the model's output type.
:return: The model's output type.
"""
return self._output_type