added tests for closed inverse in identity and log

This commit is contained in:
beckdaniel 2016-04-07 18:41:48 +01:00
parent ff700bbebf
commit 2d9bdfdbfb

View file

@ -301,11 +301,20 @@ class MiscTests(unittest.TestCase):
warp_k = GPy.kern.RBF(1) warp_k = GPy.kern.RBF(1)
warp_f = GPy.util.warping_functions.IdentityFunction(closed_inverse=False) warp_f = GPy.util.warping_functions.IdentityFunction(closed_inverse=False)
warp_m = GPy.models.WarpedGP(self.X, self.Y, kernel=warp_k, warping_function=warp_f) warp_m = GPy.models.WarpedGP(self.X, self.Y, kernel=warp_k,
warping_function=warp_f)
warp_m.optimize() warp_m.optimize()
warp_preds = warp_m.predict(self.X) warp_preds = warp_m.predict(self.X)
warp_k_exact = GPy.kern.RBF(1)
warp_f_exact = GPy.util.warping_functions.IdentityFunction()
warp_m_exact = GPy.models.WarpedGP(self.X, self.Y, kernel=warp_k_exact,
warping_function=warp_f_exact)
warp_m_exact.optimize()
warp_preds_exact = warp_m_exact.predict(self.X)
np.testing.assert_almost_equal(preds, warp_preds, decimal=4) np.testing.assert_almost_equal(preds, warp_preds, decimal=4)
np.testing.assert_almost_equal(preds, warp_preds_exact, decimal=4)
def test_warped_gp_log(self): def test_warped_gp_log(self):
""" """
@ -322,11 +331,20 @@ class MiscTests(unittest.TestCase):
warp_k = GPy.kern.RBF(1) warp_k = GPy.kern.RBF(1)
warp_f = GPy.util.warping_functions.LogFunction(closed_inverse=False) warp_f = GPy.util.warping_functions.LogFunction(closed_inverse=False)
warp_m = GPy.models.WarpedGP(self.X, Y, kernel=warp_k, warping_function=warp_f) warp_m = GPy.models.WarpedGP(self.X, Y, kernel=warp_k,
warping_function=warp_f)
warp_m.optimize() warp_m.optimize()
warp_preds = warp_m.predict(self.X, median=True)[0] warp_preds = warp_m.predict(self.X, median=True)[0]
warp_k_exact = GPy.kern.RBF(1)
warp_f_exact = GPy.util.warping_functions.LogFunction()
warp_m_exact = GPy.models.WarpedGP(self.X, Y, kernel=warp_k_exact,
warping_function=warp_f_exact)
warp_m_exact.optimize(messages=True)
warp_preds_exact = warp_m_exact.predict(self.X, median=True)[0]
np.testing.assert_almost_equal(np.exp(preds), warp_preds, decimal=4) np.testing.assert_almost_equal(np.exp(preds), warp_preds, decimal=4)
np.testing.assert_almost_equal(np.exp(preds), warp_preds_exact, decimal=4)
def test_warped_gp_cubic_sine(self, max_iters=100): def test_warped_gp_cubic_sine(self, max_iters=100):
""" """