fixing a typo-bug in the last commit for ep test case

This commit is contained in:
Akash Kumar Dhaka 2017-06-26 17:33:57 +03:00
parent 38597b1ede
commit 19598cf807

View file

@ -31,7 +31,7 @@ class TestObservationModels(unittest.TestCase):
self.Y_noisy[75:80] += 1.3
self.init_var = 0.3
self.deg_free = 5.
self.deg_free = 4.
censored = np.zeros_like(self.Y)
random_inds = np.random.choice(self.N, int(self.N / 2), replace=True)
censored[random_inds] = 1
@ -83,7 +83,7 @@ class TestObservationModels(unittest.TestCase):
# taking laplace predictions as the ground truth
probs_mean_lap, probs_var_lap = m1.predict(self.X)
probs_mean_ep_alt, probs_var_ep_alt = m2.predict(self.X)
probs_mean_ep_nested, probs_var_ep_nested = m2.predict(self.X)
probs_mean_ep_nested, probs_var_ep_nested = m3.predict(self.X)
# for simple single dimension data , marginal likelihood for laplace and EP approximations should not be so far apart.
self.assertAlmostEqual(m1.log_likelihood(), m2.log_likelihood(),delta=1)
@ -125,6 +125,7 @@ class TestObservationModels(unittest.TestCase):
optimizer='bfgs'
m1.optimize(optimizer=optimizer,max_iters=400)
m2.optimize(optimizer=optimizer, max_iters=500)
# m3.optimize(optimizer=optimizer, max_iters=500)
self.assertAlmostEqual(m1.log_likelihood(), m2.log_likelihood(),delta=10)
# self.assertAlmostEqual(m1.log_likelihood(), m3.log_likelihood(), 3)
@ -132,12 +133,12 @@ class TestObservationModels(unittest.TestCase):
preds_mean_lap, preds_var_lap = m1.predict(self.X)
preds_mean_alt, preds_var_alt = m2.predict(self.X)
# preds_mean_nested, preds_var_nested = m3.predict(self.X)
rmse_lap = self.rmse(preds_mean_lap, self.Y_noisy)
rmse_alt = self.rmse(preds_mean_alt, self.Y_noisy)
rmse_lap = self.rmse(preds_mean_lap, self.Y)
rmse_alt = self.rmse(preds_mean_alt, self.Y)
# rmse_nested = self.rmse(preds_mean_nested, self.Y_noisy)
if rmse_alt > rmse_alt:
self.assertAlmostEqual(rmse_lap, rmse_alt, delta=1.)
if rmse_alt > rmse_lap:
self.assertAlmostEqual(rmse_lap, rmse_alt, delta=1.5)
# m3.optimize(optimizer=optimizer, max_iters=500)