rbf psi-statistics speed improvement

This commit is contained in:
Zhenwen Dai 2015-08-12 13:51:43 +01:00
parent 28c03148ec
commit 6bff33e30c

View file

@ -55,16 +55,15 @@ def __psi2computations(variance, lengthscale, Z, mu, S):
# Produced intermediate results: # Produced intermediate results:
# _psi2 MxM # _psi2 MxM
N,M,Q = mu.shape[0], Z.shape[0], mu.shape[1]
lengthscale2 = np.square(lengthscale) lengthscale2 = np.square(lengthscale)
_psi2_logdenom = np.log(2.*S/lengthscale2+1.).sum(axis=-1)/(-2.) # N _psi2_logdenom = np.log(2.*S/lengthscale2+1.).sum(axis=-1)/(-2.) # N
_psi2_exp1 = (np.square(Z[:,None,:]-Z[None,:,:])/lengthscale2).sum(axis=-1)/(-4.) #MxM _psi2_exp1 = (np.square(Z[:,None,:]-Z[None,:,:])/lengthscale2).sum(axis=-1)/(-4.) #MxM
Z_hat = (Z[:,None,:]+Z[None,:,:])/2. #MxMxQ Z_hat = (Z[:,None,:]+Z[None,:,:])/2. #MxMxQ
denom = 1./(2.*S+lengthscale2) denom = 1./(2.*S+lengthscale2)
_psi2_exp2 = -(np.square(mu)*denom).sum(axis=-1)[:,None,None]+2.*np.einsum('nq,moq,nq->nmo',mu,Z_hat,denom)-np.einsum('moq,nq->nmo',np.square(Z_hat),denom) _psi2_exp2 = -(np.square(mu)*denom).sum(axis=-1)[:,None,None]+(2*(mu*denom).dot(Z_hat.reshape(M*M,Q).T) - denom.dot(np.square(Z_hat).reshape(M*M,Q).T)).reshape(N,M,M)
_psi2 = variance*variance*np.exp(_psi2_logdenom[:,None,None]+_psi2_exp1[None,:,:]+_psi2_exp2) _psi2 = variance*variance*np.exp(_psi2_logdenom[:,None,None]+_psi2_exp1[None,:,:]+_psi2_exp2)
return _psi2 return _psi2
def psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior): def psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior):
@ -157,5 +156,5 @@ def _psi2compDer(dL_dpsi2, variance, lengthscale, Z, mu, S):
return _dL_dvar, _dL_dl, _dL_dZ, _dL_dmu, _dL_dS return _dL_dvar, _dL_dl, _dL_dZ, _dL_dmu, _dL_dS
_psi1computations = Cacher(__psi1computations, limit=1) _psi1computations = Cacher(__psi1computations, limit=5)
_psi2computations = Cacher(__psi2computations, limit=1) _psi2computations = Cacher(__psi2computations, limit=5)