From 6fcb9e48fdde4d969bed2ada06565aade5696956 Mon Sep 17 00:00:00 2001 From: Martin Bubel Date: Fri, 6 Oct 2023 08:06:52 +0200 Subject: [PATCH] migrate gpy_kernsl-state_space_tests --- GPy/testing/gpy_kernels_state_space_tests.py | 34 ++++---------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/GPy/testing/gpy_kernels_state_space_tests.py b/GPy/testing/gpy_kernels_state_space_tests.py index f6013f79..f5a3f89e 100644 --- a/GPy/testing/gpy_kernels_state_space_tests.py +++ b/GPy/testing/gpy_kernels_state_space_tests.py @@ -21,10 +21,7 @@ from nose import SkipTest # generate_linear_data, generate_brownian_data, generate_linear_plus_sin -class StateSpaceKernelsTests(np.testing.TestCase): - def setUp(self): - pass - +class TestStateSpaceKernels: def run_for_model( self, X, @@ -52,7 +49,7 @@ class StateSpaceKernelsTests(np.testing.TestCase): m1.likelihood[:] = Y.var() / 100.0 if check_gradients: - self.assertTrue(m1.checkgrad()) + assert m1.checkgrad() if 1: # optimize: m1.optimize(optimizer="lbfgsb", max_iters=1) @@ -60,7 +57,7 @@ class StateSpaceKernelsTests(np.testing.TestCase): if compare_with_GP and (predict_X is None): predict_X = X - self.assertTrue(compare_with_GP) + assert compare_with_GP if compare_with_GP: m2 = GPy.models.GPRegression(X, Y, gp_kernel) @@ -92,7 +89,7 @@ class StateSpaceKernelsTests(np.testing.TestCase): m1.log_likelihood(), m2.log_likelihood(), var_compare_decimal ) - def test_Matern32_kernel( + def test_matern32_kernel( self, ): np.random.seed(234) # seed the random number generator @@ -134,7 +131,7 @@ class StateSpaceKernelsTests(np.testing.TestCase): var_compare_decimal=5, ) - def test_Matern52_kernel( + def test_matern52_kernel( self, ): np.random.seed(234) # seed the random number generator @@ -177,7 +174,7 @@ class StateSpaceKernelsTests(np.testing.TestCase): var_compare_decimal=5, ) - def test_RBF_kernel( + def test_rbf_kernel( self, ): # import pdb;pdb.set_trace() @@ -1026,22 +1023,3 @@ class StateSpaceKernelsTests(np.testing.TestCase): mean_compare_decimal=2, var_compare_decimal=2, ) - - -if __name__ == "__main__": - print("Running state-space inference tests...") - unittest.main() - - # tt = StateSpaceKernelsTests('test_RBF_kernel') - # import pdb; pdb.set_trace() - # tt.test_Matern32_kernel() - # tt.test_Matern52_kernel() - # tt.test_RBF_kernel() - # tt.test_periodic_kernel() - # tt.test_quasi_periodic_kernel() - # tt.test_linear_kernel() - # tt.test_brownian_kernel() - # tt.test_exponential_kernel() - # tt.test_kernel_addition() - # tt.test_kernel_multiplication() - # tt.test_forecast()