diff --git a/GPy/testing/model_tests.py b/GPy/testing/model_tests.py index c0316cf1..144c6adf 100644 --- a/GPy/testing/model_tests.py +++ b/GPy/testing/model_tests.py @@ -301,11 +301,20 @@ class MiscTests(unittest.TestCase): warp_k = GPy.kern.RBF(1) warp_f = GPy.util.warping_functions.IdentityFunction(closed_inverse=False) - warp_m = GPy.models.WarpedGP(self.X, self.Y, kernel=warp_k, warping_function=warp_f) + warp_m = GPy.models.WarpedGP(self.X, self.Y, kernel=warp_k, + warping_function=warp_f) warp_m.optimize() warp_preds = warp_m.predict(self.X) + + warp_k_exact = GPy.kern.RBF(1) + warp_f_exact = GPy.util.warping_functions.IdentityFunction() + warp_m_exact = GPy.models.WarpedGP(self.X, self.Y, kernel=warp_k_exact, + warping_function=warp_f_exact) + warp_m_exact.optimize() + warp_preds_exact = warp_m_exact.predict(self.X) np.testing.assert_almost_equal(preds, warp_preds, decimal=4) + np.testing.assert_almost_equal(preds, warp_preds_exact, decimal=4) def test_warped_gp_log(self): """ @@ -322,11 +331,20 @@ class MiscTests(unittest.TestCase): warp_k = GPy.kern.RBF(1) warp_f = GPy.util.warping_functions.LogFunction(closed_inverse=False) - warp_m = GPy.models.WarpedGP(self.X, Y, kernel=warp_k, warping_function=warp_f) + warp_m = GPy.models.WarpedGP(self.X, Y, kernel=warp_k, + warping_function=warp_f) warp_m.optimize() warp_preds = warp_m.predict(self.X, median=True)[0] + + warp_k_exact = GPy.kern.RBF(1) + warp_f_exact = GPy.util.warping_functions.LogFunction() + warp_m_exact = GPy.models.WarpedGP(self.X, Y, kernel=warp_k_exact, + warping_function=warp_f_exact) + warp_m_exact.optimize(messages=True) + warp_preds_exact = warp_m_exact.predict(self.X, median=True)[0] np.testing.assert_almost_equal(np.exp(preds), warp_preds, decimal=4) + np.testing.assert_almost_equal(np.exp(preds), warp_preds_exact, decimal=4) def test_warped_gp_cubic_sine(self, max_iters=100): """