ai-privacy-toolkit/apt/utils/models/model.py

49 lines
1.2 KiB
Python
Raw Normal View History

from abc import ABCMeta, abstractmethod
from typing import Any
from apt.utils.datasets import BaseDataset, DATA_ARRAY_TYPE
class Model(metaclass=ABCMeta):
"""
Abstract base class for ML model wrappers.
"""
def __init__(self, model: Any, **kwargs):
"""
Initialize a `Model` wrapper object.
:param model: The original model object (of the underlying ML framework)
"""
self._model = model
@abstractmethod
def fit(self, train_data: BaseDataset, **kwargs) -> None:
"""
Fit the model using the training data.
:param train_data: Training data.
:type train_data: `BaseDataset`
"""
raise NotImplementedError
@abstractmethod
def predict(self, x: DATA_ARRAY_TYPE, **kwargs) -> DATA_ARRAY_TYPE:
"""
Perform predictions using the model for input `x`.
:param x: Input samples.
:type x: `np.ndarray` or `pandas.DataFrame`
:return: Predictions from the model.
"""
raise NotImplementedError
@property
def model(self):
"""
Return the model.
:return: The model.
"""
return self._model