From abcea702911bcc7b8cf7e9875e6dfe861546efb5 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Mon, 18 Aug 2014 17:30:59 +0100 Subject: [PATCH] some further performance improvement for linear kernel psi statistics --- GPy/kern/_src/linear.py | 1 - GPy/kern/_src/psi_comp/linear_psi_comp.py | 10 +++++----- GPy/kern/_src/psi_comp/sslinear_psi_comp.py | 21 ++++++++++----------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/GPy/kern/_src/linear.py b/GPy/kern/_src/linear.py index 41ec3cae..6874c0d6 100644 --- a/GPy/kern/_src/linear.py +++ b/GPy/kern/_src/linear.py @@ -113,7 +113,6 @@ class Linear(Kern): def psi1(self, Z, variational_posterior): return self.psicomp.psicomputations(self.variances, Z, variational_posterior)[1] - @Cache_this(limit=1) def psi2(self, Z, variational_posterior): return self.psicomp.psicomputations(self.variances, Z, variational_posterior)[2] diff --git a/GPy/kern/_src/psi_comp/linear_psi_comp.py b/GPy/kern/_src/psi_comp/linear_psi_comp.py index 94cba0f5..93297e7e 100644 --- a/GPy/kern/_src/psi_comp/linear_psi_comp.py +++ b/GPy/kern/_src/psi_comp/linear_psi_comp.py @@ -56,15 +56,15 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S): variance2 = np.square(variance) common_sum = np.einsum('q,mq,nq->nm',variance,Z,mu) # NxM - dL_dpsi2_2 = dL_dpsi2+dL_dpsi2.T + Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z) + common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum) - dL_dvar = np.einsum('mo,nq,q,mq,oq->q',dL_dpsi2,2.*S,variance,Z,Z)+\ - np.einsum('mo,mq,nq,no->q',dL_dpsi2_2,Z,mu,common_sum) + dL_dvar = np.einsum('q,nq,q->q',Z_expect,2.*S,variance)+ np.einsum('nq,nq->q',common_expect,mu) - dL_dmu = np.einsum('mo,q,mq,no->nq',dL_dpsi2_2,variance,Z,common_sum) + dL_dmu = np.einsum('nq,q->nq',common_expect,variance) dL_dS = np.empty(S.shape) - dL_dS[:] = np.einsum('mo,q,mq,oq->q',dL_dpsi2,variance2,Z,Z) + dL_dS[:] = np.einsum('q,q->q',Z_expect,variance2) dL_dZ = 2.*(np.einsum('om,q,mq,nq->oq',dL_dpsi2,variance2,Z,S)+np.einsum('om,q,nq,nm->oq',dL_dpsi2,variance,mu,common_sum)) diff --git a/GPy/kern/_src/psi_comp/sslinear_psi_comp.py b/GPy/kern/_src/psi_comp/sslinear_psi_comp.py index dc864b3c..b505f26f 100644 --- a/GPy/kern/_src/psi_comp/sslinear_psi_comp.py +++ b/GPy/kern/_src/psi_comp/sslinear_psi_comp.py @@ -64,21 +64,20 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S, gamma): variance2 = np.square(variance) mu2S = mu2+S # NxQ common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM + Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z) + common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum) - dL_dpsi2_2 = dL_dpsi2+dL_dpsi2.T - - dL_dvar = np.einsum('mo,nq,q,mq,oq->q',dL_dpsi2,2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\ - np.einsum('mo,nq,mq,nq,no->q',dL_dpsi2_2,gamma,Z,mu,common_sum) + dL_dvar = np.einsum('nq,q,q->q',2.*(gamma*mu2S-gamma2*mu2),variance,Z_expect)+\ + np.einsum('nq,nq,nq->q',common_expect,gamma,mu) - dL_dgamma = np.einsum('mo,q,mq,oq,nq->nq',dL_dpsi2,variance2,Z,Z,(mu2S-2.*gamma*mu2))+\ - np.einsum('mo,q,mq,nq,no->nq',dL_dpsi2_2,variance,Z,mu,common_sum) + dL_dgamma = np.einsum('q,q,nq->nq',Z_expect,variance2,(mu2S-2.*gamma*mu2))+\ + np.einsum('nq,q,nq->nq',common_expect,variance,mu) - dL_dmu = np.einsum('mo,q,mq,oq,nq,nq->nq',dL_dpsi2,variance2,Z,Z,mu,2.*(gamma-gamma2))+\ - np.einsum('mo,nq,q,mq,no->nq',dL_dpsi2_2,gamma,variance,Z,common_sum) + dL_dmu = np.einsum('q,q,nq,nq->nq',Z_expect,variance2,mu,2.*(gamma-gamma2))+\ + np.einsum('nq,nq,q->nq',common_expect,gamma,variance) - dL_dS = np.einsum('mo,nq,q,mq,oq->nq',dL_dpsi2,gamma,variance2,Z,Z) + dL_dS = np.einsum('q,nq,q->nq',Z_expect,gamma,variance2) - dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,mu2S)+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum) - -np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma2,variance2,Z,mu2)) + dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,(mu2S-gamma*mu2))+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum)) return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ