mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 13:32:39 +02:00
Bug fix for issue #161
This commit is contained in:
parent
9f51137469
commit
c307b989cc
2 changed files with 12 additions and 2 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy as sp
|
import scipy as sp
|
||||||
from ..util.linalg import jitchol
|
from ..util.linalg import jitchol,trace_dot
|
||||||
|
|
||||||
class LinalgTests(np.testing.TestCase):
|
class LinalgTests(np.testing.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
@ -33,3 +33,13 @@ class LinalgTests(np.testing.TestCase):
|
||||||
return False
|
return False
|
||||||
except sp.linalg.LinAlgError:
|
except sp.linalg.LinAlgError:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def test_trace_dot(self):
|
||||||
|
N = 5
|
||||||
|
A = np.random.rand(N,N)
|
||||||
|
B = np.random.rand(N,N)
|
||||||
|
trace = np.trace(A.dot(B))
|
||||||
|
test_trace = trace_dot(A,B)
|
||||||
|
np.testing.assert_allclose(trace,test_trace,atol=1e-13)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ def trace_dot(a, b):
|
||||||
"""
|
"""
|
||||||
Efficiently compute the trace of the matrix product of a and b
|
Efficiently compute the trace of the matrix product of a and b
|
||||||
"""
|
"""
|
||||||
return np.sum(a * b)
|
return np.einsum('ij,ji->', a, b)
|
||||||
|
|
||||||
def mdot(*args):
|
def mdot(*args):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue