migrate gpy_kernsl-state_space_tests

This commit is contained in:
Martin Bubel 2023-10-06 08:06:52 +02:00
parent d88ff47e37
commit 6fcb9e48fd

View file

@ -21,10 +21,7 @@ from nose import SkipTest
# generate_linear_data, generate_brownian_data, generate_linear_plus_sin
class StateSpaceKernelsTests(np.testing.TestCase):
def setUp(self):
pass
class TestStateSpaceKernels:
def run_for_model(
self,
X,
@ -52,7 +49,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
m1.likelihood[:] = Y.var() / 100.0
if check_gradients:
self.assertTrue(m1.checkgrad())
assert m1.checkgrad()
if 1: # optimize:
m1.optimize(optimizer="lbfgsb", max_iters=1)
@ -60,7 +57,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
if compare_with_GP and (predict_X is None):
predict_X = X
self.assertTrue(compare_with_GP)
assert compare_with_GP
if compare_with_GP:
m2 = GPy.models.GPRegression(X, Y, gp_kernel)
@ -92,7 +89,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
m1.log_likelihood(), m2.log_likelihood(), var_compare_decimal
)
def test_Matern32_kernel(
def test_matern32_kernel(
self,
):
np.random.seed(234) # seed the random number generator
@ -134,7 +131,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
var_compare_decimal=5,
)
def test_Matern52_kernel(
def test_matern52_kernel(
self,
):
np.random.seed(234) # seed the random number generator
@ -177,7 +174,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
var_compare_decimal=5,
)
def test_RBF_kernel(
def test_rbf_kernel(
self,
):
# import pdb;pdb.set_trace()
@ -1026,22 +1023,3 @@ class StateSpaceKernelsTests(np.testing.TestCase):
mean_compare_decimal=2,
var_compare_decimal=2,
)
if __name__ == "__main__":
print("Running state-space inference tests...")
unittest.main()
# tt = StateSpaceKernelsTests('test_RBF_kernel')
# import pdb; pdb.set_trace()
# tt.test_Matern32_kernel()
# tt.test_Matern52_kernel()
# tt.test_RBF_kernel()
# tt.test_periodic_kernel()
# tt.test_quasi_periodic_kernel()
# tt.test_linear_kernel()
# tt.test_brownian_kernel()
# tt.test_exponential_kernel()
# tt.test_kernel_addition()
# tt.test_kernel_multiplication()
# tt.test_forecast()