performance improvement for sslinear

This commit is contained in:
Zhenwen Dai 2014-11-20 14:51:27 +00:00
parent d8a76b89de
commit ac33a75999

View file

@ -5,6 +5,8 @@
The package for the Psi statistics computation of the linear kernel for SSGPLVM The package for the Psi statistics computation of the linear kernel for SSGPLVM
""" """
from ....util.linalg import tdot
import numpy as np import numpy as np
def psicomputations(variance, Z, variational_posterior): def psicomputations(variance, Z, variational_posterior):
@ -20,10 +22,9 @@ def psicomputations(variance, Z, variational_posterior):
S = variational_posterior.variance S = variational_posterior.variance
gamma = variational_posterior.binary_prob gamma = variational_posterior.binary_prob
psi0 = np.einsum('q,nq,nq->n',variance,gamma,np.square(mu)+S) psi0 = (gamma*(np.square(mu)+S)*variance).sum(axis=-1)
psi1 = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) psi1 = np.inner(variance*gamma*mu,Z)
psi2 = np.einsum('nq,q,mq,oq,nq->mo',gamma,np.square(variance),Z,Z,(1-gamma)*np.square(mu)+S) +\ psi2 = np.inner(np.square(variance)*(gamma*((1-gamma)*np.square(mu)+S)).sum(axis=0)*Z,Z)+tdot(psi1.T)
np.einsum('nm,no->mo',psi1,psi1)
return psi0, psi1, psi2 return psi0, psi1, psi2
@ -63,9 +64,16 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S, gamma):
gamma2 = np.square(gamma) gamma2 = np.square(gamma)
variance2 = np.square(variance) variance2 = np.square(variance)
mu2S = mu2+S # NxQ mu2S = mu2+S # NxQ
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM gvm = np.einsum('nq,nq,q->nq',gamma,mu,variance)
common_sum = np.einsum('nq,mq->nm',gvm,Z)
# common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z) Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z)
common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum) dL_dpsi2T = dL_dpsi2+dL_dpsi2.T
tmp = np.einsum('mo,oq->mq',dL_dpsi2T,Z)
common_expect = np.einsum('mq,nm->nq',tmp,common_sum)
# common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum)
Z2_expect = np.einsum('om,nm->no',dL_dpsi2T,common_sum)
Z1_expect = np.einsum('om,mq->oq',dL_dpsi2T,Z)
dL_dvar = np.einsum('nq,q,q->q',2.*(gamma*mu2S-gamma2*mu2),variance,Z_expect)+\ dL_dvar = np.einsum('nq,q,q->q',2.*(gamma*mu2S-gamma2*mu2),variance,Z_expect)+\
np.einsum('nq,nq,nq->q',common_expect,gamma,mu) np.einsum('nq,nq,nq->q',common_expect,gamma,mu)
@ -78,6 +86,7 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S, gamma):
dL_dS = np.einsum('q,nq,q->nq',Z_expect,gamma,variance2) dL_dS = np.einsum('q,nq,q->nq',Z_expect,gamma,variance2)
dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,(mu2S-gamma*mu2))+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum)) # dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,(mu2S-gamma*mu2))+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum))
dL_dZ = Z1_expect*np.einsum('nq,q,nq->q',gamma,variance2,(mu2S-gamma*mu2))+np.einsum('nq,q,nq,nm->mq',gamma,variance,mu,Z2_expect)
return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ