Made EP faster for Binomial likelihood with Probit or ScaledProbit link functions by adding exact moments

This commit is contained in:
Siivola Eero 2018-09-05 15:42:40 +03:00
parent 936fd1ff71
commit 101e93b73e

View file

@ -66,7 +66,6 @@ class Binomial(Likelihood):
np.testing.assert_array_equal(N.shape, y.shape)
nchoosey = special.gammaln(N+1) - special.gammaln(y+1) - special.gammaln(N-y+1)
Ny = N-y
t1 = np.zeros(y.shape)
t2 = np.zeros(y.shape)
@ -177,6 +176,41 @@ class Binomial(Likelihood):
def exact_inference_gradients(self, dL_dKdiag,Y_metadata=None):
pass
def moments_match_ep(self,obs,tau,v,Y_metadata_i=None):
"""
Calculation of moments using quadrature
:param obs: observed output
:param tau: cavity distribution 1st natural parameter (precision)
:param v: cavity distribution 2nd natural paramenter (mu*precision)
"""
#Compute first integral for zeroth moment.
#NOTE constant np.sqrt(2*pi/tau) added at the end of the function
if (isinstance(self.gp_link, link_functions.Probit) or isinstance(self.gp_link, link_functions.ScaledProbit)) and (Y_metadata_i is None or int(Y_metadata_i.get('trials', 1)) == int(1)): #Special case for probit likelihood. Can be found from Riihimaki et Vehtari 2010
if isinstance(self.gp_link, link_functions.ScaledProbit):
nu = self.gp_link.nu
else:
nu = 1.0
nu = self.gp_link.nu
mu = v/tau
sigma2 = 1./tau
t = np.asarray(1 + sigma2*(nu**2))
t[t<1e-20] = 1e-20
a = np.sqrt(t)
z = obs*mu/a
normc_z = max(self.gp_link.transf(z), 1e-20)
m0 = normc_z
normp_z = self.gp_link.dtransf_df(z)
m1 = mu + (obs*sigma2*normp_z)/(normc_z*a)
#print('tau: {}, v: {}, nu: {}, z: {}, normc_z: {}, normp_z: {}'.format(tau, v, nu.values, z, normc_z, normp_z))
m2 = sigma2 - ((sigma2**2)*normp_z)/((1./(nu**2)+sigma2)*normc_z)*(z + normp_z/(nu**2)/normc_z)
#print("m0: {}, m1: {}, m2: {}".format(m0,m1,m2))
#m0a, m1a, m2a = super(Binomial, self).moments_match_ep(obs,tau,v,Y_metadata_i)
#print("m0a: {}, m1a: {}, m2a: {}".format(m0a,m1a,m2a))
return m0, m1, m2
else:
return super(Binomial, self).moments_match_ep(obs,tau,v,Y_metadata_i)
def variational_expectations(self, Y, m, v, gh_points=None, Y_metadata=None):
if isinstance(self.gp_link, link_functions.Probit):