mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 10:32:39 +02:00
addedHeraviside functionality to EP
This commit is contained in:
parent
aafce736f8
commit
f2fa9bd74d
3 changed files with 26 additions and 9 deletions
|
|
@ -19,7 +19,7 @@ def binomial(gp_link=None):
|
||||||
analytical_mean = True
|
analytical_mean = True
|
||||||
analytical_variance = False
|
analytical_variance = False
|
||||||
|
|
||||||
elif isinstance(gp_link,noise_models.gp_transformations.Step):
|
elif isinstance(gp_link,noise_models.gp_transformations.Heaviside):
|
||||||
analytical_mean = True
|
analytical_mean = True
|
||||||
analytical_variance = True
|
analytical_variance = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,15 +49,32 @@ class Binomial(NoiseDistribution):
|
||||||
mu_hat = v_i/tau_i + data_i*phi/(Z_hat*np.sqrt(tau_i**2 + tau_i))
|
mu_hat = v_i/tau_i + data_i*phi/(Z_hat*np.sqrt(tau_i**2 + tau_i))
|
||||||
sigma2_hat = 1./tau_i - (phi/((tau_i**2+tau_i)*Z_hat))*(z+phi/Z_hat)
|
sigma2_hat = 1./tau_i - (phi/((tau_i**2+tau_i)*Z_hat))*(z+phi/Z_hat)
|
||||||
|
|
||||||
elif isinstance(self.gp_link,gp_transformations.Step):
|
elif isinstance(self.gp_link,gp_transformations.Heaviside):
|
||||||
Z_hat = None
|
a = data_i*v_i/np.sqrt(tau_i)
|
||||||
mu_hat = None
|
Z_hat = std_norm_cdf(a)
|
||||||
sigma2_hat = None
|
N = std_norm_pdf(a)
|
||||||
|
mu_hat = v_i/tau_i + data_i*N/Z_hat/np.sqrt(tau_i)
|
||||||
|
sigma2_hat = (1. - a*N/Z_hat - np.square(N/Z_hat))/tau_i
|
||||||
|
if np.any(np.isnan([Z_hat, mu_hat, sigma2_hat])):
|
||||||
|
stop
|
||||||
|
|
||||||
return Z_hat, mu_hat, sigma2_hat
|
return Z_hat, mu_hat, sigma2_hat
|
||||||
|
|
||||||
def _predictive_mean_analytical(self,mu,sigma):
|
def _predictive_mean_analytical(self,mu,sigma):
|
||||||
return stats.norm.cdf(mu/np.sqrt(1+sigma**2))
|
if isinstance(self.gp_link,gp_transformations.Probit):
|
||||||
|
return stats.norm.cdf(mu/np.sqrt(1+sigma**2))
|
||||||
|
elif isinstance(self.gp_link,gp_transformations.Heaviside):
|
||||||
|
return stats.norm.cdf(mu/sigma)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _predictive_variance_analytical(self,mu,sigma, pred_mean):
|
||||||
|
if isinstance(self.gp_link,gp_transformations.Heaviside):
|
||||||
|
return 0.
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _mass(self,gp,obs):
|
def _mass(self,gp,obs):
|
||||||
#NOTE obs must be in {0,1}
|
#NOTE obs must be in {0,1}
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ class Reciprocal(GPTransformation):
|
||||||
def d2transf_df2(self,f):
|
def d2transf_df2(self,f):
|
||||||
return 2./f**3
|
return 2./f**3
|
||||||
|
|
||||||
class Step(GPTransformation):
|
class Heaviside(GPTransformation):
|
||||||
"""
|
"""
|
||||||
$$
|
$$
|
||||||
g(f) = I_{x \in A}
|
g(f) = I_{x \in A}
|
||||||
|
|
@ -119,7 +119,7 @@ class Step(GPTransformation):
|
||||||
return np.where(f>0, 1, -1)
|
return np.where(f>0, 1, -1)
|
||||||
|
|
||||||
def dtransf_df(self,f):
|
def dtransf_df(self,f):
|
||||||
pass
|
raise NotImplementedError, "this function is not differentiable!"
|
||||||
|
|
||||||
def d2transf_df2(self,f):
|
def d2transf_df2(self,f):
|
||||||
pass
|
raise NotImplementedError, "this function is not differentiable!"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue