migrate linalg test to pytest

This commit is contained in:
Martin Bubel 2023-10-06 18:51:48 +02:00
parent 8af7c8286c
commit ef7d2f299c

View file

@ -3,7 +3,7 @@ import scipy as sp
from ..util.linalg import jitchol, trace_dot, ijk_jlk_to_il, ijk_ljk_to_ilk from ..util.linalg import jitchol, trace_dot, ijk_jlk_to_il, ijk_ljk_to_ilk
class LinalgTests: class TestLinalg:
def setup(self): def setup(self):
# Create PD matrix # Create PD matrix
A = np.random.randn(20, 100) A = np.random.randn(20, 100)
@ -21,6 +21,7 @@ class LinalgTests:
Expect 5 rounds of jitter to be added and for the recovered matrix to be Expect 5 rounds of jitter to be added and for the recovered matrix to be
identical to the corrupted matrix apart from the jitter added to the diagonal identical to the corrupted matrix apart from the jitter added to the diagonal
""" """
self.setup()
L = jitchol(self.A_corrupt, maxtries=5) L = jitchol(self.A_corrupt, maxtries=5)
A_new = L.dot(L.T) A_new = L.dot(L.T)
diff = A_new - self.A_corrupt diff = A_new - self.A_corrupt
@ -29,6 +30,7 @@ class LinalgTests:
) )
def test_jitchol_failure(self): def test_jitchol_failure(self):
self.setup()
try: try:
""" """
Expecting an exception to be thrown as we expect it to require Expecting an exception to be thrown as we expect it to require
@ -40,6 +42,7 @@ class LinalgTests:
return True return True
def test_trace_dot(self): def test_trace_dot(self):
self.setup()
N = 5 N = 5
A = np.random.rand(N, N) A = np.random.rand(N, N)
B = np.random.rand(N, N) B = np.random.rand(N, N)
@ -48,6 +51,7 @@ class LinalgTests:
np.testing.assert_allclose(trace, test_trace, atol=1e-13) np.testing.assert_allclose(trace, test_trace, atol=1e-13)
def test_einsum_ij_jlk_to_ilk(self): def test_einsum_ij_jlk_to_ilk(self):
self.setup()
A = np.random.randn(15, 150, 5) A = np.random.randn(15, 150, 5)
B = np.random.randn(150, 50, 5) B = np.random.randn(150, 50, 5)
pure = np.einsum("ijk,jlk->il", A, B) pure = np.einsum("ijk,jlk->il", A, B)
@ -55,6 +59,7 @@ class LinalgTests:
np.testing.assert_allclose(pure, quick) np.testing.assert_allclose(pure, quick)
def test_einsum_ijk_ljk_to_ilk(self): def test_einsum_ijk_ljk_to_ilk(self):
self.setup()
A = np.random.randn(150, 20, 5) A = np.random.randn(150, 20, 5)
B = np.random.randn(150, 20, 5) B = np.random.randn(150, 20, 5)
# B = A.copy() # B = A.copy()