diff --git a/GPy/likelihoods/bernoulli.py b/GPy/likelihoods/bernoulli.py index ff2ab30a..26de274b 100644 --- a/GPy/likelihoods/bernoulli.py +++ b/GPy/likelihoods/bernoulli.py @@ -77,6 +77,32 @@ class Bernoulli(Likelihood): return Z_hat, mu_hat, sigma2_hat + def variational_expectations(self, Y, m, v, gh_points=None): + if isinstance(self.gp_link, link_functions.Probit): + + if gh_points is None: + gh_x, gh_w = np.polynomial.hermite.hermgauss(20) + else: + gh_x, gh_w = gh_points + + from scipy import stats + + shape = m.shape + m,v,Y = m.flatten(), v.flatten(), Y.flatten() + Ysign = np.where(Y==1,1,-1) + X = gh_x[None,:]*np.sqrt(2.*v[:,None]) + (m*Ysign)[:,None] + p = stats.norm.cdf(X) + p = np.clip(p, 1e-9, 1.-1e-9) # for numerical stability + N = stats.norm.pdf(X) + F = np.log(p).dot(gh_w) + NoverP = N/p + dF_dm = (NoverP*Ysign[:,None]).dot(gh_w) + dF_dv = -0.5*(NoverP**2 + NoverP*X).dot(gh_w) + return F.reshape(*shape), dF_dm.reshape(*shape), dF_dv.reshape(*shape), None + else: + raise NotImplementedError + + def predictive_mean(self, mu, variance, Y_metadata=None): if isinstance(self.gp_link, link_functions.Probit):