mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
rbf psi-statistics speedup
This commit is contained in:
parent
c639b0d126
commit
76220cd4d2
1 changed files with 12 additions and 7 deletions
|
|
@ -101,25 +101,30 @@ def _psi2compDer(dL_dpsi2, variance, lengthscale, Z, mu, S):
|
|||
# _dL_dgamma NxQ
|
||||
# _dL_dmu NxQ
|
||||
# _dL_dS NxQ
|
||||
|
||||
N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1]
|
||||
lengthscale2 = np.square(lengthscale)
|
||||
denom = 1./(2*S+lengthscale2)
|
||||
denom2 = np.square(denom)
|
||||
|
||||
_psi2 = _psi2computations(variance, lengthscale, Z, mu, S) # NxMxM
|
||||
Lpsi2 = dL_dpsi2*_psi2 # dL_dpsi2 is MxM, using broadcast to multiply N out
|
||||
Lpsi2sum = np.einsum('nmo->n',Lpsi2) #N
|
||||
Lpsi2Z = np.einsum('nmo,oq->nq',Lpsi2,Z) #NxQ
|
||||
Lpsi2Z2 = np.einsum('nmo,oq,oq->nq',Lpsi2,Z,Z) #NxQ
|
||||
Lpsi2Z2p = np.einsum('nmo,mq,oq->nq',Lpsi2,Z,Z) #NxQ
|
||||
Lpsi2sum = Lpsi2.reshape(N,M*M).sum(1) #N
|
||||
tmp = Lpsi2.reshape(N*M,M).dot(Z).reshape(N,M,Q)
|
||||
Lpsi2Z = tmp.sum(1) #NxQ
|
||||
Lpsi2Z2 = Lpsi2.reshape(N*M,M).dot(np.square(Z)).reshape(N,M,Q).sum(1) #np.einsum('nmo,oq,oq->nq',Lpsi2,Z,Z) #NxQ
|
||||
Lpsi2Z2p = (tmp*Z[None,:,:]).sum(1) #np.einsum('nmo,mq,oq->nq',Lpsi2,Z,Z) #NxQ
|
||||
Lpsi2Zhat = Lpsi2Z
|
||||
Lpsi2Zhat2 = (Lpsi2Z2+Lpsi2Z2p)/2
|
||||
|
||||
_dL_dvar = Lpsi2sum.sum()*2/variance
|
||||
_dL_dmu = (-2*denom) * (mu*Lpsi2sum[:,None]-Lpsi2Zhat)
|
||||
_dL_dS = (2*np.square(denom))*(np.square(mu)*Lpsi2sum[:,None]-2*mu*Lpsi2Zhat+Lpsi2Zhat2) - denom*Lpsi2sum[:,None]
|
||||
_dL_dZ = -np.einsum('nmo,oq->oq',Lpsi2,Z)/lengthscale2+np.einsum('nmo,oq->mq',Lpsi2,Z)/lengthscale2+ \
|
||||
2*np.einsum('nmo,nq,nq->mq',Lpsi2,mu,denom) - np.einsum('nmo,nq,mq->mq',Lpsi2,denom,Z) - np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom)
|
||||
# _dL_dZ = -np.einsum('nmo,oq->oq',Lpsi2,Z)/lengthscale2+np.einsum('nmo,oq->mq',Lpsi2,Z)/lengthscale2+ \
|
||||
# 2*np.einsum('nmo,nq,nq->mq',Lpsi2,mu,denom) - np.einsum('nmo,nq,mq->mq',Lpsi2,denom,Z) - np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom)
|
||||
Lpsi2_N = Lpsi2.sum(0)
|
||||
Lpsi2_M = Lpsi2.sum(2)
|
||||
_dL_dZ = -Lpsi2_N.sum(0)[:,None]*Z/lengthscale2+Lpsi2_N.dot(Z)/lengthscale2+ \
|
||||
2*Lpsi2_M.T.dot(mu*denom) - Lpsi2_M.T.dot(denom)*Z - (Lpsi2.reshape(N,M*M).T.dot(denom).reshape(M,M,Q)*Z[None,:,:]).sum(1)#np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom)
|
||||
_dL_dl = 2*lengthscale* ((S/lengthscale2*denom+np.square(mu*denom))*Lpsi2sum[:,None]+(Lpsi2Z2-Lpsi2Z2p)/(2*np.square(lengthscale2))-
|
||||
(2*mu*denom2)*Lpsi2Zhat+denom2*Lpsi2Zhat2).sum(axis=0)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue