migrate state_space_main_tests to pytest

This commit is contained in:
Martin Bubel 2023-10-10 20:00:18 +02:00
parent c8a6ca4e4d
commit 5f08c2c139

View file

@ -262,10 +262,7 @@ def generate_random_y_data(samples, dim, ts_no):
return Y return Y
class StateSpaceKernelsTests(np.testing.TestCase): class TestStateSpaceKernels:
def setUp(self):
pass
def run_descr_model( def run_descr_model(
self, self,
measurements, measurements,
@ -290,12 +287,14 @@ class StateSpaceKernelsTests(np.testing.TestCase):
state_dim = 1 if not isinstance(A, np.ndarray) else A.shape[0] state_dim = 1 if not isinstance(A, np.ndarray) else A.shape[0]
ts_no = 1 if (len(measurements.shape) < 3) else measurements.shape[2] ts_no = 1 if (len(measurements.shape) < 3) else measurements.shape[2]
import importlib
grad_params_no = None if dA is None else dA.shape[2] grad_params_no = None if dA is None else dA.shape[2]
ss_setup.use_cython = use_cython ss_setup.use_cython = use_cython
global ssm global ssm
if (ssm.cython_code_available) and (ssm.use_cython != use_cython): if (ssm.cython_code_available) and (ssm.use_cython != use_cython):
reload(ssm) importlib.reload(ssm.DescreteStateSpace)
grad_calc_params = None grad_calc_params = None
if calc_grad_log_likelihood: if calc_grad_log_likelihood:
@ -328,7 +327,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
) )
f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value
f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value _f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
if true_states is not None: if true_states is not None:
# print np.max(np.abs(f_mean_squeezed-true_states)) # print np.max(np.abs(f_mean_squeezed-true_states))
@ -345,7 +344,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
f_var.shape, (measurements.shape[0] + 1, state_dim, state_dim) f_var.shape, (measurements.shape[0] + 1, state_dim, state_dim)
) )
(M_smooth, P_smooth) = ssm.DescreteStateSpace.rts_smoother( (_M_smooth, _P_smooth) = ssm.DescreteStateSpace.rts_smoother(
state_dim, dynamic_callables_smoother, f_mean, f_var state_dim, dynamic_callables_smoother, f_mean, f_var
) )
@ -376,10 +375,12 @@ class StateSpaceKernelsTests(np.testing.TestCase):
state_dim = 1 if not isinstance(F, np.ndarray) else F.shape[0] state_dim = 1 if not isinstance(F, np.ndarray) else F.shape[0]
ts_no = 1 if (len(Y_data.shape) < 3) else Y_data.shape[2] ts_no = 1 if (len(Y_data.shape) < 3) else Y_data.shape[2]
import importlib
ss_setup.use_cython = use_cython ss_setup.use_cython = use_cython
global ssm global ssm
if (ssm.cython_code_available) and (ssm.use_cython != use_cython): if (ssm.cython_code_available) and (ssm.use_cython != use_cython):
reload(ssm) importlib.reload(ssm)
( (
f_mean, f_mean,
@ -406,15 +407,15 @@ class StateSpaceKernelsTests(np.testing.TestCase):
grad_calc_params=grad_calc_params, grad_calc_params=grad_calc_params,
) )
f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value _f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value
f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value _f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
np.testing.assert_equal(f_mean.shape, (Y_data.shape[0] + 1, state_dim, ts_no)) np.testing.assert_equal(f_mean.shape, (Y_data.shape[0] + 1, state_dim, ts_no))
np.testing.assert_equal( np.testing.assert_equal(
f_var.shape, (Y_data.shape[0] + 1, state_dim, state_dim) f_var.shape, (Y_data.shape[0] + 1, state_dim, state_dim)
) )
(M_smooth, P_smooth) = ssm.ContDescrStateSpace.cont_discr_rts_smoother( (_M_smooth, _P_smooth) = ssm.ContDescrStateSpace.cont_discr_rts_smoother(
state_dim, f_mean, f_var, dynamic_callables_smoother state_dim, f_mean, f_var, dynamic_callables_smoother
) )
@ -1516,13 +1517,3 @@ class StateSpaceKernelsTests(np.testing.TestCase):
# plt.show() # plt.show()
# # plotting <- # # plotting <-
# # 1D measurements, 1 ts_no <- # # 1D measurements, 1 ts_no <-
if __name__ == "__main__":
print("Running state-space inference tests...")
unittest.main()
# tt = StateSpaceKernelsTests('test_discrete_ss_first')
# res = tt.test_discrete_ss_first(plot=True)
# res = tt.test_discrete_ss_1D(plot=True)
# res = tt.test_discrete_ss_2D(plot=False)
# res = tt.test_continuos_ss(plot=True)