mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 11:02:38 +02:00
fallback the implementation of spike and slab prior
This commit is contained in:
parent
4f0894b6b7
commit
edbb576bfc
4 changed files with 27 additions and 36 deletions
|
|
@ -22,12 +22,14 @@ try:
|
|||
# _psi1 NxM
|
||||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
gamma = variational_posterior.binary_prob
|
||||
|
||||
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
||||
l2 = np.square(lengthscale)
|
||||
log_denom1 = np.log(S/l2+1)
|
||||
log_denom2 = np.log(2*S/l2+1)
|
||||
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
|
||||
log_gamma = np.log(gamma)
|
||||
log_gamma1 = np.log(1.-gamma)
|
||||
variance = float(variance)
|
||||
psi0 = np.empty(N)
|
||||
psi0[:] = variance
|
||||
|
|
@ -37,6 +39,7 @@ try:
|
|||
from ....util.misc import param_to_array
|
||||
S = param_to_array(S)
|
||||
mu = param_to_array(mu)
|
||||
gamma = param_to_array(gamma)
|
||||
Z = param_to_array(Z)
|
||||
|
||||
support_code = """
|
||||
|
|
@ -79,7 +82,7 @@ try:
|
|||
}
|
||||
}
|
||||
"""
|
||||
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
|
||||
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
|
||||
|
||||
psi2 = psi2n.sum(axis=0)
|
||||
return psi0,psi1,psi2,psi2n
|
||||
|
|
@ -94,12 +97,13 @@ try:
|
|||
|
||||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
gamma = variational_posterior.binary_prob
|
||||
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
||||
l2 = np.square(lengthscale)
|
||||
log_denom1 = np.log(S/l2+1)
|
||||
log_denom2 = np.log(2*S/l2+1)
|
||||
log_gamma,log_gamma1 = variational_posterior.gamma_log_prob()
|
||||
gamma, gamma1 = variational_posterior.gamma_probabilities()
|
||||
log_gamma = np.log(gamma)
|
||||
log_gamma1 = np.log(1.-gamma)
|
||||
variance = float(variance)
|
||||
|
||||
dvar = np.zeros(1)
|
||||
|
|
@ -113,6 +117,7 @@ try:
|
|||
from ....util.misc import param_to_array
|
||||
S = param_to_array(S)
|
||||
mu = param_to_array(mu)
|
||||
gamma = param_to_array(gamma)
|
||||
Z = param_to_array(Z)
|
||||
|
||||
support_code = """
|
||||
|
|
@ -130,7 +135,6 @@ try:
|
|||
double Zm1q = Z(m1,q);
|
||||
double Zm2q = Z(m2,q);
|
||||
double gnq = gamma(n,q);
|
||||
double g1nq = gamma1(n,q);
|
||||
double mu_nq = mu(n,q);
|
||||
|
||||
if(m2==0) {
|
||||
|
|
@ -156,7 +160,7 @@ try:
|
|||
|
||||
dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum);
|
||||
dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.;
|
||||
dgamma(n,q) += lpsi1*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
|
||||
dgamma(n,q) += lpsi1*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
|
||||
dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum);
|
||||
dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
|
||||
}
|
||||
|
|
@ -184,7 +188,7 @@ try:
|
|||
|
||||
dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum;
|
||||
dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum;
|
||||
dgamma(n,q) += lpsi2*(d_exp1*g1nq-d_exp2*gnq)/exp_sum;
|
||||
dgamma(n,q) += lpsi2*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
|
||||
dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum;
|
||||
dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
|
||||
}
|
||||
|
|
@ -192,7 +196,7 @@ try:
|
|||
}
|
||||
}
|
||||
"""
|
||||
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','gamma1','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
|
||||
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
|
||||
|
||||
dl *= 2.*lengthscale
|
||||
if not ARD:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue