migrate gpy_kernsl-state_space_tests

This commit is contained in:
Martin Bubel 2023-10-06 08:06:52 +02:00
parent d88ff47e37
commit 6fcb9e48fd

View file

@ -21,10 +21,7 @@ from nose import SkipTest
# generate_linear_data, generate_brownian_data, generate_linear_plus_sin # generate_linear_data, generate_brownian_data, generate_linear_plus_sin
class StateSpaceKernelsTests(np.testing.TestCase): class TestStateSpaceKernels:
def setUp(self):
pass
def run_for_model( def run_for_model(
self, self,
X, X,
@ -52,7 +49,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
m1.likelihood[:] = Y.var() / 100.0 m1.likelihood[:] = Y.var() / 100.0
if check_gradients: if check_gradients:
self.assertTrue(m1.checkgrad()) assert m1.checkgrad()
if 1: # optimize: if 1: # optimize:
m1.optimize(optimizer="lbfgsb", max_iters=1) m1.optimize(optimizer="lbfgsb", max_iters=1)
@ -60,7 +57,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
if compare_with_GP and (predict_X is None): if compare_with_GP and (predict_X is None):
predict_X = X predict_X = X
self.assertTrue(compare_with_GP) assert compare_with_GP
if compare_with_GP: if compare_with_GP:
m2 = GPy.models.GPRegression(X, Y, gp_kernel) m2 = GPy.models.GPRegression(X, Y, gp_kernel)
@ -92,7 +89,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
m1.log_likelihood(), m2.log_likelihood(), var_compare_decimal m1.log_likelihood(), m2.log_likelihood(), var_compare_decimal
) )
def test_Matern32_kernel( def test_matern32_kernel(
self, self,
): ):
np.random.seed(234) # seed the random number generator np.random.seed(234) # seed the random number generator
@ -134,7 +131,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
var_compare_decimal=5, var_compare_decimal=5,
) )
def test_Matern52_kernel( def test_matern52_kernel(
self, self,
): ):
np.random.seed(234) # seed the random number generator np.random.seed(234) # seed the random number generator
@ -177,7 +174,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
var_compare_decimal=5, var_compare_decimal=5,
) )
def test_RBF_kernel( def test_rbf_kernel(
self, self,
): ):
# import pdb;pdb.set_trace() # import pdb;pdb.set_trace()
@ -1026,22 +1023,3 @@ class StateSpaceKernelsTests(np.testing.TestCase):
mean_compare_decimal=2, mean_compare_decimal=2,
var_compare_decimal=2, var_compare_decimal=2,
) )
if __name__ == "__main__":
print("Running state-space inference tests...")
unittest.main()
# tt = StateSpaceKernelsTests('test_RBF_kernel')
# import pdb; pdb.set_trace()
# tt.test_Matern32_kernel()
# tt.test_Matern52_kernel()
# tt.test_RBF_kernel()
# tt.test_periodic_kernel()
# tt.test_quasi_periodic_kernel()
# tt.test_linear_kernel()
# tt.test_brownian_kernel()
# tt.test_exponential_kernel()
# tt.test_kernel_addition()
# tt.test_kernel_multiplication()
# tt.test_forecast()