addedHeraviside functionality to EP

This commit is contained in:
James Hensman 2013-09-16 16:20:26 +01:00
parent aafce736f8
commit f2fa9bd74d
3 changed files with 26 additions and 9 deletions

View file

@ -49,15 +49,32 @@ class Binomial(NoiseDistribution):
mu_hat = v_i/tau_i + data_i*phi/(Z_hat*np.sqrt(tau_i**2 + tau_i))
sigma2_hat = 1./tau_i - (phi/((tau_i**2+tau_i)*Z_hat))*(z+phi/Z_hat)
elif isinstance(self.gp_link,gp_transformations.Step):
Z_hat = None
mu_hat = None
sigma2_hat = None
elif isinstance(self.gp_link,gp_transformations.Heaviside):
a = data_i*v_i/np.sqrt(tau_i)
Z_hat = std_norm_cdf(a)
N = std_norm_pdf(a)
mu_hat = v_i/tau_i + data_i*N/Z_hat/np.sqrt(tau_i)
sigma2_hat = (1. - a*N/Z_hat - np.square(N/Z_hat))/tau_i
if np.any(np.isnan([Z_hat, mu_hat, sigma2_hat])):
stop
return Z_hat, mu_hat, sigma2_hat
def _predictive_mean_analytical(self,mu,sigma):
return stats.norm.cdf(mu/np.sqrt(1+sigma**2))
if isinstance(self.gp_link,gp_transformations.Probit):
return stats.norm.cdf(mu/np.sqrt(1+sigma**2))
elif isinstance(self.gp_link,gp_transformations.Heaviside):
return stats.norm.cdf(mu/sigma)
else:
raise NotImplementedError
def _predictive_variance_analytical(self,mu,sigma, pred_mean):
if isinstance(self.gp_link,gp_transformations.Heaviside):
return 0.
else:
raise NotImplementedError
def _mass(self,gp,obs):
#NOTE obs must be in {0,1}