migrate serialzation_tests to pytest

This commit is contained in:
Martin Bubel 2023-10-10 19:56:11 +02:00
parent 96d8ac0975
commit c69f68feba

View file

@ -10,7 +10,7 @@ import os
fixed_seed = 11
class Test(unittest.TestCase):
class TestSerialization:
def test_serialize_deserialize_kernels(self):
k1 = GPy.kern.RBF(2, variance=1.0, lengthscale=[1.0, 1.0], ARD=True)
k2 = GPy.kern.RatQuad(
@ -371,23 +371,19 @@ class Test(unittest.TestCase):
m1_r = GPy.models.GPClassification.load_model(
"temp_test_gp_classifier_with_data.json.zip"
)
self.assertTrue(
type(m) == type(m1_r),
"Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)),
)
assert type(m) == type(
m1_r
), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r))
m2_r = GPy.models.GPClassification.load_model(
"temp_test_gp_classifier_without_data.json.zip", (X, Y)
)
self.assertTrue(
type(m) == type(m2_r),
"Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)),
)
assert type(m) == type(m2_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)),
os.remove("temp_test_gp_classifier_with_data.json.zip")
os.remove("temp_test_gp_classifier_without_data.json.zip")
var = m.predict(X)[0]
var1_r = m1_r.predict(X)[0]
var2_r = m2_r.predict(X)[0]
_var2_r = m2_r.predict(X)[0]
np.testing.assert_array_equal(
np.array(var).flatten(), np.array(var1_r).flatten()
)
@ -419,17 +415,11 @@ class Test(unittest.TestCase):
m1_r = GPy.models.SparseGPClassification.load_model(
"temp_test_sparse_gp_classifier_with_data.json.zip"
)
self.assertTrue(
type(m) == type(m1_r),
"Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)),
)
assert type(m) == type(m1_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r))
m2_r = GPy.models.SparseGPClassification.load_model(
"temp_test_sparse_gp_classifier_without_data.json.zip", (X, Y)
)
self.assertTrue(
type(m) == type(m2_r),
"Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)),
)
assert type(m) == type(m2_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)),
os.remove("temp_test_sparse_gp_classifier_with_data.json.zip")
os.remove("temp_test_sparse_gp_classifier_without_data.json.zip")
@ -442,8 +432,3 @@ class Test(unittest.TestCase):
np.testing.assert_array_equal(
np.array(var).flatten(), np.array(var1_r).flatten()
)
if __name__ == "__main__":
# import sys;sys.argv = ['', 'Test.test_parameter_index_operations']
unittest.main()