From 7f7b0da7b9e6dd873323781c2497fc166a57c74c Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Wed, 12 Nov 2014 20:19:56 +0000 Subject: [PATCH] spike and slab binary variable numerical enhancement --- GPy/core/parameterization/variational.py | 30 ++++++++++++++++++++---- GPy/kern/_src/psi_comp/ssrbf_psi_comp.py | 20 +++++++--------- GPy/models/ss_gplvm.py | 5 +--- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/GPy/core/parameterization/variational.py b/GPy/core/parameterization/variational.py index e2a24008..bb7678ff 100644 --- a/GPy/core/parameterization/variational.py +++ b/GPy/core/parameterization/variational.py @@ -8,6 +8,8 @@ import numpy as np from parameterized import Parameterized from param import Param from transformations import Logexp, Logistic,__fixed__ +from GPy.util.misc import param_to_array +from GPy.util.caching import Cache_this class VariationalPrior(Parameterized): def __init__(self, name='latent space', **kw): @@ -48,7 +50,8 @@ class SpikeAndSlabPrior(VariationalPrior): def KL_divergence(self, variational_posterior): mu = variational_posterior.mean S = variational_posterior.variance - gamma = variational_posterior.binary_prob + gamma,gamma1 = variational_posterior.gamma_probabilities() + log_gamma,log_gamma1 = variational_posterior.gamma_log_prob() if len(self.pi.shape)==2: idx = np.unique(gamma._raveled_index()/gamma.shape[-1]) pi = self.pi[idx] @@ -57,20 +60,21 @@ class SpikeAndSlabPrior(VariationalPrior): var_mean = np.square(mu)/self.variance var_S = (S/self.variance - np.log(S)) - var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum() + var_gamma = (gamma*(log_gamma-np.log(pi))).sum()+(gamma1*(log_gamma1-np.log(1-pi))).sum() return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2. def update_gradients_KL(self, variational_posterior): mu = variational_posterior.mean S = variational_posterior.variance - gamma = variational_posterior.binary_prob + gamma,gamma1 = variational_posterior.gamma_probabilities() + log_gamma,log_gamma1 = variational_posterior.gamma_log_prob() if len(self.pi.shape)==2: idx = np.unique(gamma._raveled_index()/gamma.shape[-1]) pi = self.pi[idx] else: pi = self.pi - gamma.gradient -= np.log((1-pi)/pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2. + variational_posterior.binary_prob.gradient -= (np.log((1-pi)/pi)+log_gamma-log_gamma1+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.)*gamma*gamma1 mu.gradient -= gamma*mu/self.variance S.gradient -= (1./self.variance - 1./S) * gamma /2. if self.learnPi: @@ -158,8 +162,24 @@ class SpikeAndSlabPosterior(VariationalPosterior): binary_prob : the probability of the distribution on the slab part. """ super(SpikeAndSlabPosterior, self).__init__(means, variances, name) - self.gamma = Param("binary_prob",binary_prob, Logistic(1e-10,1.-1e-10)) + self.gamma = Param("binary_prob",binary_prob) self.link_parameter(self.gamma) + + @Cache_this(limit=5) + def gamma_probabilities(self): + prob = np.zeros_like(param_to_array(self.gamma)) + prob[self.gamma>-710] = 1./(1.+np.exp(-self.gamma[self.gamma>-710])) + prob1 = np.zeros_like(param_to_array(self.gamma)) + prob1[self.gamma<710] = 1./(1.+np.exp(self.gamma[self.gamma<710])) + return prob, prob1 + + @Cache_this(limit=5) + def gamma_log_prob(self): + loggamma = param_to_array(self.gamma).copy() + loggamma[loggamma>-40] = -np.log1p(np.exp(-loggamma[loggamma>-40])) + loggamma1 = param_to_array(self.gamma).copy() + loggamma1[loggamma1<40] = -np.log1p(np.exp(loggamma1[loggamma1<40])) + return loggamma,loggamma1 def set_gradients(self, grad): self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad diff --git a/GPy/kern/_src/psi_comp/ssrbf_psi_comp.py b/GPy/kern/_src/psi_comp/ssrbf_psi_comp.py index f6a24c86..18a4d751 100644 --- a/GPy/kern/_src/psi_comp/ssrbf_psi_comp.py +++ b/GPy/kern/_src/psi_comp/ssrbf_psi_comp.py @@ -22,14 +22,12 @@ try: # _psi1 NxM mu = variational_posterior.mean S = variational_posterior.variance - gamma = variational_posterior.binary_prob N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] l2 = np.square(lengthscale) log_denom1 = np.log(S/l2+1) log_denom2 = np.log(2*S/l2+1) - log_gamma = np.log(gamma) - log_gamma1 = np.log(1.-gamma) + log_gamma,log_gamma1 = variational_posterior.gamma_log_prob() variance = float(variance) psi0 = np.empty(N) psi0[:] = variance @@ -39,7 +37,6 @@ try: from ....util.misc import param_to_array S = param_to_array(S) mu = param_to_array(mu) - gamma = param_to_array(gamma) Z = param_to_array(Z) support_code = """ @@ -82,7 +79,7 @@ try: } } """ - weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz) + weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz) psi2 = psi2n.sum(axis=0) return psi0,psi1,psi2,psi2n @@ -97,13 +94,12 @@ try: mu = variational_posterior.mean S = variational_posterior.variance - gamma = variational_posterior.binary_prob N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] l2 = np.square(lengthscale) log_denom1 = np.log(S/l2+1) log_denom2 = np.log(2*S/l2+1) - log_gamma = np.log(gamma) - log_gamma1 = np.log(1.-gamma) + log_gamma,log_gamma1 = variational_posterior.gamma_log_prob() + gamma, gamma1 = variational_posterior.gamma_probabilities() variance = float(variance) dvar = np.zeros(1) @@ -117,7 +113,6 @@ try: from ....util.misc import param_to_array S = param_to_array(S) mu = param_to_array(mu) - gamma = param_to_array(gamma) Z = param_to_array(Z) support_code = """ @@ -135,6 +130,7 @@ try: double Zm1q = Z(m1,q); double Zm2q = Z(m2,q); double gnq = gamma(n,q); + double g1nq = gamma1(n,q); double mu_nq = mu(n,q); if(m2==0) { @@ -160,7 +156,7 @@ try: dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum); dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.; - dgamma(n,q) += lpsi1*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum; + dgamma(n,q) += lpsi1*(d_exp1*g1nq-d_exp2*gnq)/exp_sum; dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum); dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum; } @@ -188,7 +184,7 @@ try: dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum; dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum; - dgamma(n,q) += lpsi2*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum; + dgamma(n,q) += lpsi2*(d_exp1*g1nq-d_exp2*gnq)/exp_sum; dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum; dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum; } @@ -196,7 +192,7 @@ try: } } """ - weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz) + weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','gamma1','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz) dl *= 2.*lengthscale if not ARD: diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index 04006d84..a61ad2a0 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -39,10 +39,7 @@ class SSGPLVM(SparseGP_MPI): X_variance = np.random.uniform(0,.1,X.shape) if Gamma is None: - gamma = np.empty_like(X) # The posterior probabilities of the binary variable in the variational approximation - gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim) - gamma[gamma>1.-1e-9] = 1.-1e-9 - gamma[gamma<1e-9] = 1e-9 + gamma = np.random.randn(X.shape[0], input_dim) else: gamma = Gamma.copy()