mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
Made EP faster for Binomial likelihood with Probit or ScaledProbit link functions by adding exact moments
This commit is contained in:
parent
936fd1ff71
commit
101e93b73e
1 changed files with 35 additions and 1 deletions
|
|
@ -66,7 +66,6 @@ class Binomial(Likelihood):
|
|||
np.testing.assert_array_equal(N.shape, y.shape)
|
||||
|
||||
nchoosey = special.gammaln(N+1) - special.gammaln(y+1) - special.gammaln(N-y+1)
|
||||
|
||||
Ny = N-y
|
||||
t1 = np.zeros(y.shape)
|
||||
t2 = np.zeros(y.shape)
|
||||
|
|
@ -177,6 +176,41 @@ class Binomial(Likelihood):
|
|||
|
||||
def exact_inference_gradients(self, dL_dKdiag,Y_metadata=None):
|
||||
pass
|
||||
|
||||
def moments_match_ep(self,obs,tau,v,Y_metadata_i=None):
|
||||
"""
|
||||
Calculation of moments using quadrature
|
||||
:param obs: observed output
|
||||
:param tau: cavity distribution 1st natural parameter (precision)
|
||||
:param v: cavity distribution 2nd natural paramenter (mu*precision)
|
||||
"""
|
||||
#Compute first integral for zeroth moment.
|
||||
#NOTE constant np.sqrt(2*pi/tau) added at the end of the function
|
||||
if (isinstance(self.gp_link, link_functions.Probit) or isinstance(self.gp_link, link_functions.ScaledProbit)) and (Y_metadata_i is None or int(Y_metadata_i.get('trials', 1)) == int(1)): #Special case for probit likelihood. Can be found from Riihimaki et Vehtari 2010
|
||||
if isinstance(self.gp_link, link_functions.ScaledProbit):
|
||||
nu = self.gp_link.nu
|
||||
else:
|
||||
nu = 1.0
|
||||
nu = self.gp_link.nu
|
||||
mu = v/tau
|
||||
sigma2 = 1./tau
|
||||
t = np.asarray(1 + sigma2*(nu**2))
|
||||
t[t<1e-20] = 1e-20
|
||||
a = np.sqrt(t)
|
||||
z = obs*mu/a
|
||||
normc_z = max(self.gp_link.transf(z), 1e-20)
|
||||
m0 = normc_z
|
||||
normp_z = self.gp_link.dtransf_df(z)
|
||||
m1 = mu + (obs*sigma2*normp_z)/(normc_z*a)
|
||||
#print('tau: {}, v: {}, nu: {}, z: {}, normc_z: {}, normp_z: {}'.format(tau, v, nu.values, z, normc_z, normp_z))
|
||||
m2 = sigma2 - ((sigma2**2)*normp_z)/((1./(nu**2)+sigma2)*normc_z)*(z + normp_z/(nu**2)/normc_z)
|
||||
#print("m0: {}, m1: {}, m2: {}".format(m0,m1,m2))
|
||||
#m0a, m1a, m2a = super(Binomial, self).moments_match_ep(obs,tau,v,Y_metadata_i)
|
||||
#print("m0a: {}, m1a: {}, m2a: {}".format(m0a,m1a,m2a))
|
||||
return m0, m1, m2
|
||||
else:
|
||||
return super(Binomial, self).moments_match_ep(obs,tau,v,Y_metadata_i)
|
||||
|
||||
def variational_expectations(self, Y, m, v, gh_points=None, Y_metadata=None):
|
||||
if isinstance(self.gp_link, link_functions.Probit):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue