svgp, more c-ordering

This commit is contained in:
James Hensman 2015-05-14 16:49:11 +01:00
parent 0450e228a8
commit 14043be0ce

View file

@ -3,6 +3,7 @@ from ...util import linalg
from ...util import choleskies
import numpy as np
from .posterior import Posterior
from scipy.linalg.blas import dgemm
class SVGP(LatentFunctionInference):
@ -37,16 +38,16 @@ class SVGP(LatentFunctionInference):
#compute kernel related stuff
Kmm = kern.K(Z)
Knm = kern.K(X, Z)
Kmn = kern.K(Z, X)
Knn_diag = kern.Kdiag(X)
Kmmi, Lm, Lmi, logdetKmm = linalg.pdinv(Kmm)
#compute the marginal means and variances of q(f)
A = np.dot(Knm, Kmmi)
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + (S.dot(A.T)*A.T[None,:,:]).sum(1).T
A = np.dot(Kmmi, Kmn)
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
#compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean)
@ -80,14 +81,26 @@ class SVGP(LatentFunctionInference):
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
#derivatives of expected likelihood, assuming zero mean function
Adv = A.T[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
Admu = A.T.dot(dF_dmu)
AdvA = np.dot(Adv, A) # D, M, M
Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
Admu = A.dot(dF_dmu)
#AdvA_ = np.dot(Adv, A) # D, M, M
AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing )
#assert np.allclose(AdvA, AdvA_, 1e-9)
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
tmp = 2.*(S.dot(Kmmi).swapaxes(1,2) - np.eye(num_inducing)[None, :,:]) # TODO: transpose?
dF_dKmn = np.sum([np.dot(a,b) for a,b in zip(tmp, Adv)],0) + Kmmim.dot(dF_dmu.T)
tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing )
#tmp_ = S.dot(Kmmi).swapaxes(1,2)
tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:])
#dF_dKmn_ = np.sum([np.dot(a,b) for a,b in zip(tmp, Adv)],0) + Kmmim.dot(dF_dmu.T)
dF_dKnm = Kmmim.dot(dF_dmu.T).T
assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dgemm in place call:
for a,b in zip(tmp, Adv):
dgemm(1.0, b.T, a.T, beta=1., c=dF_dKnm, overwrite_c=True)
dF_dKmn = dF_dKnm.T
dF_dm = Admu
dF_dS = AdvA