migrate state_space_main_tests to pytest

This commit is contained in:
Martin Bubel 2023-10-10 20:00:18 +02:00
parent c8a6ca4e4d
commit 5f08c2c139

View file

@ -262,10 +262,7 @@ def generate_random_y_data(samples, dim, ts_no):
return Y
class StateSpaceKernelsTests(np.testing.TestCase):
def setUp(self):
pass
class TestStateSpaceKernels:
def run_descr_model(
self,
measurements,
@ -290,12 +287,14 @@ class StateSpaceKernelsTests(np.testing.TestCase):
state_dim = 1 if not isinstance(A, np.ndarray) else A.shape[0]
ts_no = 1 if (len(measurements.shape) < 3) else measurements.shape[2]
import importlib
grad_params_no = None if dA is None else dA.shape[2]
ss_setup.use_cython = use_cython
global ssm
if (ssm.cython_code_available) and (ssm.use_cython != use_cython):
reload(ssm)
importlib.reload(ssm.DescreteStateSpace)
grad_calc_params = None
if calc_grad_log_likelihood:
@ -328,7 +327,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
)
f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value
f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
_f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
if true_states is not None:
# print np.max(np.abs(f_mean_squeezed-true_states))
@ -345,7 +344,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
f_var.shape, (measurements.shape[0] + 1, state_dim, state_dim)
)
(M_smooth, P_smooth) = ssm.DescreteStateSpace.rts_smoother(
(_M_smooth, _P_smooth) = ssm.DescreteStateSpace.rts_smoother(
state_dim, dynamic_callables_smoother, f_mean, f_var
)
@ -376,10 +375,12 @@ class StateSpaceKernelsTests(np.testing.TestCase):
state_dim = 1 if not isinstance(F, np.ndarray) else F.shape[0]
ts_no = 1 if (len(Y_data.shape) < 3) else Y_data.shape[2]
import importlib
ss_setup.use_cython = use_cython
global ssm
if (ssm.cython_code_available) and (ssm.use_cython != use_cython):
reload(ssm)
importlib.reload(ssm)
(
f_mean,
@ -406,15 +407,15 @@ class StateSpaceKernelsTests(np.testing.TestCase):
grad_calc_params=grad_calc_params,
)
f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value
f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
_f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value
_f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
np.testing.assert_equal(f_mean.shape, (Y_data.shape[0] + 1, state_dim, ts_no))
np.testing.assert_equal(
f_var.shape, (Y_data.shape[0] + 1, state_dim, state_dim)
)
(M_smooth, P_smooth) = ssm.ContDescrStateSpace.cont_discr_rts_smoother(
(_M_smooth, _P_smooth) = ssm.ContDescrStateSpace.cont_discr_rts_smoother(
state_dim, f_mean, f_var, dynamic_callables_smoother
)
@ -1516,13 +1517,3 @@ class StateSpaceKernelsTests(np.testing.TestCase):
# plt.show()
# # plotting <-
# # 1D measurements, 1 ts_no <-
if __name__ == "__main__":
print("Running state-space inference tests...")
unittest.main()
# tt = StateSpaceKernelsTests('test_discrete_ss_first')
# res = tt.test_discrete_ss_first(plot=True)
# res = tt.test_discrete_ss_1D(plot=True)
# res = tt.test_discrete_ss_2D(plot=False)
# res = tt.test_continuos_ss(plot=True)