mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-27 13:56:22 +02:00
New model wrappers (#32)
* keras wrapper + blackbox classifier wrapper (fix #7) * fix error in NCP calculation * Update notebooks * Fix #25 (incorrect attack_feature indexes for social feature in notebook) * Consistent naming of internal parameters
This commit is contained in:
parent
fd6be8e778
commit
fe676fa426
15 changed files with 1407 additions and 656 deletions
|
|
@ -5,7 +5,7 @@ Implementation of utility classes for dataset handling
|
|||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Callable, Collection, Any, Union, List, Optional
|
||||
from typing import Callable, Collection, Any, Union, List, Optional, Type
|
||||
|
||||
import tarfile
|
||||
import os
|
||||
|
|
@ -19,9 +19,9 @@ from torch import Tensor
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, List, Tensor]
|
||||
INPUT_DATA_ARRAY_TYPE = Union[np.ndarray, pd.DataFrame, pd.Series, List, Tensor]
|
||||
OUTPUT_DATA_ARRAY_TYPE = np.ndarray
|
||||
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame]
|
||||
DATA_PANDAS_NUMPY_TYPE = Union[np.ndarray, pd.DataFrame, pd.Series]
|
||||
|
||||
|
||||
class Dataset(metaclass=ABCMeta):
|
||||
|
|
@ -323,7 +323,7 @@ class DatasetFactory:
|
|||
:return: a Callable that returns the registered dataset class
|
||||
"""
|
||||
|
||||
def inner_wrapper(wrapped_class: Dataset) -> Any:
|
||||
def inner_wrapper(wrapped_class: Type[Dataset]) -> Any:
|
||||
if name in cls.registry:
|
||||
logger.warning('Dataset %s already exists. Will replace it', name)
|
||||
cls.registry[name] = wrapped_class
|
||||
|
|
@ -414,14 +414,18 @@ class Data:
|
|||
"""
|
||||
Get test set samples
|
||||
|
||||
:return: test samples
|
||||
:return: test samples, or None if no test data provided
|
||||
"""
|
||||
if self.test is None:
|
||||
return None
|
||||
return self.test.get_samples()
|
||||
|
||||
def get_test_labels(self) -> Collection[Any]:
|
||||
"""
|
||||
Get test set labels
|
||||
|
||||
:return: test labels
|
||||
:return: test labels, or None if no test data provided
|
||||
"""
|
||||
if self.test is None:
|
||||
return None
|
||||
return self.test.get_labels()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue