mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-18 13:55:14 +02:00
faster einsums in svgp
This commit is contained in:
parent
3e4f272808
commit
5ad38ac640
1 changed files with 9 additions and 4 deletions
|
|
@ -4,6 +4,8 @@ from ...util import choleskies
|
|||
import numpy as np
|
||||
from .posterior import Posterior
|
||||
|
||||
def ij_ijk_to_ikl
|
||||
|
||||
class SVGP(LatentFunctionInference):
|
||||
|
||||
def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, KL_scale=1.0, batch_scale=1.0):
|
||||
|
|
@ -41,11 +43,12 @@ class SVGP(LatentFunctionInference):
|
|||
#compute the marginal means and variances of q(f)
|
||||
A = np.dot(Knm, Kmmi)
|
||||
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
|
||||
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jkl->ikl', A, S),1)
|
||||
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jkl->ikl', A, S),1)
|
||||
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] *A.dot(S.reshape(S.shape[0],-1)).reshape(A.shape[0],S.shape[1],S.shape[2]),1)
|
||||
|
||||
#compute the KL term
|
||||
Kmmim = np.dot(Kmmi, q_u_mean)
|
||||
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.einsum('ij,ijk->k', Kmmi, S) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
||||
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[:,:,None]*S,0).sum(0) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
||||
KL = KLs.sum()
|
||||
#gradient of the KL term (assuming zero mean function)
|
||||
dKL_dm = Kmmim.copy()
|
||||
|
|
@ -78,11 +81,13 @@ class SVGP(LatentFunctionInference):
|
|||
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
|
||||
Admu = A.T.dot(dF_dmu)
|
||||
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
|
||||
tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
|
||||
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
|
||||
tmp = np.sum([np.dot(AdvA[:,:,k], S[:,:,k]) for k in range(S.shape[-1])],0).dot(Kmmi)
|
||||
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
|
||||
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
||||
tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
|
||||
dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
||||
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
||||
dF_dKmn = np.sum([np.dot(tmp[:,:,k], Adv[:,:,k]) for k in range(Adv.shape[-1])],0) + Kmmim.dot(dF_dmu.T)
|
||||
dF_dm = Admu
|
||||
dF_dS = AdvA
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue