performance improvement for sslinear kernel

This commit is contained in:
Zhenwen Dai 2014-11-20 14:52:49 +00:00
parent aa473178f7
commit 11f872db33

View file

@ -5,6 +5,8 @@
The package for the Psi statistics computation of the linear kernel for SSGPLVM The package for the Psi statistics computation of the linear kernel for SSGPLVM
""" """
from ....util.linalg import tdot
import numpy as np import numpy as np
def psicomputations(variance, Z, variational_posterior): def psicomputations(variance, Z, variational_posterior):
@ -20,10 +22,9 @@ def psicomputations(variance, Z, variational_posterior):
S = variational_posterior.variance S = variational_posterior.variance
gamma = variational_posterior.binary_prob gamma = variational_posterior.binary_prob
psi0 = np.einsum('q,nq,nq->n',variance,gamma,np.square(mu)+S) psi0 = (gamma*(np.square(mu)+S)*variance).sum(axis=-1)
psi1 = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) psi1 = np.inner(variance*gamma*mu,Z)
psi2 = np.einsum('nq,q,mq,oq,nq->mo',gamma,np.square(variance),Z,Z,(1-gamma)*np.square(mu)+S) +\ psi2 = np.inner(np.square(variance)*(gamma*((1-gamma)*np.square(mu)+S)).sum(axis=0)*Z,Z)+tdot(psi1.T)
np.einsum('nm,no->mo',psi1,psi1)
return psi0, psi1, psi2 return psi0, psi1, psi2
@ -63,9 +64,16 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S, gamma):
gamma2 = np.square(gamma) gamma2 = np.square(gamma)
variance2 = np.square(variance) variance2 = np.square(variance)
mu2S = mu2+S # NxQ mu2S = mu2+S # NxQ
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM gvm = np.einsum('nq,nq,q->nq',gamma,mu,variance)
common_sum = np.einsum('nq,mq->nm',gvm,Z)
# common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z) Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z)
common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum) dL_dpsi2T = dL_dpsi2+dL_dpsi2.T
tmp = np.einsum('mo,oq->mq',dL_dpsi2T,Z)
common_expect = np.einsum('mq,nm->nq',tmp,common_sum)
# common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum)
Z2_expect = np.einsum('om,nm->no',dL_dpsi2T,common_sum)
Z1_expect = np.einsum('om,mq->oq',dL_dpsi2T,Z)
dL_dvar = np.einsum('nq,q,q->q',2.*(gamma*mu2S-gamma2*mu2),variance,Z_expect)+\ dL_dvar = np.einsum('nq,q,q->q',2.*(gamma*mu2S-gamma2*mu2),variance,Z_expect)+\
np.einsum('nq,nq,nq->q',common_expect,gamma,mu) np.einsum('nq,nq,nq->q',common_expect,gamma,mu)
@ -78,6 +86,7 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S, gamma):
dL_dS = np.einsum('q,nq,q->nq',Z_expect,gamma,variance2) dL_dS = np.einsum('q,nq,q->nq',Z_expect,gamma,variance2)
dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,(mu2S-gamma*mu2))+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum)) # dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,(mu2S-gamma*mu2))+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum))
dL_dZ = Z1_expect*np.einsum('nq,q,nq->q',gamma,variance2,(mu2S-gamma*mu2))+np.einsum('nq,q,nq,nm->mq',gamma,variance,mu,Z2_expect)
return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ