New model wrappers (#32)

* keras wrapper + blackbox classifier wrapper (fix #7)

* fix error in NCP calculation

* Update notebooks

* Fix #25 (incorrect attack_feature indexes for social feature in notebook)

* Consistent naming of internal parameters
This commit is contained in:
abigailgold 2022-05-12 15:44:29 +03:00 committed by GitHub
parent fd6be8e778
commit fe676fa426
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1407 additions and 656 deletions

View file

@ -5,7 +5,7 @@ Implementation of utility classes for dataset handling
"""
from abc import ABCMeta, abstractmethod
from typing import Callable, Collection, Any, Union, List, Optional
from typing import Callable, Collection, Any, Union, List, Optional, Type
import tarfile
import os
@ -19,9 +19,9 @@ from torch import Tensor
logger = logging.getLogger(__name__)
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, pd.Series, List, Tensor]
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame, pd.Series]
class Dataset(metaclass=ABCMeta):
@ -323,7 +323,7 @@ class DatasetFactory:
:return: a Callable that returns the registered dataset class
"""
def inner_wrapper(wrapped_class: Dataset) -> Any:
def inner_wrapper(wrapped_class: Type[Dataset]) -> Any:
if name in cls.registry:
logger.warning('Dataset %s already exists. Will replace it', name)
cls.registry[name] = wrapped_class
@ -414,14 +414,18 @@ class Data:
"""
Get test set samples
:return: test samples
:return: test samples, or None if no test data provided
"""
if self.test is None:
return None
return self.test.get_samples()
def get_test_labels(self) -> Collection[Any]:
"""
Get test set labels
:return: test labels
:return: test labels, or None if no test data provided
"""
if self.test is None:
return None
return self.test.get_labels()