fallback the implementation of spike and slab prior

This commit is contained in:
Zhenwen Dai 2015-03-30 21:49:02 +01:00
parent 4f0894b6b7
commit edbb576bfc
4 changed files with 27 additions and 36 deletions

View file

@ -22,12 +22,14 @@ try:
# _psi1 NxM
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
l2 = np.square(lengthscale)
log_denom1 = np.log(S/l2+1)
log_denom2 = np.log(2*S/l2+1)
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
log_gamma = np.log(gamma)
log_gamma1 = np.log(1.-gamma)
variance = float(variance)
psi0 = np.empty(N)
psi0[:] = variance
@ -37,6 +39,7 @@ try:
from ....util.misc import param_to_array
S = param_to_array(S)
mu = param_to_array(mu)
gamma = param_to_array(gamma)
Z = param_to_array(Z)
support_code = """
@ -79,7 +82,7 @@ try:
}
}
"""
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
psi2 = psi2n.sum(axis=0)
return psi0,psi1,psi2,psi2n
@ -94,12 +97,13 @@ try:
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
l2 = np.square(lengthscale)
log_denom1 = np.log(S/l2+1)
log_denom2 = np.log(2*S/l2+1)
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
gamma, gamma1 = variational_posterior.gamma_probabilities()
log_gamma = np.log(gamma)
log_gamma1 = np.log(1.-gamma)
variance = float(variance)
dvar = np.zeros(1)
@ -113,6 +117,7 @@ try:
from ....util.misc import param_to_array
S = param_to_array(S)
mu = param_to_array(mu)
gamma = param_to_array(gamma)
Z = param_to_array(Z)
support_code = """
@ -130,7 +135,6 @@ try:
double Zm1q = Z(m1,q);
double Zm2q = Z(m2,q);
double gnq = gamma(n,q);
double g1nq = gamma1(n,q);
double mu_nq = mu(n,q);
if(m2==0) {
@ -156,7 +160,7 @@ try:
dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum);
dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.;
dgamma(n,q) += lpsi1*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
dgamma(n,q) += lpsi1*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum);
dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
}
@ -184,7 +188,7 @@ try:
dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum;
dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum;
dgamma(n,q) += lpsi2*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
dgamma(n,q) += lpsi2*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum;
dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
}
@ -192,7 +196,7 @@ try:
}
}
"""
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','gamma1','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
dl *= 2.*lengthscale
if not ARD: