new ssrbf implementation

This commit is contained in:
Zhenwen Dai 2014-11-10 16:24:24 +00:00
parent c7d0bd2204
commit 22d30d9d39

View file

@ -7,6 +7,205 @@ The package for the psi statistics computation
import numpy as np import numpy as np
try:
from scipy import weave
def _psicomputations(variance, lengthscale, Z, variational_posterior):
"""
Z - MxQ
mu - NxQ
S - NxQ
gamma - NxQ
"""
# here are the "statistics" for psi0, psi1 and psi2
# Produced intermediate results:
# _psi1 NxM
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
l2 = np.square(lengthscale)
log_denom1 = np.log(S/l2+1)
log_denom2 = np.log(2*S/l2+1)
log_gamma = np.log(gamma)
log_gamma1 = np.log(1.-gamma)
variance = float(variance)
psi0 = np.empty(N)
psi0[:] = variance
psi1 = np.empty((N,M))
psi2n = np.empty((N,M,M))
from ....util.misc import param_to_array
S = param_to_array(S)
mu = param_to_array(mu)
gamma = param_to_array(gamma)
Z = param_to_array(Z)
support_code = """
#include <math.h>
"""
code = """
for(int n=0; n<N; n++) {
for(int m1=0;m1<M;m1++) {
double log_psi1=0;
for(int m2=0;m2<=m1;m2++) {
double log_psi2_n=0;
for(int q=0;q<Q;q++) {
double Snq = S(n,q);
double lq = l2(q);
double Zm1q = Z(m1,q);
double Zm2q = Z(m2,q);
if(m2==0) {
// Compute Psi_1
double muZ = mu(n,q)-Z(m1,q);
double psi1_exp1 = log_gamma(n,q) - (muZ*muZ/(Snq+lq) +log_denom1(n,q))/2.;
double psi1_exp2 = log_gamma1(n,q) -Zm1q*Zm1q/(2.*lq);
log_psi1 += (psi1_exp1>psi1_exp2)?psi1_exp1+log1p(exp(psi1_exp2-psi1_exp1)):psi1_exp2+log1p(exp(psi1_exp1-psi1_exp2));
}
// Compute Psi_2
double muZhat = mu(n,q) - (Zm1q+Zm2q)/2.;
double Z2 = Zm1q*Zm1q+ Zm2q*Zm2q;
double dZ = Zm1q - Zm2q;
double psi2_exp1 = dZ*dZ/(-4.*lq)-muZhat*muZhat/(2.*Snq+lq) - log_denom2(n,q)/2. + log_gamma(n,q);
double psi2_exp2 = log_gamma1(n,q) - Z2/(2.*lq);
log_psi2_n += (psi2_exp1>psi2_exp2)?psi2_exp1+log1p(exp(psi2_exp2-psi2_exp1)):psi2_exp2+log1p(exp(psi2_exp1-psi2_exp2));
}
double exp_psi2_n = exp(log_psi2_n);
psi2n(n,m1,m2) = variance*variance*exp_psi2_n;
if(m1!=m2) { psi2n(n,m2,m1) = variance*variance*exp_psi2_n;}
}
psi1(n,m1) = variance*exp(log_psi1);
}
}
"""
weave.inline(code, support_code=support_code, arg_names=['psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1'], type_converters=weave.converters.blitz)
psi2 = psi2n.sum(axis=0)
return psi0,psi1,psi2,psi2n
from GPy.util.caching import Cacher
psicomputations = Cacher(_psicomputations, limit=1)
def psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior):
ARD = (len(lengthscale)!=1)
_,psi1,_,psi2n = psicomputations(variance, lengthscale, Z, variational_posterior)
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
l2 = np.square(lengthscale)
log_denom1 = np.log(S/l2+1)
log_denom2 = np.log(2*S/l2+1)
log_gamma = np.log(gamma)
log_gamma1 = np.log(1.-gamma)
variance = float(variance)
dvar = np.zeros(1)
dmu = np.zeros((N,Q))
dS = np.zeros((N,Q))
dgamma = np.zeros((N,Q))
dl = np.zeros(Q)
dZ = np.zeros((M,Q))
dvar += np.sum(dL_dpsi0)
from ....util.misc import param_to_array
S = param_to_array(S)
mu = param_to_array(mu)
gamma = param_to_array(gamma)
Z = param_to_array(Z)
support_code = """
#include <math.h>
"""
code = """
for(int n=0; n<N; n++) {
for(int m1=0;m1<M;m1++) {
double log_psi1=0;
for(int m2=0;m2<M;m2++) {
double log_psi2_n=0;
for(int q=0;q<Q;q++) {
double Snq = S(n,q);
double lq = l2(q);
double Zm1q = Z(m1,q);
double Zm2q = Z(m2,q);
double gnq = gamma(n,q);
double mu_nq = mu(n,q);
if(m2==0) {
// Compute Psi_1
double lpsi1 = psi1(n,m1)*dL_dpsi1(n,m1);
if(q==0) {dvar(0) += lpsi1/variance;}
double Zmu = Zm1q - mu_nq;
double denom = Snq+lq;
double Zmu2_denom = Zmu*Zmu/denom;
double exp1 = log_gamma(n,q)-(Zmu*Zmu/(Snq+lq)+log_denom1(n,q))/(2.);
double exp2 = log_gamma1(n,q)-Zm1q*Zm1q/(2.*lq);
double d_exp1,d_exp2;
if(exp1>exp2) {
d_exp1 = 1.;
d_exp2 = exp(exp2-exp1);
} else {
d_exp1 = exp(exp1-exp2);
d_exp2 = 1.;
}
double exp_sum = d_exp1+d_exp2;
dmu(n,q) += lpsi1*Zmu*d_exp1/(denom*exp_sum);
dS(n,q) += lpsi1*(Zmu2_denom-1.)*d_exp1/(denom*exp_sum)/2.;
dgamma(n,q) += lpsi1*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
dl(q) += lpsi1*((Zmu2_denom+Snq/lq)/denom*d_exp1+Zm1q*Zm1q/(lq*lq)*d_exp2)/(2.*exp_sum);
dZ(m1,q) += lpsi1*(-Zmu/denom*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
}
// Compute Psi_2
double lpsi2 = psi2n(n,m1,m2)*dL_dpsi2(m1,m2);
if(q==0) {dvar(0) += lpsi2*2/variance;}
double dZm1m2 = Zm1q - Zm2q;
double Z2 = Zm1q*Zm1q+Zm2q*Zm2q;
double muZhat = mu_nq - (Zm1q + Zm2q)/2.;
double denom = 2.*Snq+lq;
double muZhat2_denom = muZhat*muZhat/denom;
double exp1 = dZm1m2*dZm1m2/(-4.*lq)-muZhat*muZhat/(2.*Snq+lq) - log_denom2(n,q)/2. + log_gamma(n,q);
double exp2 = log_gamma1(n,q) - Z2/(2.*lq);
double d_exp1,d_exp2;
if(exp1>exp2) {
d_exp1 = 1.;
d_exp2 = exp(exp2-exp1);
} else {
d_exp1 = exp(exp1-exp2);
d_exp2 = 1.;
}
double exp_sum = d_exp1+d_exp2;
dmu(n,q) += -2.*lpsi2*muZhat/denom*d_exp1/exp_sum;
dS(n,q) += lpsi2*(2.*muZhat2_denom-1.)/denom*d_exp1/exp_sum;
dgamma(n,q) += lpsi2*(d_exp1/gnq-d_exp2/(1.-gnq))/exp_sum;
dl(q) += lpsi2*(((Snq/lq+muZhat2_denom)/denom+dZm1m2*dZm1m2/(4.*lq*lq))*d_exp1+Z2/(2.*lq*lq)*d_exp2)/exp_sum;
dZ(m1,q) += 2.*lpsi2*((muZhat/denom-dZm1m2/(2*lq))*d_exp1-Zm1q/lq*d_exp2)/exp_sum;
}
}
}
}
"""
weave.inline(code, support_code=support_code, arg_names=['dL_dpsi1','dL_dpsi2','psi1','psi2n','N','M','Q','variance','l2','Z','mu','S','gamma','log_denom1','log_denom2','log_gamma','log_gamma1','dvar','dl','dmu','dS','dgamma','dZ'], type_converters=weave.converters.blitz)
dl *= 2.*lengthscale
if not ARD:
dl = dl.sum()
return dvar, dl, dZ, dmu, dS, dgamma
except:
def psicomputations(variance, lengthscale, Z, variational_posterior): def psicomputations(variance, lengthscale, Z, variational_posterior):
""" """
Z - MxQ Z - MxQ