svgp bugfix

This commit is contained in:
James Hensman 2015-05-19 13:33:47 +01:00
parent d01545c92b
commit 601efa7525

View file

@ -44,9 +44,11 @@ class SVGP(LatentFunctionInference):
#compute the marginal means and variances of q(f)
A, _ = linalg.dpotrs(Lm, Kmn)
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
#TODO? possibly use dtrmm for the above line?
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
v = np.empty((num_data, num_outputs))
for i in range(num_outputs):
tmp = dtrmm(1.0,L[i].T, A, lower=0, trans_a=0)
v[:,i] = np.sum(np.square(tmp),0)
v += (Knn_diag - np.sum(A*Kmn,0))[:,None]
#compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean)
@ -90,11 +92,9 @@ class SVGP(LatentFunctionInference):
tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing )
tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:])
dF_dKnm = Kmmim.dot(dF_dmu.T).T
assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dsymm in place call:
dF_dKmn = Kmmim.dot(dF_dmu.T)
for a,b in zip(tmp, Adv):
dsymm(1.0, a.T, b.T, beta=1., side=1, c=dF_dKnm, overwrite_c=True)
dF_dKmn = dF_dKnm.T
dF_dKmn += np.dot(a.T, b)
dF_dm = Admu
dF_dS = AdvA