mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-03 08:42:39 +02:00
rbf psi-statistics speed improvement
This commit is contained in:
parent
28c03148ec
commit
6bff33e30c
1 changed files with 4 additions and 5 deletions
|
|
@ -55,16 +55,15 @@ def __psi2computations(variance, lengthscale, Z, mu, S):
|
||||||
# Produced intermediate results:
|
# Produced intermediate results:
|
||||||
# _psi2 MxM
|
# _psi2 MxM
|
||||||
|
|
||||||
|
N,M,Q = mu.shape[0], Z.shape[0], mu.shape[1]
|
||||||
lengthscale2 = np.square(lengthscale)
|
lengthscale2 = np.square(lengthscale)
|
||||||
|
|
||||||
_psi2_logdenom = np.log(2.*S/lengthscale2+1.).sum(axis=-1)/(-2.) # N
|
_psi2_logdenom = np.log(2.*S/lengthscale2+1.).sum(axis=-1)/(-2.) # N
|
||||||
_psi2_exp1 = (np.square(Z[:,None,:]-Z[None,:,:])/lengthscale2).sum(axis=-1)/(-4.) #MxM
|
_psi2_exp1 = (np.square(Z[:,None,:]-Z[None,:,:])/lengthscale2).sum(axis=-1)/(-4.) #MxM
|
||||||
Z_hat = (Z[:,None,:]+Z[None,:,:])/2. #MxMxQ
|
Z_hat = (Z[:,None,:]+Z[None,:,:])/2. #MxMxQ
|
||||||
denom = 1./(2.*S+lengthscale2)
|
denom = 1./(2.*S+lengthscale2)
|
||||||
_psi2_exp2 = -(np.square(mu)*denom).sum(axis=-1)[:,None,None]+2.*np.einsum('nq,moq,nq->nmo',mu,Z_hat,denom)-np.einsum('moq,nq->nmo',np.square(Z_hat),denom)
|
_psi2_exp2 = -(np.square(mu)*denom).sum(axis=-1)[:,None,None]+(2*(mu*denom).dot(Z_hat.reshape(M*M,Q).T) - denom.dot(np.square(Z_hat).reshape(M*M,Q).T)).reshape(N,M,M)
|
||||||
_psi2 = variance*variance*np.exp(_psi2_logdenom[:,None,None]+_psi2_exp1[None,:,:]+_psi2_exp2)
|
_psi2 = variance*variance*np.exp(_psi2_logdenom[:,None,None]+_psi2_exp1[None,:,:]+_psi2_exp2)
|
||||||
|
|
||||||
|
|
||||||
return _psi2
|
return _psi2
|
||||||
|
|
||||||
def psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior):
|
def psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, variational_posterior):
|
||||||
|
|
@ -157,5 +156,5 @@ def _psi2compDer(dL_dpsi2, variance, lengthscale, Z, mu, S):
|
||||||
|
|
||||||
return _dL_dvar, _dL_dl, _dL_dZ, _dL_dmu, _dL_dS
|
return _dL_dvar, _dL_dl, _dL_dZ, _dL_dmu, _dL_dS
|
||||||
|
|
||||||
_psi1computations = Cacher(__psi1computations, limit=1)
|
_psi1computations = Cacher(__psi1computations, limit=5)
|
||||||
_psi2computations = Cacher(__psi2computations, limit=1)
|
_psi2computations = Cacher(__psi2computations, limit=5)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue