diff --git a/GPy/testing/serialization_tests.py b/GPy/testing/serialization_tests.py index f2af89d3..f08148f8 100644 --- a/GPy/testing/serialization_tests.py +++ b/GPy/testing/serialization_tests.py @@ -10,7 +10,7 @@ import os fixed_seed = 11 -class Test(unittest.TestCase): +class TestSerialization: def test_serialize_deserialize_kernels(self): k1 = GPy.kern.RBF(2, variance=1.0, lengthscale=[1.0, 1.0], ARD=True) k2 = GPy.kern.RatQuad( @@ -371,23 +371,19 @@ class Test(unittest.TestCase): m1_r = GPy.models.GPClassification.load_model( "temp_test_gp_classifier_with_data.json.zip" ) - self.assertTrue( - type(m) == type(m1_r), - "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)), - ) + assert type(m) == type( + m1_r + ), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)) m2_r = GPy.models.GPClassification.load_model( "temp_test_gp_classifier_without_data.json.zip", (X, Y) ) - self.assertTrue( - type(m) == type(m2_r), - "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)), - ) + assert type(m) == type(m2_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)), os.remove("temp_test_gp_classifier_with_data.json.zip") os.remove("temp_test_gp_classifier_without_data.json.zip") var = m.predict(X)[0] var1_r = m1_r.predict(X)[0] - var2_r = m2_r.predict(X)[0] + _var2_r = m2_r.predict(X)[0] np.testing.assert_array_equal( np.array(var).flatten(), np.array(var1_r).flatten() ) @@ -419,17 +415,11 @@ class Test(unittest.TestCase): m1_r = GPy.models.SparseGPClassification.load_model( "temp_test_sparse_gp_classifier_with_data.json.zip" ) - self.assertTrue( - type(m) == type(m1_r), - "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)), - ) + assert type(m) == type(m1_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)) m2_r = GPy.models.SparseGPClassification.load_model( "temp_test_sparse_gp_classifier_without_data.json.zip", (X, Y) ) - self.assertTrue( - type(m) == type(m2_r), - "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)), - ) + assert type(m) == type(m2_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)), os.remove("temp_test_sparse_gp_classifier_with_data.json.zip") os.remove("temp_test_sparse_gp_classifier_without_data.json.zip") @@ -442,8 +432,3 @@ class Test(unittest.TestCase): np.testing.assert_array_equal( np.array(var).flatten(), np.array(var1_r).flatten() ) - - -if __name__ == "__main__": - # import sys;sys.argv = ['', 'Test.test_parameter_index_operations'] - unittest.main()