mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
added more stable expectations for Bernoulli
This commit is contained in:
parent
47cbdc265e
commit
393b9e94ba
1 changed files with 26 additions and 0 deletions
|
|
@ -77,6 +77,32 @@ class Bernoulli(Likelihood):
|
||||||
|
|
||||||
return Z_hat, mu_hat, sigma2_hat
|
return Z_hat, mu_hat, sigma2_hat
|
||||||
|
|
||||||
|
def variational_expectations(self, Y, m, v, gh_points=None):
|
||||||
|
if isinstance(self.gp_link, link_functions.Probit):
|
||||||
|
|
||||||
|
if gh_points is None:
|
||||||
|
gh_x, gh_w = np.polynomial.hermite.hermgauss(20)
|
||||||
|
else:
|
||||||
|
gh_x, gh_w = gh_points
|
||||||
|
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
|
shape = m.shape
|
||||||
|
m,v,Y = m.flatten(), v.flatten(), Y.flatten()
|
||||||
|
Ysign = np.where(Y==1,1,-1)
|
||||||
|
X = gh_x[None,:]*np.sqrt(2.*v[:,None]) + (m*Ysign)[:,None]
|
||||||
|
p = stats.norm.cdf(X)
|
||||||
|
p = np.clip(p, 1e-9, 1.-1e-9) # for numerical stability
|
||||||
|
N = stats.norm.pdf(X)
|
||||||
|
F = np.log(p).dot(gh_w)
|
||||||
|
NoverP = N/p
|
||||||
|
dF_dm = (NoverP*Ysign[:,None]).dot(gh_w)
|
||||||
|
dF_dv = -0.5*(NoverP**2 + NoverP*X).dot(gh_w)
|
||||||
|
return F.reshape(*shape), dF_dm.reshape(*shape), dF_dv.reshape(*shape), None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def predictive_mean(self, mu, variance, Y_metadata=None):
|
def predictive_mean(self, mu, variance, Y_metadata=None):
|
||||||
|
|
||||||
if isinstance(self.gp_link, link_functions.Probit):
|
if isinstance(self.gp_link, link_functions.Probit):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue