From 601efa7525583632936610dd127ce870a104aa70 Mon Sep 17 00:00:00 2001 From: James Hensman Date: Tue, 19 May 2015 13:33:47 +0100 Subject: [PATCH] svgp bugfix --- GPy/inference/latent_function_inference/svgp.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/GPy/inference/latent_function_inference/svgp.py b/GPy/inference/latent_function_inference/svgp.py index 3c4602cc..b04ca609 100644 --- a/GPy/inference/latent_function_inference/svgp.py +++ b/GPy/inference/latent_function_inference/svgp.py @@ -44,9 +44,11 @@ class SVGP(LatentFunctionInference): #compute the marginal means and variances of q(f) A, _ = linalg.dpotrs(Lm, Kmn) mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u) - LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data) - #TODO? possibly use dtrmm for the above line? - v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T + v = np.empty((num_data, num_outputs)) + for i in range(num_outputs): + tmp = dtrmm(1.0,L[i].T, A, lower=0, trans_a=0) + v[:,i] = np.sum(np.square(tmp),0) + v += (Knn_diag - np.sum(A*Kmn,0))[:,None] #compute the KL term Kmmim = np.dot(Kmmi, q_u_mean) @@ -90,11 +92,9 @@ class SVGP(LatentFunctionInference): tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing ) tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:]) - dF_dKnm = Kmmim.dot(dF_dmu.T).T - assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dsymm in place call: + dF_dKmn = Kmmim.dot(dF_dmu.T) for a,b in zip(tmp, Adv): - dsymm(1.0, a.T, b.T, beta=1., side=1, c=dF_dKnm, overwrite_c=True) - dF_dKmn = dF_dKnm.T + dF_dKmn += np.dot(a.T, b) dF_dm = Admu dF_dS = AdvA