This commit is contained in:
abigailt 2022-03-07 19:09:31 +02:00
parent 3d82db80c4
commit f2df2fcc8c
6 changed files with 35 additions and 43 deletions

View file

@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from typing import Any
from apt.utils.datasets import BaseDataset, DATA_ARRAY_TYPE
from apt.utils.datasets import Dataset, DATA_ARRAY_TYPE
class Model(metaclass=ABCMeta):
@ -18,12 +18,12 @@ class Model(metaclass=ABCMeta):
self._model = model
@abstractmethod
def fit(self, train_data: BaseDataset, **kwargs) -> None:
def fit(self, train_data: Dataset, **kwargs) -> None:
"""
Fit the model using the training data.
:param train_data: Training data.
:type train_data: `BaseDataset`
:type train_data: `Dataset`
"""
raise NotImplementedError