From 11f872db33ea37661f5cc74a5ad8fbf7ff3f0e72 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Thu, 20 Nov 2014 14:52:49 +0000 Subject: [PATCH] performance improvement for sslinear kernel --- GPy/kern/_src/psi_comp/sslinear_psi_comp.py | 23 ++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/GPy/kern/_src/psi_comp/sslinear_psi_comp.py b/GPy/kern/_src/psi_comp/sslinear_psi_comp.py index b505f26f..5f261785 100644 --- a/GPy/kern/_src/psi_comp/sslinear_psi_comp.py +++ b/GPy/kern/_src/psi_comp/sslinear_psi_comp.py @@ -5,6 +5,8 @@ The package for the Psi statistics computation of the linear kernel for SSGPLVM """ +from ....util.linalg import tdot + import numpy as np def psicomputations(variance, Z, variational_posterior): @@ -20,10 +22,9 @@ def psicomputations(variance, Z, variational_posterior): S = variational_posterior.variance gamma = variational_posterior.binary_prob - psi0 = np.einsum('q,nq,nq->n',variance,gamma,np.square(mu)+S) - psi1 = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) - psi2 = np.einsum('nq,q,mq,oq,nq->mo',gamma,np.square(variance),Z,Z,(1-gamma)*np.square(mu)+S) +\ - np.einsum('nm,no->mo',psi1,psi1) + psi0 = (gamma*(np.square(mu)+S)*variance).sum(axis=-1) + psi1 = np.inner(variance*gamma*mu,Z) + psi2 = np.inner(np.square(variance)*(gamma*((1-gamma)*np.square(mu)+S)).sum(axis=0)*Z,Z)+tdot(psi1.T) return psi0, psi1, psi2 @@ -63,9 +64,16 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S, gamma): gamma2 = np.square(gamma) variance2 = np.square(variance) mu2S = mu2+S # NxQ - common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM + gvm = np.einsum('nq,nq,q->nq',gamma,mu,variance) + common_sum = np.einsum('nq,mq->nm',gvm,Z) +# common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z) - common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum) + dL_dpsi2T = dL_dpsi2+dL_dpsi2.T + tmp = np.einsum('mo,oq->mq',dL_dpsi2T,Z) + common_expect = np.einsum('mq,nm->nq',tmp,common_sum) +# common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum) + Z2_expect = np.einsum('om,nm->no',dL_dpsi2T,common_sum) + Z1_expect = np.einsum('om,mq->oq',dL_dpsi2T,Z) dL_dvar = np.einsum('nq,q,q->q',2.*(gamma*mu2S-gamma2*mu2),variance,Z_expect)+\ np.einsum('nq,nq,nq->q',common_expect,gamma,mu) @@ -78,6 +86,7 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S, gamma): dL_dS = np.einsum('q,nq,q->nq',Z_expect,gamma,variance2) - dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,(mu2S-gamma*mu2))+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum)) +# dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,(mu2S-gamma*mu2))+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum)) + dL_dZ = Z1_expect*np.einsum('nq,q,nq->q',gamma,variance2,(mu2S-gamma*mu2))+np.einsum('nq,q,nq,nm->mq',gamma,variance,mu,Z2_expect) return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ