ai-privacy-toolkit/tests/test_datasets.py
abigailgold d52fcd0041
Formatting (#68)
Fix most flake/lint errors and ignore a few others

Signed-off-by: abigailt <abigailt@il.ibm.com>
2022-12-25 15:13:57 +02:00

40 lines
1.1 KiB
Python

import numpy as np
from apt.utils.datasets import Data, DatasetWithPredictions
from apt.utils import dataset_utils
def test_dataset_predictions():
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
pred = np.array([[0.23, 0.56, 0.21] for i in range(105)])
dataset = DatasetWithPredictions(pred)
data = Data(train=dataset)
new_pred = data.get_train_set().get_predictions()
assert np.equal(pred, new_pred).all()
def test_dataset_predictions_x():
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
pred = np.array([[0.23, 0.56, 0.21] for i in range(105)])
dataset = DatasetWithPredictions(pred, x=x_train)
data = Data(train=dataset)
new_pred = data.get_train_set().get_predictions()
assert np.equal(pred, new_pred).all()
def test_dataset_predictions_y():
(x_train, y_train), (_, _) = dataset_utils.get_iris_dataset_np()
pred = np.array([[0.23, 0.56, 0.21] for i in range(105)])
dataset = DatasetWithPredictions(pred, x=x_train, y=y_train)
data = Data(train=dataset)
new_pred = data.get_train_set().get_predictions()
assert np.equal(pred, new_pred).all()