mirror of
https://github.com/IBM/ai-privacy-toolkit.git
synced 2026-04-27 13:56:22 +02:00
41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
|
|
import numpy as np
|
||
|
|
|
||
|
|
from apt.utils.datasets import Data, DatasetWithPredictions
|
||
|
|
from apt.utils import dataset_utils
|
||
|
|
|
||
|
|
|
||
|
|
def test_dataset_predictions():
|
||
|
|
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
|
||
|
|
pred = np.array([[0.23, 0.56, 0.21] for i in range(105)])
|
||
|
|
|
||
|
|
dataset = DatasetWithPredictions(pred)
|
||
|
|
data = Data(train=dataset)
|
||
|
|
|
||
|
|
new_pred = data.get_train_set().get_predictions()
|
||
|
|
|
||
|
|
assert np.equal(pred, new_pred).all()
|
||
|
|
|
||
|
|
|
||
|
|
def test_dataset_predictions_x():
|
||
|
|
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
|
||
|
|
pred = np.array([[0.23, 0.56, 0.21] for i in range(105)])
|
||
|
|
|
||
|
|
dataset = DatasetWithPredictions(pred, x=x_train)
|
||
|
|
data = Data(train=dataset)
|
||
|
|
|
||
|
|
new_pred = data.get_train_set().get_predictions()
|
||
|
|
|
||
|
|
assert np.equal(pred, new_pred).all()
|
||
|
|
|
||
|
|
|
||
|
|
def test_dataset_predictions_y():
|
||
|
|
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
|
||
|
|
pred = np.array([[0.23, 0.56, 0.21] for i in range(105)])
|
||
|
|
|
||
|
|
dataset = DatasetWithPredictions(pred, x=x_train, y=y_train)
|
||
|
|
data = Data(train=dataset)
|
||
|
|
|
||
|
|
new_pred = data.get_train_set().get_predictions()
|
||
|
|
|
||
|
|
assert np.equal(pred, new_pred).all()
|