mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 02:24:17 +02:00
[statespace] tests mote thorough and numerically stable
This commit is contained in:
parent
bfd0ee0db2
commit
754106e471
1 changed files with 12 additions and 13 deletions
|
|
@ -28,23 +28,22 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
kalman_filter_type=kalman_filter_type,
|
||||
use_cython=use_cython)
|
||||
|
||||
m1.likelihood[:] = Y.var()/10.
|
||||
|
||||
if check_gradients:
|
||||
self.assertTrue(m1.checkgrad())
|
||||
|
||||
if 1:#optimize:
|
||||
m1.optimize(optimizer='lbfgsb',max_iters=2)
|
||||
m1.optimize(optimizer='lbfgsb', max_iters=1)
|
||||
|
||||
if compare_with_GP and (predict_X is None):
|
||||
predict_X = X
|
||||
|
||||
self.assertTrue(compare_with_GP)
|
||||
if compare_with_GP:
|
||||
np.random.seed(254856)
|
||||
m2 = GPy.models.GPRegression(X,Y, gp_kernel)
|
||||
#m2.randomize()
|
||||
m2.optimize(max_iters=optimize_max_iters)
|
||||
|
||||
m1[:] = m2[:]
|
||||
m2[:] = m1[:]
|
||||
|
||||
if (predict_X is not None):
|
||||
x_pred_reg_1 = m1.predict(predict_X)
|
||||
|
|
@ -53,12 +52,12 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
x_pred_reg_2 = m2.predict(predict_X)
|
||||
x_quant_reg_2 = m2.predict_quantiles(predict_X)
|
||||
|
||||
np.testing.assert_array_almost_equal(x_pred_reg_1[0], x_pred_reg_2[0], 3)
|
||||
np.testing.assert_array_almost_equal(x_pred_reg_1[1], x_pred_reg_2[1], 3)
|
||||
np.testing.assert_array_almost_equal(x_quant_reg_1[0], x_quant_reg_2[0], 3)
|
||||
np.testing.assert_array_almost_equal(x_quant_reg_1[1], x_quant_reg_2[1], 3)
|
||||
np.testing.assert_almost_equal(m1.log_likelihood(), m2.log_likelihood(), 3)
|
||||
np.testing.assert_array_almost_equal(m1.gradient, m2.gradient, 2)
|
||||
np.testing.assert_array_almost_equal(x_pred_reg_1[0], x_pred_reg_2[0], mean_compare_decimal)
|
||||
np.testing.assert_array_almost_equal(x_pred_reg_1[1], x_pred_reg_2[1], var_compare_decimal)
|
||||
np.testing.assert_array_almost_equal(x_quant_reg_1[0], x_quant_reg_2[0], mean_compare_decimal)
|
||||
np.testing.assert_array_almost_equal(x_quant_reg_1[1], x_quant_reg_2[1], mean_compare_decimal)
|
||||
np.testing.assert_array_almost_equal(m1.gradient, m2.gradient, var_compare_decimal)
|
||||
np.testing.assert_almost_equal(m1.log_likelihood(), m2.log_likelihood(), var_compare_decimal)
|
||||
|
||||
|
||||
def test_Matern32_kernel(self,):
|
||||
|
|
@ -103,7 +102,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
predict_X=X,
|
||||
gp_kernel=gp_kernel,
|
||||
optimize_max_iters=1000,
|
||||
mean_compare_decimal=2, var_compare_decimal=2)
|
||||
mean_compare_decimal=2, var_compare_decimal=1)
|
||||
|
||||
def test_periodic_kernel(self,):
|
||||
np.random.seed(322) # seed the random number generator
|
||||
|
|
@ -172,7 +171,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
self.run_for_model(X, Y, ss_kernel, check_gradients=True,
|
||||
predict_X=X,
|
||||
gp_kernel=gp_kernel,
|
||||
mean_compare_decimal=5, var_compare_decimal=5)
|
||||
mean_compare_decimal=4, var_compare_decimal=4)
|
||||
|
||||
def test_exponential_kernel(self,):
|
||||
np.random.seed(12345) # seed the random number generator
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue