fixes to EP

This commit is contained in:
James Hensman 2014-03-14 11:47:23 +00:00
parent 1ed7d73219
commit 77d08a7d6f
7 changed files with 36 additions and 32 deletions

View file

@ -5,6 +5,7 @@ import numpy as np
from ..util.univariate_Gaussian import std_norm_pdf, std_norm_cdf
import link_functions
from likelihood import Likelihood
from scipy import stats
class Bernoulli(Likelihood):
"""
@ -43,7 +44,7 @@ class Bernoulli(Likelihood):
Y_prep[Y.flatten() == 0] = -1
return Y_prep
def moments_match_ep(self, data_i, tau_i, v_i):
def moments_match_ep(self, Y_i, tau_i, v_i):
"""
Moments match of the marginal approximation in EP algorithm
@ -51,9 +52,9 @@ class Bernoulli(Likelihood):
:param tau_i: precision of the cavity distribution (float)
:param v_i: mean/variance of the cavity distribution (float)
"""
if data_i == 1:
if Y_i == 1:
sign = 1.
elif data_i == 0:
elif Y_i == 0:
sign = -1
else:
raise ValueError("bad value for Bernouilli observation (0, 1)")
@ -76,7 +77,7 @@ class Bernoulli(Likelihood):
return Z_hat, mu_hat, sigma2_hat
def predictive_mean(self, mu, variance):
def predictive_mean(self, mu, variance, Y_metadata=None):
if isinstance(self.gp_link, link_functions.Probit):
return stats.norm.cdf(mu/np.sqrt(1+variance))
@ -87,7 +88,7 @@ class Bernoulli(Likelihood):
else:
raise NotImplementedError
def predictive_variance(self, mu, variance, pred_mean):
def predictive_variance(self, mu, variance, pred_mean, Y_metadata=None):
if isinstance(self.gp_link, link_functions.Heaviside):
return 0.