mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-26 05:16:22 +02:00
Increase version to 0.2.0 (#74)
* Remove tensorflow dependency if not using keras model * Remove xgboost dependency if not using xgboost model * Documentation updates Signed-off-by: abigailt <abigailt@il.ibm.com>
This commit is contained in:
parent
782edabd58
commit
8a9ef80146
25 changed files with 306 additions and 152 deletions
|
|
@ -9,21 +9,20 @@ from apt.utils.datasets import ArrayDataset
|
|||
|
||||
class AttackStrategyUtils(abc.ABC):
|
||||
"""
|
||||
Abstract base class for common utilities of various privacy attack strategies.
|
||||
Abstract base class for common utilities of various privacy attack strategies.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class KNNAttackStrategyUtils(AttackStrategyUtils):
|
||||
"""
|
||||
Common utilities for attack strategy based on KNN distances.
|
||||
Common utilities for attack strategy based on KNN distances.
|
||||
|
||||
:param use_batches: Use batches with a progress meter or not when finding KNNs for query set.
|
||||
:param batch_size: if use_batches=True, the size of batch_size should be > 0.
|
||||
"""
|
||||
|
||||
def __init__(self, use_batches: bool = False, batch_size: int = 10) -> None:
|
||||
"""
|
||||
:param use_batches: Use batches with a progress meter or not when finding KNNs for query set
|
||||
:param batch_size: if use_batches=True, the size of batch_size should be > 0
|
||||
"""
|
||||
self.use_batches = use_batches
|
||||
self.batch_size = batch_size
|
||||
if use_batches:
|
||||
|
|
@ -31,11 +30,18 @@ class KNNAttackStrategyUtils(AttackStrategyUtils):
|
|||
raise ValueError(f"When using batching batch_size should be > 0, and not {batch_size}")
|
||||
|
||||
def fit(self, knn_learner: NearestNeighbors, dataset: ArrayDataset):
|
||||
"""
|
||||
Fit the KNN learner.
|
||||
|
||||
:param knn_learner: The KNN model to fit.
|
||||
:param dataset: The training set to fit the model on.
|
||||
"""
|
||||
knn_learner.fit(dataset.get_samples())
|
||||
|
||||
def find_knn(self, knn_learner: NearestNeighbors, query_samples: ArrayDataset, distance_processor=None):
|
||||
"""
|
||||
Nearest neighbor search function.
|
||||
|
||||
:param query_samples: query samples, to which nearest neighbors are to be found
|
||||
:param knn_learner: unsupervised learner for implementing neighbor searches, after it was fitted
|
||||
:param distance_processor: function for processing the distance into another more relevant metric per sample.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue