mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-27 14:25:16 +02:00
migrate gpy_kernsl-state_space_tests
This commit is contained in:
parent
d88ff47e37
commit
6fcb9e48fd
1 changed files with 6 additions and 28 deletions
|
|
@ -21,10 +21,7 @@ from nose import SkipTest
|
|||
# generate_linear_data, generate_brownian_data, generate_linear_plus_sin
|
||||
|
||||
|
||||
class StateSpaceKernelsTests(np.testing.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
class TestStateSpaceKernels:
|
||||
def run_for_model(
|
||||
self,
|
||||
X,
|
||||
|
|
@ -52,7 +49,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
m1.likelihood[:] = Y.var() / 100.0
|
||||
|
||||
if check_gradients:
|
||||
self.assertTrue(m1.checkgrad())
|
||||
assert m1.checkgrad()
|
||||
|
||||
if 1: # optimize:
|
||||
m1.optimize(optimizer="lbfgsb", max_iters=1)
|
||||
|
|
@ -60,7 +57,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
if compare_with_GP and (predict_X is None):
|
||||
predict_X = X
|
||||
|
||||
self.assertTrue(compare_with_GP)
|
||||
assert compare_with_GP
|
||||
if compare_with_GP:
|
||||
m2 = GPy.models.GPRegression(X, Y, gp_kernel)
|
||||
|
||||
|
|
@ -92,7 +89,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
m1.log_likelihood(), m2.log_likelihood(), var_compare_decimal
|
||||
)
|
||||
|
||||
def test_Matern32_kernel(
|
||||
def test_matern32_kernel(
|
||||
self,
|
||||
):
|
||||
np.random.seed(234) # seed the random number generator
|
||||
|
|
@ -134,7 +131,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
var_compare_decimal=5,
|
||||
)
|
||||
|
||||
def test_Matern52_kernel(
|
||||
def test_matern52_kernel(
|
||||
self,
|
||||
):
|
||||
np.random.seed(234) # seed the random number generator
|
||||
|
|
@ -177,7 +174,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
var_compare_decimal=5,
|
||||
)
|
||||
|
||||
def test_RBF_kernel(
|
||||
def test_rbf_kernel(
|
||||
self,
|
||||
):
|
||||
# import pdb;pdb.set_trace()
|
||||
|
|
@ -1026,22 +1023,3 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
mean_compare_decimal=2,
|
||||
var_compare_decimal=2,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running state-space inference tests...")
|
||||
unittest.main()
|
||||
|
||||
# tt = StateSpaceKernelsTests('test_RBF_kernel')
|
||||
# import pdb; pdb.set_trace()
|
||||
# tt.test_Matern32_kernel()
|
||||
# tt.test_Matern52_kernel()
|
||||
# tt.test_RBF_kernel()
|
||||
# tt.test_periodic_kernel()
|
||||
# tt.test_quasi_periodic_kernel()
|
||||
# tt.test_linear_kernel()
|
||||
# tt.test_brownian_kernel()
|
||||
# tt.test_exponential_kernel()
|
||||
# tt.test_kernel_addition()
|
||||
# tt.test_kernel_multiplication()
|
||||
# tt.test_forecast()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue