Added faster einsums to linalg, with a couple of tests

This commit is contained in:
Alan Saul 2015-04-30 11:16:29 +01:00
parent 5ad38ac640
commit 762e1e75b0
3 changed files with 35 additions and 8 deletions

View file

@ -4,8 +4,6 @@ from ...util import choleskies
import numpy as np import numpy as np
from .posterior import Posterior from .posterior import Posterior
def ij_ijk_to_ikl
class SVGP(LatentFunctionInference): class SVGP(LatentFunctionInference):
def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, KL_scale=1.0, batch_scale=1.0): def inference(self, q_u_mean, q_u_chol, kern, X, Z, likelihood, Y, mean_function=None, Y_metadata=None, KL_scale=1.0, batch_scale=1.0):
@ -43,8 +41,8 @@ class SVGP(LatentFunctionInference):
#compute the marginal means and variances of q(f) #compute the marginal means and variances of q(f)
A = np.dot(Knm, Kmmi) A = np.dot(Knm, Kmmi)
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u) mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jkl->ikl', A, S),1) #v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] *A.dot(S.reshape(S.shape[0],-1)).reshape(A.shape[0],S.shape[1],S.shape[2]),1) v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
#compute the KL term #compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean) Kmmim = np.dot(Kmmi, q_u_mean)
@ -82,12 +80,13 @@ class SVGP(LatentFunctionInference):
Admu = A.T.dot(dF_dmu) Admu = A.T.dot(dF_dmu)
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)]) AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi) #tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
tmp = np.sum([np.dot(AdvA[:,:,k], S[:,:,k]) for k in range(S.shape[-1])],0).dot(Kmmi) tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug? dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None]) #tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None])
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T) #dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
dF_dKmn = np.sum([np.dot(tmp[:,:,k], Adv[:,:,k]) for k in range(Adv.shape[-1])],0) + Kmmim.dot(dF_dmu.T) dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T)
dF_dm = Admu dF_dm = Admu
dF_dS = AdvA dF_dS = AdvA

View file

@ -1,6 +1,7 @@
import numpy as np import numpy as np
import scipy as sp import scipy as sp
from ..util.linalg import jitchol from GPy.util.linalg import jitchol
import GPy
class LinalgTests(np.testing.TestCase): class LinalgTests(np.testing.TestCase):
def setUp(self): def setUp(self):
@ -35,3 +36,17 @@ class LinalgTests(np.testing.TestCase):
return False return False
except sp.linalg.LinAlgError: except sp.linalg.LinAlgError:
return True return True
def test_einsum_ijk_jlk_to_il(self):
A = np.random.randn(50, 150, 5)
B = np.random.randn(150, 100, 5)
pure = np.einsum('ijk,jlk->il', A, B)
quick = GPy.util.linalg.ijk_jlk_to_il(A, B)
np.testing.assert_allclose(pure, quick)
def test_einsum_ij_jlk_to_ilk(self):
A = np.random.randn(15, 150, 5)
B = np.random.randn(150, 50, 5)
pure = np.einsum('ijk,jlk->il', A, B)
quick = GPy.util.linalg.ijk_jlk_to_il(A,B)
np.testing.assert_allclose(pure, quick)

View file

@ -452,3 +452,16 @@ def backsub_both_sides(L, X, transpose='left'):
tmp, _ = dtrtrs(L, X, lower=1, trans=0) tmp, _ = dtrtrs(L, X, lower=1, trans=0)
return dtrtrs(L, tmp.T, lower=1, trans=0)[0].T return dtrtrs(L, tmp.T, lower=1, trans=0)[0].T
def ij_jlk_to_ilk(A, B):
"""
Faster version of einsum 'ij,jlk->ilk'
"""
return A.dot(B.reshape(B.shape[0], -1)).reshape(A.shape[0], B.shape[1], B.shape[2])
def ijk_jlk_to_il(A, B):
"""
Faster version of einsum einsum('ijk,jlk->il', A,B)
"""
res = np.zeros((A.shape[0], B.shape[1]))
[np.add(np.dot(A[:,:,k], B[:,:,k]), res, res) for k in range(B.shape[-1])]
return res