diff --git a/GPy/testing/linalg_test.py b/GPy/testing/linalg_test.py index b734f6af..3a65735c 100644 --- a/GPy/testing/linalg_test.py +++ b/GPy/testing/linalg_test.py @@ -1,6 +1,6 @@ import numpy as np import scipy as sp -from ..util.linalg import jitchol +from ..util.linalg import jitchol,trace_dot class LinalgTests(np.testing.TestCase): def setUp(self): @@ -33,3 +33,13 @@ class LinalgTests(np.testing.TestCase): return False except sp.linalg.LinAlgError: return True + + def test_trace_dot(self): + N = 5 + A = np.random.rand(N,N) + B = np.random.rand(N,N) + trace = np.trace(A.dot(B)) + test_trace = trace_dot(A,B) + np.testing.assert_allclose(trace,test_trace,atol=1e-13) + + diff --git a/GPy/util/linalg.py b/GPy/util/linalg.py index 2c02357c..66254685 100644 --- a/GPy/util/linalg.py +++ b/GPy/util/linalg.py @@ -191,7 +191,7 @@ def trace_dot(a, b): """ Efficiently compute the trace of the matrix product of a and b """ - return np.sum(a * b) + return np.einsum('ij,ji->', a, b) def mdot(*args): """