mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 19:12:40 +02:00
spike and slab binary variable numerical enhancement
This commit is contained in:
parent
3c642a5600
commit
7f7b0da7b9
3 changed files with 34 additions and 21 deletions
|
|
@ -8,6 +8,8 @@ import numpy as np
|
||||||
from parameterized import Parameterized
|
from parameterized import Parameterized
|
||||||
from param import Param
|
from param import Param
|
||||||
from transformations import Logexp, Logistic,__fixed__
|
from transformations import Logexp, Logistic,__fixed__
|
||||||
|
from GPy.util.misc import param_to_array
|
||||||
|
from GPy.util.caching import Cache_this
|
||||||
|
|
||||||
class VariationalPrior(Parameterized):
|
class VariationalPrior(Parameterized):
|
||||||
def __init__(self, name='latent space', **kw):
|
def __init__(self, name='latent space', **kw):
|
||||||
|
|
@ -48,7 +50,8 @@ class SpikeAndSlabPrior(VariationalPrior):
|
||||||
def KL_divergence(self, variational_posterior):
|
def KL_divergence(self, variational_posterior):
|
||||||
mu = variational_posterior.mean
|
mu = variational_posterior.mean
|
||||||
S = variational_posterior.variance
|
S = variational_posterior.variance
|
||||||
gamma = variational_posterior.binary_prob
|
gamma,gamma1 = variational_posterior.gamma_probabilities()
|
||||||
|
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
|
||||||
if len(self.pi.shape)==2:
|
if len(self.pi.shape)==2:
|
||||||
idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
|
idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
|
||||||
pi = self.pi[idx]
|
pi = self.pi[idx]
|
||||||
|
|
@ -57,20 +60,21 @@ class SpikeAndSlabPrior(VariationalPrior):
|
||||||
|
|
||||||
var_mean = np.square(mu)/self.variance
|
var_mean = np.square(mu)/self.variance
|
||||||
var_S = (S/self.variance - np.log(S))
|
var_S = (S/self.variance - np.log(S))
|
||||||
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum()
|
var_gamma = (gamma*(log_gamma-np.log(pi))).sum()+(gamma1*(log_gamma1-np.log(1-pi))).sum()
|
||||||
return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2.
|
return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2.
|
||||||
|
|
||||||
def update_gradients_KL(self, variational_posterior):
|
def update_gradients_KL(self, variational_posterior):
|
||||||
mu = variational_posterior.mean
|
mu = variational_posterior.mean
|
||||||
S = variational_posterior.variance
|
S = variational_posterior.variance
|
||||||
gamma = variational_posterior.binary_prob
|
gamma,gamma1 = variational_posterior.gamma_probabilities()
|
||||||
|
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
|
||||||
if len(self.pi.shape)==2:
|
if len(self.pi.shape)==2:
|
||||||
idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
|
idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
|
||||||
pi = self.pi[idx]
|
pi = self.pi[idx]
|
||||||
else:
|
else:
|
||||||
pi = self.pi
|
pi = self.pi
|
||||||
|
|
||||||
gamma.gradient -= np.log((1-pi)/pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.
|
variational_posterior.binary_prob.gradient -= (np.log((1-pi)/pi)+log_gamma-log_gamma1+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.)*gamma*gamma1
|
||||||
mu.gradient -= gamma*mu/self.variance
|
mu.gradient -= gamma*mu/self.variance
|
||||||
S.gradient -= (1./self.variance - 1./S) * gamma /2.
|
S.gradient -= (1./self.variance - 1./S) * gamma /2.
|
||||||
if self.learnPi:
|
if self.learnPi:
|
||||||
|
|
@ -158,8 +162,24 @@ class SpikeAndSlabPosterior(VariationalPosterior):
|
||||||
binary_prob : the probability of the distribution on the slab part.
|
binary_prob : the probability of the distribution on the slab part.
|
||||||
"""
|
"""
|
||||||
super(SpikeAndSlabPosterior, self).__init__(means, variances, name)
|
super(SpikeAndSlabPosterior, self).__init__(means, variances, name)
|
||||||
self.gamma = Param("binary_prob",binary_prob, Logistic(1e-10,1.-1e-10))
|
self.gamma = Param("binary_prob",binary_prob)
|
||||||
self.link_parameter(self.gamma)
|
self.link_parameter(self.gamma)
|
||||||
|
|
||||||
|
@Cache_this(limit=5)
|
||||||
|
def gamma_probabilities(self):
|
||||||
|
prob = np.zeros_like(param_to_array(self.gamma))
|
||||||
|
prob[self.gamma>-710] = 1./(1.+np.exp(-self.gamma[self.gamma>-710]))
|
||||||
|
prob1 = np.zeros_like(param_to_array(self.gamma))
|
||||||
|
prob1[self.gamma<710] = 1./(1.+np.exp(self.gamma[self.gamma<710]))
|
||||||
|
return prob, prob1
|
||||||
|
|
||||||
|
@Cache_this(limit=5)
|
||||||
|
def gamma_log_prob(self):
|
||||||
|
loggamma = param_to_array(self.gamma).copy()
|
||||||
|
loggamma[loggamma>-40] = -np.log1p(np.exp(-loggamma[loggamma>-40]))
|
||||||
|
loggamma1 = param_to_array(self.gamma).copy()
|
||||||
|
loggamma1[loggamma1<40] = -np.log1p(np.exp(loggamma1[loggamma1<40]))
|
||||||
|
return loggamma,loggamma1
|
||||||
|
|
||||||
def set_gradients(self, grad):
|
def set_gradients(self, grad):
|
||||||
self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad
|
self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad
|
||||||
|
|
|
||||||
|
|
@ -22,14 +22,12 @@ try:
|
||||||
# _psi1 NxM
|
# _psi1 NxM
|
||||||
mu = variational_posterior.mean
|
mu = variational_posterior.mean
|
||||||
S = variational_posterior.variance
|
S = variational_posterior.variance
|
||||||
gamma = variational_posterior.binary_prob
|
|
||||||
|
|
||||||
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
||||||
l2 = np.square(lengthscale)
|
l2 = np.square(lengthscale)
|
||||||
log_denom1 = np.log(S/l2+1)
|
log_denom1 = np.log(S/l2+1)
|
||||||
log_denom2 = np.log(2*S/l2+1)
|
log_denom2 = np.log(2*S/l2+1)
|
||||||
log_gamma = np.log(gamma)
|
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
|
||||||
log_gamma1 = np.log(1.-gamma)
|
|
||||||
variance = float(variance)
|
variance = float(variance)
|
||||||
psi0 = np.empty(N)
|
psi0 = np.empty(N)
|
||||||
psi0[:] = variance
|
psi0[:] = variance
|
||||||
|
|
@ -39,7 +37,6 @@ try:
|
||||||
from ....util.misc import param_to_array
|
from ....util.misc import param_to_array
|
||||||
S = param_to_array(S)
|
S = param_to_array(S)
|
||||||
mu = param_to_array(mu)
|
mu = param_to_array(mu)
|
||||||
gamma = param_to_array(gamma)
|
|
||||||
Z = param_to_array(Z)
|
Z = param_to_array(Z)
|
||||||
|
|
||||||
support_code = """
|
support_code = """
|
||||||
|
|
@ -82,7 +79,7 @@ try:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
|
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
|
||||||
|
|
||||||
psi2 = psi2n.sum(axis=0)
|
psi2 = psi2n.sum(axis=0)
|
||||||
return psi0,psi1,psi2,psi2n
|
return psi0,psi1,psi2,psi2n
|
||||||
|
|
@ -97,13 +94,12 @@ try:
|
||||||
|
|
||||||
mu = variational_posterior.mean
|
mu = variational_posterior.mean
|
||||||
S = variational_posterior.variance
|
S = variational_posterior.variance
|
||||||
gamma = variational_posterior.binary_prob
|
|
||||||
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
||||||
l2 = np.square(lengthscale)
|
l2 = np.square(lengthscale)
|
||||||
log_denom1 = np.log(S/l2+1)
|
log_denom1 = np.log(S/l2+1)
|
||||||
log_denom2 = np.log(2*S/l2+1)
|
log_denom2 = np.log(2*S/l2+1)
|
||||||
log_gamma = np.log(gamma)
|
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
|
||||||
log_gamma1 = np.log(1.-gamma)
|
gamma, gamma1 = variational_posterior.gamma_probabilities()
|
||||||
variance = float(variance)
|
variance = float(variance)
|
||||||
|
|
||||||
dvar = np.zeros(1)
|
dvar = np.zeros(1)
|
||||||
|
|
@ -117,7 +113,6 @@ try:
|
||||||
from ....util.misc import param_to_array
|
from ....util.misc import param_to_array
|
||||||
S = param_to_array(S)
|
S = param_to_array(S)
|
||||||
mu = param_to_array(mu)
|
mu = param_to_array(mu)
|
||||||
gamma = param_to_array(gamma)
|
|
||||||
Z = param_to_array(Z)
|
Z = param_to_array(Z)
|
||||||
|
|
||||||
support_code = """
|
support_code = """
|
||||||
|
|
@ -135,6 +130,7 @@ try:
|
||||||
double Zm1q = Z(m1,q);
|
double Zm1q = Z(m1,q);
|
||||||
double Zm2q = Z(m2,q);
|
double Zm2q = Z(m2,q);
|
||||||
double gnq = gamma(n,q);
|
double gnq = gamma(n,q);
|
||||||
|
double g1nq = gamma1(n,q);
|
||||||
double mu_nq = mu(n,q);
|
double mu_nq = mu(n,q);
|
||||||
|
|
||||||
if(m2==0) {
|
if(m2==0) {
|
||||||
|
|
@ -160,7 +156,7 @@ try:
|
||||||
|
|
||||||
dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum);
|
dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum);
|
||||||
dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.;
|
dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.;
|
||||||
dgamma(n,q) += lpsi1*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
|
dgamma(n,q) += lpsi1*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
|
||||||
dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum);
|
dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum);
|
||||||
dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
|
dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
|
||||||
}
|
}
|
||||||
|
|
@ -188,7 +184,7 @@ try:
|
||||||
|
|
||||||
dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum;
|
dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum;
|
||||||
dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum;
|
dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum;
|
||||||
dgamma(n,q) += lpsi2*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
|
dgamma(n,q) += lpsi2*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
|
||||||
dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum;
|
dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum;
|
||||||
dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
|
dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
|
||||||
}
|
}
|
||||||
|
|
@ -196,7 +192,7 @@ try:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
|
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','gamma1','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
|
||||||
|
|
||||||
dl *= 2.*lengthscale
|
dl *= 2.*lengthscale
|
||||||
if not ARD:
|
if not ARD:
|
||||||
|
|
|
||||||
|
|
@ -39,10 +39,7 @@ class SSGPLVM(SparseGP_MPI):
|
||||||
X_variance = np.random.uniform(0,.1,X.shape)
|
X_variance = np.random.uniform(0,.1,X.shape)
|
||||||
|
|
||||||
if Gamma is None:
|
if Gamma is None:
|
||||||
gamma = np.empty_like(X) # The posterior probabilities of the binary variable in the variational approximation
|
gamma = np.random.randn(X.shape[0], input_dim)
|
||||||
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
|
|
||||||
gamma[gamma>1.-1e-9] = 1.-1e-9
|
|
||||||
gamma[gamma<1e-9] = 1e-9
|
|
||||||
else:
|
else:
|
||||||
gamma = Gamma.copy()
|
gamma = Gamma.copy()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue