mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
svgp bugfix
This commit is contained in:
parent
d01545c92b
commit
601efa7525
1 changed files with 7 additions and 7 deletions
|
|
@ -44,9 +44,11 @@ class SVGP(LatentFunctionInference):
|
|||
#compute the marginal means and variances of q(f)
|
||||
A, _ = linalg.dpotrs(Lm, Kmn)
|
||||
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
|
||||
LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
|
||||
#TODO? possibly use dtrmm for the above line?
|
||||
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
|
||||
v = np.empty((num_data, num_outputs))
|
||||
for i in range(num_outputs):
|
||||
tmp = dtrmm(1.0,L[i].T, A, lower=0, trans_a=0)
|
||||
v[:,i] = np.sum(np.square(tmp),0)
|
||||
v += (Knn_diag - np.sum(A*Kmn,0))[:,None]
|
||||
|
||||
#compute the KL term
|
||||
Kmmim = np.dot(Kmmi, q_u_mean)
|
||||
|
|
@ -90,11 +92,9 @@ class SVGP(LatentFunctionInference):
|
|||
tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing )
|
||||
tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:])
|
||||
|
||||
dF_dKnm = Kmmim.dot(dF_dmu.T).T
|
||||
assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dsymm in place call:
|
||||
dF_dKmn = Kmmim.dot(dF_dmu.T)
|
||||
for a,b in zip(tmp, Adv):
|
||||
dsymm(1.0, a.T, b.T, beta=1., side=1, c=dF_dKnm, overwrite_c=True)
|
||||
dF_dKmn = dF_dKnm.T
|
||||
dF_dKmn += np.dot(a.T, b)
|
||||
|
||||
dF_dm = Admu
|
||||
dF_dS = AdvA
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue