speed ups for normal cdf

This commit is contained in:
James Hensman 2015-04-09 15:42:02 +01:00
parent 337bf67559
commit 1e30ffd730
7 changed files with 38 additions and 96 deletions

View file

@ -2,10 +2,10 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from ..util.univariate_Gaussian import std_norm_pdf, std_norm_cdf
from ..util.univariate_Gaussian import std_norm_cdf, std_norm_pdf
import link_functions
from likelihood import Likelihood
from scipy import stats
class Bernoulli(Likelihood):
"""
@ -81,19 +81,18 @@ class Bernoulli(Likelihood):
if isinstance(self.gp_link, link_functions.Probit):
if gh_points is None:
gh_x, gh_w = np.polynomial.hermite.hermgauss(20)
gh_x, gh_w = self._gh_points()
else:
gh_x, gh_w = gh_points
from scipy import stats
shape = m.shape
m,v,Y = m.flatten(), v.flatten(), Y.flatten()
Ysign = np.where(Y==1,1,-1)
X = gh_x[None,:]*np.sqrt(2.*v[:,None]) + (m*Ysign)[:,None]
p = stats.norm.cdf(X)
p = std_norm_cdf(X)
p = np.clip(p, 1e-9, 1.-1e-9) # for numerical stability
N = stats.norm.pdf(X)
N = std_norm_pdf(X)
F = np.log(p).dot(gh_w)
NoverP = N/p
dF_dm = (NoverP*Ysign[:,None]).dot(gh_w)
@ -106,10 +105,10 @@ class Bernoulli(Likelihood):
def predictive_mean(self, mu, variance, Y_metadata=None):
if isinstance(self.gp_link, link_functions.Probit):
return stats.norm.cdf(mu/np.sqrt(1+variance))
return std_norm_cdf(mu/np.sqrt(1+variance))
elif isinstance(self.gp_link, link_functions.Heaviside):
return stats.norm.cdf(mu/np.sqrt(variance))
return std_norm_cdf(mu/np.sqrt(variance))
else:
raise NotImplementedError