diff --git a/GPy/likelihoods/binomial.py b/GPy/likelihoods/binomial.py index 61440ec9..122cbcff 100644 --- a/GPy/likelihoods/binomial.py +++ b/GPy/likelihoods/binomial.py @@ -66,7 +66,6 @@ class Binomial(Likelihood): np.testing.assert_array_equal(N.shape, y.shape) nchoosey = special.gammaln(N+1) - special.gammaln(y+1) - special.gammaln(N-y+1) - Ny = N-y t1 = np.zeros(y.shape) t2 = np.zeros(y.shape) @@ -177,6 +176,41 @@ class Binomial(Likelihood): def exact_inference_gradients(self, dL_dKdiag,Y_metadata=None): pass + + def moments_match_ep(self,obs,tau,v,Y_metadata_i=None): + """ + Calculation of moments using quadrature + :param obs: observed output + :param tau: cavity distribution 1st natural parameter (precision) + :param v: cavity distribution 2nd natural paramenter (mu*precision) + """ + #Compute first integral for zeroth moment. + #NOTE constant np.sqrt(2*pi/tau) added at the end of the function + if (isinstance(self.gp_link, link_functions.Probit) or isinstance(self.gp_link, link_functions.ScaledProbit)) and (Y_metadata_i is None or int(Y_metadata_i.get('trials', 1)) == int(1)): #Special case for probit likelihood. Can be found from Riihimaki et Vehtari 2010 + if isinstance(self.gp_link, link_functions.ScaledProbit): + nu = self.gp_link.nu + else: + nu = 1.0 + nu = self.gp_link.nu + mu = v/tau + sigma2 = 1./tau + t = np.asarray(1 + sigma2*(nu**2)) + t[t<1e-20] = 1e-20 + a = np.sqrt(t) + z = obs*mu/a + normc_z = max(self.gp_link.transf(z), 1e-20) + m0 = normc_z + normp_z = self.gp_link.dtransf_df(z) + m1 = mu + (obs*sigma2*normp_z)/(normc_z*a) + #print('tau: {}, v: {}, nu: {}, z: {}, normc_z: {}, normp_z: {}'.format(tau, v, nu.values, z, normc_z, normp_z)) + m2 = sigma2 - ((sigma2**2)*normp_z)/((1./(nu**2)+sigma2)*normc_z)*(z + normp_z/(nu**2)/normc_z) + #print("m0: {}, m1: {}, m2: {}".format(m0,m1,m2)) + #m0a, m1a, m2a = super(Binomial, self).moments_match_ep(obs,tau,v,Y_metadata_i) + #print("m0a: {}, m1a: {}, m2a: {}".format(m0a,m1a,m2a)) + return m0, m1, m2 + else: + return super(Binomial, self).moments_match_ep(obs,tau,v,Y_metadata_i) + def variational_expectations(self, Y, m, v, gh_points=None, Y_metadata=None): if isinstance(self.gp_link, link_functions.Probit):