diff --git a/GPy/kern/_src/psi_comp/rbf_psi_comp.py b/GPy/kern/_src/psi_comp/rbf_psi_comp.py index e7a67d16..16bada72 100644 --- a/GPy/kern/_src/psi_comp/rbf_psi_comp.py +++ b/GPy/kern/_src/psi_comp/rbf_psi_comp.py @@ -55,16 +55,15 @@ def __psi2computations(variance, lengthscale, Z, mu, S): # Produced intermediate results: # _psi2 MxM + N,M,Q = mu.shape[0], Z.shape[0], mu.shape[1] lengthscale2 = np.square(lengthscale) _psi2_logdenom = np.log(2.*S/lengthscale2+1.).sum(axis=-1)/(-2.) # N _psi2_exp1 = (np.square(Z[:,None,:]-Z[None,:,:])/lengthscale2).sum(axis=-1)/(-4.) #MxM Z_hat = (Z[:,None,:]+Z[None,:,:])/2. #MxMxQ denom = 1./(2.*S+lengthscale2) - _psi2_exp2 = -(np.square(mu)*denom).sum(axis=-1)[:,None,None]+2.*np.einsum('nq,moq,nq->nmo',mu,Z_hat,denom)-np.einsum('moq,nq->nmo',np.square(Z_hat),denom) + _psi2_exp2 = -(np.square(mu)*denom).sum(axis=-1)[:,None,None]+(2*(mu*denom).dot(Z_hat.reshape(M*M,Q).T) - denom.dot(np.square(Z_hat).reshape(M*M,Q).T)).reshape(N,M,M) _psi2 = variance*variance*np.exp(_psi2_logdenom[:,None,None]+_psi2_exp1[None,:,:]+_psi2_exp2) - - return _psi2 def psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior): @@ -157,5 +156,5 @@ def _psi2compDer(dL_dpsi2, variance, lengthscale, Z, mu, S): return _dL_dvar, _dL_dl, _dL_dZ, _dL_dmu, _dL_dS -_psi1computations = Cacher(__psi1computations, limit=1) -_psi2computations = Cacher(__psi2computations, limit=1) +_psi1computations = Cacher(__psi1computations, limit=5) +_psi2computations = Cacher(__psi2computations, limit=5)