Fixed bernoulli likelihood divide by 0 and log of 0

This commit is contained in:
Alan Saul 2014-02-12 16:48:57 +00:00
parent c788c463d8
commit 46ce76dee8
5 changed files with 33 additions and 20 deletions

View file

@ -116,7 +116,8 @@ class Bernoulli(Likelihood):
Each y_i must be in {0, 1}
"""
assert np.atleast_1d(link_f).shape == np.atleast_1d(y).shape
objective = (link_f**y) * ((1.-link_f)**(1.-y))
#objective = (link_f**y) * ((1.-link_f)**(1.-y))
objective = np.where(y, link_f, 1.-link_f)
return np.exp(np.sum(np.log(objective)))
def logpdf_link(self, link_f, y, extra_data=None):
@ -136,7 +137,9 @@ class Bernoulli(Likelihood):
"""
assert np.atleast_1d(link_f).shape == np.atleast_1d(y).shape
#objective = y*np.log(link_f) + (1.-y)*np.log(link_f)
state = np.seterr(divide='ignore')
objective = np.where(y==1, np.log(link_f), np.log(1-link_f))
np.seterr(**state)
return np.sum(objective)
def dlogpdf_dlink(self, link_f, y, extra_data=None):
@ -155,7 +158,10 @@ class Bernoulli(Likelihood):
:rtype: Nx1 array
"""
assert np.atleast_1d(link_f).shape == np.atleast_1d(y).shape
grad = (y/link_f) - (1.-y)/(1-link_f)
#grad = (y/link_f) - (1.-y)/(1-link_f)
state = np.seterr(divide='ignore')
grad = np.where(y, 1./link_f, -1./(1-link_f))
np.seterr(**state)
return grad
def d2logpdf_dlink2(self, link_f, y, extra_data=None):
@ -180,7 +186,10 @@ class Bernoulli(Likelihood):
(the distribution for y_i depends only on link(f_i) not on link(f_(j!=i))
"""
assert np.atleast_1d(link_f).shape == np.atleast_1d(y).shape
d2logpdf_dlink2 = -y/(link_f**2) - (1-y)/((1-link_f)**2)
#d2logpdf_dlink2 = -y/(link_f**2) - (1-y)/((1-link_f)**2)
state = np.seterr(divide='ignore')
d2logpdf_dlink2 = np.where(y, -1./np.square(link_f), -1./np.square(1.-link_f))
np.seterr(**state)
return d2logpdf_dlink2
def d3logpdf_dlink3(self, link_f, y, extra_data=None):
@ -199,7 +208,10 @@ class Bernoulli(Likelihood):
:rtype: Nx1 array
"""
assert np.atleast_1d(link_f).shape == np.atleast_1d(y).shape
d3logpdf_dlink3 = 2*(y/(link_f**3) - (1-y)/((1-link_f)**3))
#d3logpdf_dlink3 = 2*(y/(link_f**3) - (1-y)/((1-link_f)**3))
state = np.seterr(divide='ignore')
d3logpdf_dlink3 = np.where(y, 2./(link_f**3), -2./((1.-link_f)**3))
np.seterr(**state)
return d3logpdf_dlink3
def samples(self, gp):