diff --git a/GPy/inference/latent_function_inference/svgp.py b/GPy/inference/latent_function_inference/svgp.py index 9c5599bd..e416d0a5 100644 --- a/GPy/inference/latent_function_inference/svgp.py +++ b/GPy/inference/latent_function_inference/svgp.py @@ -4,8 +4,6 @@ from ...util import choleskies import numpy as np from .posterior import Posterior -def ij_ijk_to_ikl - class SVGP(LatentFunctionInference): def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, KL_scale=1.0, batch_scale=1.0): @@ -43,8 +41,8 @@ class SVGP(LatentFunctionInference): #compute the marginal means and variances of q(f) A = np.dot(Knm, Kmmi) mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u) - #v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jkl->ikl', A, S),1) - v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] *A.dot(S.reshape(S.shape[0],-1)).reshape(A.shape[0],S.shape[1],S.shape[2]),1) + #v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1) + v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1) #compute the KL term Kmmim = np.dot(Kmmi, q_u_mean) @@ -82,12 +80,13 @@ class SVGP(LatentFunctionInference): Admu = A.T.dot(dF_dmu) AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)]) #tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi) - tmp = np.sum([np.dot(AdvA[:,:,k], S[:,:,k]) for k in range(S.shape[-1])],0).dot(Kmmi) + tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi) dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug? - tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None]) + #tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None]) + tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None]) #dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T) - dF_dKmn = np.sum([np.dot(tmp[:,:,k], Adv[:,:,k]) for k in range(Adv.shape[-1])],0) + Kmmim.dot(dF_dmu.T) + dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T) dF_dm = Admu dF_dS = AdvA diff --git a/GPy/testing/linalg_test.py b/GPy/testing/linalg_test.py index 8e103795..81cb0368 100644 --- a/GPy/testing/linalg_test.py +++ b/GPy/testing/linalg_test.py @@ -1,6 +1,7 @@ import numpy as np import scipy as sp -from ..util.linalg import jitchol +from GPy.util.linalg import jitchol +import GPy class LinalgTests(np.testing.TestCase): def setUp(self): @@ -35,3 +36,17 @@ class LinalgTests(np.testing.TestCase): return False except sp.linalg.LinAlgError: return True + + def test_einsum_ijk_jlk_to_il(self): + A = np.random.randn(50, 150, 5) + B = np.random.randn(150, 100, 5) + pure = np.einsum('ijk,jlk->il', A, B) + quick = GPy.util.linalg.ijk_jlk_to_il(A, B) + np.testing.assert_allclose(pure, quick) + + def test_einsum_ij_jlk_to_ilk(self): + A = np.random.randn(15, 150, 5) + B = np.random.randn(150, 50, 5) + pure = np.einsum('ijk,jlk->il', A, B) + quick = GPy.util.linalg.ijk_jlk_to_il(A,B) + np.testing.assert_allclose(pure, quick) diff --git a/GPy/util/linalg.py b/GPy/util/linalg.py index 29744b9f..285701e7 100644 --- a/GPy/util/linalg.py +++ b/GPy/util/linalg.py @@ -452,3 +452,16 @@ def backsub_both_sides(L, X, transpose='left'): tmp, _ = dtrtrs(L, X, lower=1, trans=0) return dtrtrs(L, tmp.T, lower=1, trans=0)[0].T +def ij_jlk_to_ilk(A, B): + """ + Faster version of einsum 'ij,jlk->ilk' + """ + return A.dot(B.reshape(B.shape[0], -1)).reshape(A.shape[0], B.shape[1], B.shape[2]) + +def ijk_jlk_to_il(A, B): + """ + Faster version of einsum einsum('ijk,jlk->il', A,B) + """ + res = np.zeros((A.shape[0], B.shape[1])) + [np.add(np.dot(A[:,:,k], B[:,:,k]), res, res) for k in range(B.shape[-1])] + return res