Merge pull request #184 from mikecroucher/master

Bug fix for issue #161
This commit is contained in:
mikecroucher 2015-03-28 08:36:43 +00:00
commit 775ce9e64c
2 changed files with 12 additions and 2 deletions

View file

@ -1,6 +1,6 @@
import numpy as np
import scipy as sp
from ..util.linalg import jitchol
from ..util.linalg import jitchol,trace_dot
class LinalgTests(np.testing.TestCase):
def setUp(self):
@ -33,3 +33,13 @@ class LinalgTests(np.testing.TestCase):
return False
except sp.linalg.LinAlgError:
return True
def test_trace_dot(self):
N = 5
A = np.random.rand(N,N)
B = np.random.rand(N,N)
trace = np.trace(A.dot(B))
test_trace = trace_dot(A,B)
np.testing.assert_allclose(trace,test_trace,atol=1e-13)

View file

@ -191,7 +191,7 @@ def trace_dot(a, b):
"""
Efficiently compute the trace of the matrix product of a and b
"""
return np.sum(a * b)
return np.einsum('ij,ji->', a, b)
def mdot(*args):
"""