Bug fix for issue #161

This commit is contained in:
Mike Croucher 2015-03-15 11:10:45 +00:00
parent 9f51137469
commit c307b989cc
2 changed files with 12 additions and 2 deletions

View file

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import scipy as sp import scipy as sp
from ..util.linalg import jitchol from ..util.linalg import jitchol,trace_dot
class LinalgTests(np.testing.TestCase): class LinalgTests(np.testing.TestCase):
def setUp(self): def setUp(self):
@ -33,3 +33,13 @@ class LinalgTests(np.testing.TestCase):
return False return False
except sp.linalg.LinAlgError: except sp.linalg.LinAlgError:
return True return True
def test_trace_dot(self):
N = 5
A = np.random.rand(N,N)
B = np.random.rand(N,N)
trace = np.trace(A.dot(B))
test_trace = trace_dot(A,B)
np.testing.assert_allclose(trace,test_trace,atol=1e-13)

View file

@ -191,7 +191,7 @@ def trace_dot(a, b):
""" """
Efficiently compute the trace of the matrix product of a and b Efficiently compute the trace of the matrix product of a and b
""" """
return np.sum(a * b) return np.einsum('ij,ji->', a, b)
def mdot(*args): def mdot(*args):
""" """