mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 05:22:38 +02:00
Bug fix for issue #161
This commit is contained in:
parent
9f51137469
commit
c307b989cc
2 changed files with 12 additions and 2 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import scipy as sp
|
||||
from ..util.linalg import jitchol
|
||||
from ..util.linalg import jitchol,trace_dot
|
||||
|
||||
class LinalgTests(np.testing.TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -33,3 +33,13 @@ class LinalgTests(np.testing.TestCase):
|
|||
return False
|
||||
except sp.linalg.LinAlgError:
|
||||
return True
|
||||
|
||||
def test_trace_dot(self):
|
||||
N = 5
|
||||
A = np.random.rand(N,N)
|
||||
B = np.random.rand(N,N)
|
||||
trace = np.trace(A.dot(B))
|
||||
test_trace = trace_dot(A,B)
|
||||
np.testing.assert_allclose(trace,test_trace,atol=1e-13)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ def trace_dot(a, b):
|
|||
"""
|
||||
Efficiently compute the trace of the matrix product of a and b
|
||||
"""
|
||||
return np.sum(a * b)
|
||||
return np.einsum('ij,ji->', a, b)
|
||||
|
||||
def mdot(*args):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue