diff --git a/GPy/kern/_src/psi_comp/rbf_psi_comp.py b/GPy/kern/_src/psi_comp/rbf_psi_comp.py index 735a354d..5667eec6 100644 --- a/GPy/kern/_src/psi_comp/rbf_psi_comp.py +++ b/GPy/kern/_src/psi_comp/rbf_psi_comp.py @@ -101,25 +101,30 @@ def _psi2compDer(dL_dpsi2, variance, lengthscale, Z, mu, S): # _dL_dgamma NxQ # _dL_dmu NxQ # _dL_dS NxQ - + N,M,Q = mu.shape[0],Z.shape[0],mu.shape[1] lengthscale2 = np.square(lengthscale) denom = 1./(2*S+lengthscale2) denom2 = np.square(denom) _psi2 = _psi2computations(variance, lengthscale, Z, mu, S) # NxMxM Lpsi2 = dL_dpsi2*_psi2 # dL_dpsi2 is MxM, using broadcast to multiply N out - Lpsi2sum = np.einsum('nmo->n',Lpsi2) #N - Lpsi2Z = np.einsum('nmo,oq->nq',Lpsi2,Z) #NxQ - Lpsi2Z2 = np.einsum('nmo,oq,oq->nq',Lpsi2,Z,Z) #NxQ - Lpsi2Z2p = np.einsum('nmo,mq,oq->nq',Lpsi2,Z,Z) #NxQ + Lpsi2sum = Lpsi2.reshape(N,M*M).sum(1) #N + tmp = Lpsi2.reshape(N*M,M).dot(Z).reshape(N,M,Q) + Lpsi2Z = tmp.sum(1) #NxQ + Lpsi2Z2 = Lpsi2.reshape(N*M,M).dot(np.square(Z)).reshape(N,M,Q).sum(1) #np.einsum('nmo,oq,oq->nq',Lpsi2,Z,Z) #NxQ + Lpsi2Z2p = (tmp*Z[None,:,:]).sum(1) #np.einsum('nmo,mq,oq->nq',Lpsi2,Z,Z) #NxQ Lpsi2Zhat = Lpsi2Z Lpsi2Zhat2 = (Lpsi2Z2+Lpsi2Z2p)/2 _dL_dvar = Lpsi2sum.sum()*2/variance _dL_dmu = (-2*denom) * (mu*Lpsi2sum[:,None]-Lpsi2Zhat) _dL_dS = (2*np.square(denom))*(np.square(mu)*Lpsi2sum[:,None]-2*mu*Lpsi2Zhat+Lpsi2Zhat2) - denom*Lpsi2sum[:,None] - _dL_dZ = -np.einsum('nmo,oq->oq',Lpsi2,Z)/lengthscale2+np.einsum('nmo,oq->mq',Lpsi2,Z)/lengthscale2+ \ - 2*np.einsum('nmo,nq,nq->mq',Lpsi2,mu,denom) - np.einsum('nmo,nq,mq->mq',Lpsi2,denom,Z) - np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom) +# _dL_dZ = -np.einsum('nmo,oq->oq',Lpsi2,Z)/lengthscale2+np.einsum('nmo,oq->mq',Lpsi2,Z)/lengthscale2+ \ +# 2*np.einsum('nmo,nq,nq->mq',Lpsi2,mu,denom) - np.einsum('nmo,nq,mq->mq',Lpsi2,denom,Z) - np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom) + Lpsi2_N = Lpsi2.sum(0) + Lpsi2_M = Lpsi2.sum(2) + _dL_dZ = -Lpsi2_N.sum(0)[:,None]*Z/lengthscale2+Lpsi2_N.dot(Z)/lengthscale2+ \ + 2*Lpsi2_M.T.dot(mu*denom) - Lpsi2_M.T.dot(denom)*Z - (Lpsi2.reshape(N,M*M).T.dot(denom).reshape(M,M,Q)*Z[None,:,:]).sum(1)#np.einsum('nmo,oq,nq->mq',Lpsi2,Z,denom) _dL_dl = 2*lengthscale* ((S/lengthscale2*denom+np.square(mu*denom))*Lpsi2sum[:,None]+(Lpsi2Z2-Lpsi2Z2p)/(2*np.square(lengthscale2))- (2*mu*denom2)*Lpsi2Zhat+denom2*Lpsi2Zhat2).sum(axis=0)