spike and slab binary variable numerical enhancement

This commit is contained in:
Zhenwen Dai 2014-11-12 20:19:56 +00:00
parent 3c642a5600
commit 7f7b0da7b9
3 changed files with 34 additions and 21 deletions

View file

@ -22,14 +22,12 @@ try:
# _psi1 NxM
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
l2 = np.square(lengthscale)
log_denom1 = np.log(S/l2+1)
log_denom2 = np.log(2*S/l2+1)
log_gamma = np.log(gamma)
log_gamma1 = np.log(1.-gamma)
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
variance = float(variance)
psi0 = np.empty(N)
psi0[:] = variance
@ -39,7 +37,6 @@ try:
from ....util.misc import param_to_array
S = param_to_array(S)
mu = param_to_array(mu)
gamma = param_to_array(gamma)
Z = param_to_array(Z)
support_code = """
@ -82,7 +79,7 @@ try:
}
}
"""
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
psi2 = psi2n.sum(axis=0)
return psi0,psi1,psi2,psi2n
@ -97,13 +94,12 @@ try:
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
l2 = np.square(lengthscale)
log_denom1 = np.log(S/l2+1)
log_denom2 = np.log(2*S/l2+1)
log_gamma = np.log(gamma)
log_gamma1 = np.log(1.-gamma)
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
gamma, gamma1 = variational_posterior.gamma_probabilities()
variance = float(variance)
dvar = np.zeros(1)
@ -117,7 +113,6 @@ try:
from ....util.misc import param_to_array
S = param_to_array(S)
mu = param_to_array(mu)
gamma = param_to_array(gamma)
Z = param_to_array(Z)
support_code = """
@ -135,6 +130,7 @@ try:
double Zm1q = Z(m1,q);
double Zm2q = Z(m2,q);
double gnq = gamma(n,q);
double g1nq = gamma1(n,q);
double mu_nq = mu(n,q);
if(m2==0) {
@ -160,7 +156,7 @@ try:
dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum);
dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.;
dgamma(n,q) += lpsi1*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
dgamma(n,q) += lpsi1*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum);
dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
}
@ -188,7 +184,7 @@ try:
dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum;
dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum;
dgamma(n,q) += lpsi2*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
dgamma(n,q) += lpsi2*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum;
dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
}
@ -196,7 +192,7 @@ try:
}
}
"""
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','gamma1','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
dl *= 2.*lengthscale
if not ARD: