diff --git a/GPy/inference/latent_function_inference/svgp.py b/GPy/inference/latent_function_inference/svgp.py index b23a09b0..b3e62118 100644 --- a/GPy/inference/latent_function_inference/svgp.py +++ b/GPy/inference/latent_function_inference/svgp.py @@ -3,6 +3,7 @@ from ...util import linalg from ...util import choleskies import numpy as np from .posterior import Posterior +from scipy.linalg.blas import dgemm class SVGP(LatentFunctionInference): @@ -37,16 +38,16 @@ class SVGP(LatentFunctionInference): #compute kernel related stuff Kmm = kern.K(Z) - Knm = kern.K(X, Z) + Kmn = kern.K(Z, X) Knn_diag = kern.Kdiag(X) Kmmi, Lm, Lmi, logdetKmm = linalg.pdinv(Kmm) #compute the marginal means and variances of q(f) - A = np.dot(Knm, Kmmi) - mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u) - #v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1) - #v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1) - v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + (S.dot(A.T)*A.T[None,:,:]).sum(1).T + A = np.dot(Kmmi, Kmn) + mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u) + LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data) + v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T + #compute the KL term Kmmim = np.dot(Kmmi, q_u_mean) @@ -80,14 +81,26 @@ class SVGP(LatentFunctionInference): dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale #derivatives of expected likelihood, assuming zero mean function - Adv = A.T[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N - Admu = A.T.dot(dF_dmu) - AdvA = np.dot(Adv, A) # D, M, M + Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N + Admu = A.dot(dF_dmu) + #AdvA_ = np.dot(Adv, A) # D, M, M + AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing ) + #assert np.allclose(AdvA, AdvA_, 1e-9) + tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi) dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug? - tmp = 2.*(S.dot(Kmmi).swapaxes(1,2) - np.eye(num_inducing)[None, :,:]) # TODO: transpose? - dF_dKmn = np.sum([np.dot(a,b) for a,b in zip(tmp, Adv)],0) + Kmmim.dot(dF_dmu.T) + tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing ) + #tmp_ = S.dot(Kmmi).swapaxes(1,2) + tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:]) + + #dF_dKmn_ = np.sum([np.dot(a,b) for a,b in zip(tmp, Adv)],0) + Kmmim.dot(dF_dmu.T) + dF_dKnm = Kmmim.dot(dF_dmu.T).T + assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dgemm in place call: + for a,b in zip(tmp, Adv): + dgemm(1.0, b.T, a.T, beta=1., c=dF_dKnm, overwrite_c=True) + dF_dKmn = dF_dKnm.T + dF_dm = Admu dF_dS = AdvA