diff --git a/GPy/likelihoods/noise_models/bernoulli_noise.py b/GPy/likelihoods/noise_models/bernoulli_noise.py index 2c4116da..17390e55 100644 --- a/GPy/likelihoods/noise_models/bernoulli_noise.py +++ b/GPy/likelihoods/noise_models/bernoulli_noise.py @@ -71,15 +71,19 @@ class Bernoulli(NoiseDistribution): return Z_hat, mu_hat, sigma2_hat - def _predictive_mean_analytical(self,mu,sigma): + def _predictive_mean_analytical(self,mu,variance): + if isinstance(self.gp_link,gp_transformations.Probit): - return stats.norm.cdf(mu/np.sqrt(1+sigma**2)) + return stats.norm.cdf(mu/np.sqrt(1+variance)) + elif isinstance(self.gp_link,gp_transformations.Heaviside): - return stats.norm.cdf(mu/sigma) + return stats.norm.cdf(mu/np.sqrt(variance)) + else: raise NotImplementedError - def _predictive_variance_analytical(self,mu,sigma, pred_mean): + def _predictive_variance_analytical(self,mu,variance, pred_mean): + if isinstance(self.gp_link,gp_transformations.Heaviside): return 0. else: