rbf psi-statistics speedup

This commit is contained in:
Zhenwen Dai 2015-09-23 18:09:24 +01:00
parent c639b0d126
commit 76220cd4d2

View file

@ -101,25 +101,30 @@ def _psi2compDer(dL_dpsi2, variance, lengthscale, Z, mu, S):
# _dL_dgamma NxQ
# _dL_dmu NxQ
# _dL_dS NxQ
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
lengthscale2 = np.square(lengthscale)
denom = 1./(2*S+lengthscale2)
denom2 = np.square(denom)
_psi2 = _psi2computations(variance, lengthscale, Z, mu, S) # NxMxM
Lpsi2 = dL_dpsi2*_psi2 # dL_dpsi2 is MxM, using broadcast to multiply N out
Lpsi2sum = np.einsum('nmo->n',Lpsi2) #N
Lpsi2Z = np.einsum('nmo,oq->nq',Lpsi2,Z) #NxQ
Lpsi2Z2 = np.einsum('nmo,oq,oq->nq',Lpsi2,Z,Z) #NxQ
Lpsi2Z2p = np.einsum('nmo,mq,oq->nq',Lpsi2,Z,Z) #NxQ
Lpsi2sum = Lpsi2.reshape(N,M*M).sum(1) #N
tmp = Lpsi2.reshape(N*M,M).dot(Z).reshape(N,M,Q)
Lpsi2Z = tmp.sum(1) #NxQ
Lpsi2Z2 = Lpsi2.reshape(N*M,M).dot(np.square(Z)).reshape(N,M,Q).sum(1) #np.einsum('nmo,oq,oq->nq',Lpsi2,Z,Z) #NxQ
Lpsi2Z2p = (tmp*Z[None,:,:]).sum(1) #np.einsum('nmo,mq,oq->nq',Lpsi2,Z,Z) #NxQ
Lpsi2Zhat = Lpsi2Z
Lpsi2Zhat2 = (Lpsi2Z2+Lpsi2Z2p)/2
_dL_dvar = Lpsi2sum.sum()*2/variance
_dL_dmu = (-2*denom) * (mu*Lpsi2sum[:,None]-Lpsi2Zhat)
_dL_dS = (2*np.square(denom))*(np.square(mu)*Lpsi2sum[:,None]-2*mu*Lpsi2Zhat+Lpsi2Zhat2) - denom*Lpsi2sum[:,None]
_dL_dZ = -np.einsum('nmo,oq->oq',Lpsi2,Z)/lengthscale2+np.einsum('nmo,oq->mq',Lpsi2,Z)/lengthscale2+ \
2*np.einsum('nmo,nq,nq->mq',Lpsi2,mu,denom) - np.einsum('nmo,nq,mq->mq',Lpsi2,denom,Z) - np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom)
# _dL_dZ = -np.einsum('nmo,oq->oq',Lpsi2,Z)/lengthscale2+np.einsum('nmo,oq->mq',Lpsi2,Z)/lengthscale2+ \
# 2*np.einsum('nmo,nq,nq->mq',Lpsi2,mu,denom) - np.einsum('nmo,nq,mq->mq',Lpsi2,denom,Z) - np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom)
Lpsi2_N = Lpsi2.sum(0)
Lpsi2_M = Lpsi2.sum(2)
_dL_dZ = -Lpsi2_N.sum(0)[:,None]*Z/lengthscale2+Lpsi2_N.dot(Z)/lengthscale2+ \
2*Lpsi2_M.T.dot(mu*denom) - Lpsi2_M.T.dot(denom)*Z - (Lpsi2.reshape(N,M*M).T.dot(denom).reshape(M,M,Q)*Z[None,:,:]).sum(1)#np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom)
_dL_dl = 2*lengthscale* ((S/lengthscale2*denom+np.square(mu*denom))*Lpsi2sum[:,None]+(Lpsi2Z2-Lpsi2Z2p)/(2*np.square(lengthscale2))-
(2*mu*denom2)*Lpsi2Zhat+denom2*Lpsi2Zhat2).sum(axis=0)