mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-18 13:55:14 +02:00
Heaviside transformation fixed
This commit is contained in:
parent
42589a657a
commit
7e4dca7e3a
4 changed files with 37 additions and 7 deletions
|
|
@ -64,7 +64,7 @@ class Binomial(NoiseDistribution):
|
|||
if isinstance(self.gp_link,gp_transformations.Probit):
|
||||
return stats.norm.cdf(mu/np.sqrt(1+sigma**2))
|
||||
elif isinstance(self.gp_link,gp_transformations.Heaviside):
|
||||
return stats.norm.cdf(mu/sigma)
|
||||
return stats.norm.cdf(mu/sigma)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -74,8 +74,6 @@ class Binomial(NoiseDistribution):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
def _mass(self,gp,obs):
|
||||
#NOTE obs must be in {0,1}
|
||||
p = self.gp_link.transf(gp)
|
||||
|
|
|
|||
|
|
@ -116,10 +116,10 @@ class Heaviside(GPTransformation):
|
|||
"""
|
||||
def transf(self,f):
|
||||
#transformation goes here
|
||||
return np.where(f>0, 1, -1)
|
||||
return np.where(f>0, 1, 0)
|
||||
|
||||
def dtransf_df(self,f):
|
||||
raise NotImplementedError, "this function is not differentiable!"
|
||||
raise NotImplementedError, "This function is not differentiable!"
|
||||
|
||||
def d2transf_df2(self,f):
|
||||
raise NotImplementedError, "this function is not differentiable!"
|
||||
raise NotImplementedError, "This function is not differentiable!"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue