spike and slab binary variable numerical enhancement

This commit is contained in:
Zhenwen Dai 2014-11-12 20:19:56 +00:00
parent 3c642a5600
commit 7f7b0da7b9
3 changed files with 34 additions and 21 deletions

View file

@ -8,6 +8,8 @@ import numpy as np
from parameterized import Parameterized from parameterized import Parameterized
from param import Param from param import Param
from transformations import Logexp, Logistic,__fixed__ from transformations import Logexp, Logistic,__fixed__
from GPy.util.misc import param_to_array
from GPy.util.caching import Cache_this
class VariationalPrior(Parameterized): class VariationalPrior(Parameterized):
def __init__(self, name='latent space', **kw): def __init__(self, name='latent space', **kw):
@ -48,7 +50,8 @@ class SpikeAndSlabPrior(VariationalPrior):
def KL_divergence(self, variational_posterior): def KL_divergence(self, variational_posterior):
mu = variational_posterior.mean mu = variational_posterior.mean
S = variational_posterior.variance S = variational_posterior.variance
gamma = variational_posterior.binary_prob gamma,gamma1 = variational_posterior.gamma_probabilities()
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
if len(self.pi.shape)==2: if len(self.pi.shape)==2:
idx = np.unique(gamma._raveled_index()/gamma.shape[-1]) idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
pi = self.pi[idx] pi = self.pi[idx]
@ -57,20 +60,21 @@ class SpikeAndSlabPrior(VariationalPrior):
var_mean = np.square(mu)/self.variance var_mean = np.square(mu)/self.variance
var_S = (S/self.variance - np.log(S)) var_S = (S/self.variance - np.log(S))
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum() var_gamma = (gamma*(log_gamma-np.log(pi))).sum()+(gamma1*(log_gamma1-np.log(1-pi))).sum()
return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2. return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2.
def update_gradients_KL(self, variational_posterior): def update_gradients_KL(self, variational_posterior):
mu = variational_posterior.mean mu = variational_posterior.mean
S = variational_posterior.variance S = variational_posterior.variance
gamma = variational_posterior.binary_prob gamma,gamma1 = variational_posterior.gamma_probabilities()
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
if len(self.pi.shape)==2: if len(self.pi.shape)==2:
idx = np.unique(gamma._raveled_index()/gamma.shape[-1]) idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
pi = self.pi[idx] pi = self.pi[idx]
else: else:
pi = self.pi pi = self.pi
gamma.gradient -= np.log((1-pi)/pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2. variational_posterior.binary_prob.gradient -= (np.log((1-pi)/pi)+log_gamma-log_gamma1+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.)*gamma*gamma1
mu.gradient -= gamma*mu/self.variance mu.gradient -= gamma*mu/self.variance
S.gradient -= (1./self.variance - 1./S) * gamma /2. S.gradient -= (1./self.variance - 1./S) * gamma /2.
if self.learnPi: if self.learnPi:
@ -158,8 +162,24 @@ class SpikeAndSlabPosterior(VariationalPosterior):
binary_prob : the probability of the distribution on the slab part. binary_prob : the probability of the distribution on the slab part.
""" """
super(SpikeAndSlabPosterior, self).__init__(means, variances, name) super(SpikeAndSlabPosterior, self).__init__(means, variances, name)
self.gamma = Param("binary_prob",binary_prob, Logistic(1e-10,1.-1e-10)) self.gamma = Param("binary_prob",binary_prob)
self.link_parameter(self.gamma) self.link_parameter(self.gamma)
@Cache_this(limit=5)
def gamma_probabilities(self):
prob = np.zeros_like(param_to_array(self.gamma))
prob[self.gamma>-710] = 1./(1.+np.exp(-self.gamma[self.gamma>-710]))
prob1 = np.zeros_like(param_to_array(self.gamma))
prob1[self.gamma<710] = 1./(1.+np.exp(self.gamma[self.gamma<710]))
return prob, prob1
@Cache_this(limit=5)
def gamma_log_prob(self):
loggamma = param_to_array(self.gamma).copy()
loggamma[loggamma>-40] = -np.log1p(np.exp(-loggamma[loggamma>-40]))
loggamma1 = param_to_array(self.gamma).copy()
loggamma1[loggamma1<40] = -np.log1p(np.exp(loggamma1[loggamma1<40]))
return loggamma,loggamma1
def set_gradients(self, grad): def set_gradients(self, grad):
self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad

View file

@ -22,14 +22,12 @@ try:
# _psi1 NxM # _psi1 NxM
mu = variational_posterior.mean mu = variational_posterior.mean
S = variational_posterior.variance S = variational_posterior.variance
gamma = variational_posterior.binary_prob
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
l2 = np.square(lengthscale) l2 = np.square(lengthscale)
log_denom1 = np.log(S/l2+1) log_denom1 = np.log(S/l2+1)
log_denom2 = np.log(2*S/l2+1) log_denom2 = np.log(2*S/l2+1)
log_gamma = np.log(gamma) log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
log_gamma1 = np.log(1.-gamma)
variance = float(variance) variance = float(variance)
psi0 = np.empty(N) psi0 = np.empty(N)
psi0[:] = variance psi0[:] = variance
@ -39,7 +37,6 @@ try:
from ....util.misc import param_to_array from ....util.misc import param_to_array
S = param_to_array(S) S = param_to_array(S)
mu = param_to_array(mu) mu = param_to_array(mu)
gamma = param_to_array(gamma)
Z = param_to_array(Z) Z = param_to_array(Z)
support_code = """ support_code = """
@ -82,7 +79,7 @@ try:
} }
} }
""" """
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz) weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
psi2 = psi2n.sum(axis=0) psi2 = psi2n.sum(axis=0)
return psi0,psi1,psi2,psi2n return psi0,psi1,psi2,psi2n
@ -97,13 +94,12 @@ try:
mu = variational_posterior.mean mu = variational_posterior.mean
S = variational_posterior.variance S = variational_posterior.variance
gamma = variational_posterior.binary_prob
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
l2 = np.square(lengthscale) l2 = np.square(lengthscale)
log_denom1 = np.log(S/l2+1) log_denom1 = np.log(S/l2+1)
log_denom2 = np.log(2*S/l2+1) log_denom2 = np.log(2*S/l2+1)
log_gamma = np.log(gamma) log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
log_gamma1 = np.log(1.-gamma) gamma, gamma1 = variational_posterior.gamma_probabilities()
variance = float(variance) variance = float(variance)
dvar = np.zeros(1) dvar = np.zeros(1)
@ -117,7 +113,6 @@ try:
from ....util.misc import param_to_array from ....util.misc import param_to_array
S = param_to_array(S) S = param_to_array(S)
mu = param_to_array(mu) mu = param_to_array(mu)
gamma = param_to_array(gamma)
Z = param_to_array(Z) Z = param_to_array(Z)
support_code = """ support_code = """
@ -135,6 +130,7 @@ try:
double Zm1q = Z(m1,q); double Zm1q = Z(m1,q);
double Zm2q = Z(m2,q); double Zm2q = Z(m2,q);
double gnq = gamma(n,q); double gnq = gamma(n,q);
double g1nq = gamma1(n,q);
double mu_nq = mu(n,q); double mu_nq = mu(n,q);
if(m2==0) { if(m2==0) {
@ -160,7 +156,7 @@ try:
dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum); dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum);
dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.; dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.;
dgamma(n,q) += lpsi1*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum; dgamma(n,q) += lpsi1*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum); dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum);
dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum; dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
} }
@ -188,7 +184,7 @@ try:
dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum; dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum;
dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum; dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum;
dgamma(n,q) += lpsi2*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum; dgamma(n,q) += lpsi2*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum; dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum;
dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum; dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
} }
@ -196,7 +192,7 @@ try:
} }
} }
""" """
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz) weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','gamma1','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
dl *= 2.*lengthscale dl *= 2.*lengthscale
if not ARD: if not ARD:

View file

@ -39,10 +39,7 @@ class SSGPLVM(SparseGP_MPI):
X_variance = np.random.uniform(0,.1,X.shape) X_variance = np.random.uniform(0,.1,X.shape)
if Gamma is None: if Gamma is None:
gamma = np.empty_like(X) # The posterior probabilities of the binary variable in the variational approximation gamma = np.random.randn(X.shape[0], input_dim)
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
gamma[gamma>1.-1e-9] = 1.-1e-9
gamma[gamma<1e-9] = 1e-9
else: else:
gamma = Gamma.copy() gamma = Gamma.copy()