svgp working with reordered chols

This commit is contained in:
James Hensman 2015-05-15 08:59:19 +01:00
parent 2249ec06a5
commit 77093a304c
4 changed files with 51 additions and 59 deletions

View file

@ -24,8 +24,6 @@ class SVGP(LatentFunctionInference):
if np.any(np.isinf(Si)):
raise ValueError("Cholesky representation unstable")
#S = S + np.eye(S.shape[0])*1e-5*np.max(np.max(S))
#Si, Lnew, _,_ = linalg.pdinv(S)
#compute mean function stuff
if mean_function is not None:
@ -35,27 +33,21 @@ class SVGP(LatentFunctionInference):
prior_mean_u = np.zeros((num_inducing, num_outputs))
prior_mean_f = np.zeros((num_data, num_outputs))
#compute kernel related stuff
Kmm = kern.K(Z)
Kmn = kern.K(Z, X)
Knn_diag = kern.Kdiag(X)
Kmmi, Lm, Lmi, logdetKmm = linalg.pdinv(Kmm)
Lm = linalg.jitchol(Kmm)
logdetKmm = 2.*np.sum(np.log(np.diag(Lm)))
Kmmi, _ = linalg.dpotri(Lm)
#compute the marginal means and variances of q(f)
A = np.dot(Kmmi, Kmn)
A, _ = linalg.dpotrs(Lm, Kmn)
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
#LA = np.empty((num_outputs, num_inducing, num_data))
#Af = np.asfortranarray(A)
#for Li, LAi in zip(L, LA):
#LAi[:,:] = dtrmm(1., Li.T, Af, side=0, lower=0, trans_a=1, overwrite_b=0)
#stop
#assert np.allclose(LA, LA_)
#TODO? possibly use dtrmm for the above line?
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
#compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
@ -90,7 +82,7 @@ class SVGP(LatentFunctionInference):
#derivatives of expected likelihood, assuming zero mean function
Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
Admu = A.dot(dF_dmu)
Adv = np.ascontiguousarray(Adv) # makes for faster operations later...
Adv = np.ascontiguousarray(Adv) # makes for faster operations later...(inc dsymm)
AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing )
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T