mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
some further performance improvement for linear kernel psi statistics
This commit is contained in:
parent
3d322301a2
commit
abcea70291
3 changed files with 15 additions and 17 deletions
|
|
@ -113,7 +113,6 @@ class Linear(Kern):
|
||||||
def psi1(self, Z, variational_posterior):
|
def psi1(self, Z, variational_posterior):
|
||||||
return self.psicomp.psicomputations(self.variances, Z, variational_posterior)[1]
|
return self.psicomp.psicomputations(self.variances, Z, variational_posterior)[1]
|
||||||
|
|
||||||
@Cache_this(limit=1)
|
|
||||||
def psi2(self, Z, variational_posterior):
|
def psi2(self, Z, variational_posterior):
|
||||||
return self.psicomp.psicomputations(self.variances, Z, variational_posterior)[2]
|
return self.psicomp.psicomputations(self.variances, Z, variational_posterior)[2]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,15 +56,15 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S):
|
||||||
|
|
||||||
variance2 = np.square(variance)
|
variance2 = np.square(variance)
|
||||||
common_sum = np.einsum('q,mq,nq->nm',variance,Z,mu) # NxM
|
common_sum = np.einsum('q,mq,nq->nm',variance,Z,mu) # NxM
|
||||||
dL_dpsi2_2 = dL_dpsi2+dL_dpsi2.T
|
Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z)
|
||||||
|
common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum)
|
||||||
|
|
||||||
dL_dvar = np.einsum('mo,nq,q,mq,oq->q',dL_dpsi2,2.*S,variance,Z,Z)+\
|
dL_dvar = np.einsum('q,nq,q->q',Z_expect,2.*S,variance)+ np.einsum('nq,nq->q',common_expect,mu)
|
||||||
np.einsum('mo,mq,nq,no->q',dL_dpsi2_2,Z,mu,common_sum)
|
|
||||||
|
|
||||||
dL_dmu = np.einsum('mo,q,mq,no->nq',dL_dpsi2_2,variance,Z,common_sum)
|
dL_dmu = np.einsum('nq,q->nq',common_expect,variance)
|
||||||
|
|
||||||
dL_dS = np.empty(S.shape)
|
dL_dS = np.empty(S.shape)
|
||||||
dL_dS[:] = np.einsum('mo,q,mq,oq->q',dL_dpsi2,variance2,Z,Z)
|
dL_dS[:] = np.einsum('q,q->q',Z_expect,variance2)
|
||||||
|
|
||||||
dL_dZ = 2.*(np.einsum('om,q,mq,nq->oq',dL_dpsi2,variance2,Z,S)+np.einsum('om,q,nq,nm->oq',dL_dpsi2,variance,mu,common_sum))
|
dL_dZ = 2.*(np.einsum('om,q,mq,nq->oq',dL_dpsi2,variance2,Z,S)+np.einsum('om,q,nq,nm->oq',dL_dpsi2,variance,mu,common_sum))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,21 +64,20 @@ def _psi2computations(dL_dpsi2, variance, Z, mu, S, gamma):
|
||||||
variance2 = np.square(variance)
|
variance2 = np.square(variance)
|
||||||
mu2S = mu2+S # NxQ
|
mu2S = mu2+S # NxQ
|
||||||
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
|
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
|
||||||
|
Z_expect = np.einsum('mo,mq,oq->q',dL_dpsi2,Z,Z)
|
||||||
|
common_expect = np.einsum('mo,mq,no->nq',dL_dpsi2+dL_dpsi2.T,Z,common_sum)
|
||||||
|
|
||||||
dL_dpsi2_2 = dL_dpsi2+dL_dpsi2.T
|
dL_dvar = np.einsum('nq,q,q->q',2.*(gamma*mu2S-gamma2*mu2),variance,Z_expect)+\
|
||||||
|
np.einsum('nq,nq,nq->q',common_expect,gamma,mu)
|
||||||
dL_dvar = np.einsum('mo,nq,q,mq,oq->q',dL_dpsi2,2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\
|
|
||||||
np.einsum('mo,nq,mq,nq,no->q',dL_dpsi2_2,gamma,Z,mu,common_sum)
|
|
||||||
|
|
||||||
dL_dgamma = np.einsum('mo,q,mq,oq,nq->nq',dL_dpsi2,variance2,Z,Z,(mu2S-2.*gamma*mu2))+\
|
dL_dgamma = np.einsum('q,q,nq->nq',Z_expect,variance2,(mu2S-2.*gamma*mu2))+\
|
||||||
np.einsum('mo,q,mq,nq,no->nq',dL_dpsi2_2,variance,Z,mu,common_sum)
|
np.einsum('nq,q,nq->nq',common_expect,variance,mu)
|
||||||
|
|
||||||
dL_dmu = np.einsum('mo,q,mq,oq,nq,nq->nq',dL_dpsi2,variance2,Z,Z,mu,2.*(gamma-gamma2))+\
|
dL_dmu = np.einsum('q,q,nq,nq->nq',Z_expect,variance2,mu,2.*(gamma-gamma2))+\
|
||||||
np.einsum('mo,nq,q,mq,no->nq',dL_dpsi2_2,gamma,variance,Z,common_sum)
|
np.einsum('nq,nq,q->nq',common_expect,gamma,variance)
|
||||||
|
|
||||||
dL_dS = np.einsum('mo,nq,q,mq,oq->nq',dL_dpsi2,gamma,variance2,Z,Z)
|
dL_dS = np.einsum('q,nq,q->nq',Z_expect,gamma,variance2)
|
||||||
|
|
||||||
dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,mu2S)+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum)
|
dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,(mu2S-gamma*mu2))+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum))
|
||||||
-np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma2,variance2,Z,mu2))
|
|
||||||
|
|
||||||
return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ
|
return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue