mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-25 04:46:23 +02:00
migrate state_space_main_tests to pytest
This commit is contained in:
parent
c8a6ca4e4d
commit
5f08c2c139
1 changed files with 12 additions and 21 deletions
|
|
@ -262,10 +262,7 @@ def generate_random_y_data(samples, dim, ts_no):
|
|||
return Y
|
||||
|
||||
|
||||
class StateSpaceKernelsTests(np.testing.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
class TestStateSpaceKernels:
|
||||
def run_descr_model(
|
||||
self,
|
||||
measurements,
|
||||
|
|
@ -290,12 +287,14 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
|
||||
state_dim = 1 if not isinstance(A, np.ndarray) else A.shape[0]
|
||||
ts_no = 1 if (len(measurements.shape) < 3) else measurements.shape[2]
|
||||
import importlib
|
||||
|
||||
grad_params_no = None if dA is None else dA.shape[2]
|
||||
|
||||
ss_setup.use_cython = use_cython
|
||||
global ssm
|
||||
if (ssm.cython_code_available) and (ssm.use_cython != use_cython):
|
||||
reload(ssm)
|
||||
importlib.reload(ssm.DescreteStateSpace)
|
||||
|
||||
grad_calc_params = None
|
||||
if calc_grad_log_likelihood:
|
||||
|
|
@ -328,7 +327,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
)
|
||||
|
||||
f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value
|
||||
f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
|
||||
_f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
|
||||
|
||||
if true_states is not None:
|
||||
# print np.max(np.abs(f_mean_squeezed-true_states))
|
||||
|
|
@ -345,7 +344,7 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
f_var.shape, (measurements.shape[0] + 1, state_dim, state_dim)
|
||||
)
|
||||
|
||||
(M_smooth, P_smooth) = ssm.DescreteStateSpace.rts_smoother(
|
||||
(_M_smooth, _P_smooth) = ssm.DescreteStateSpace.rts_smoother(
|
||||
state_dim, dynamic_callables_smoother, f_mean, f_var
|
||||
)
|
||||
|
||||
|
|
@ -376,10 +375,12 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
state_dim = 1 if not isinstance(F, np.ndarray) else F.shape[0]
|
||||
ts_no = 1 if (len(Y_data.shape) < 3) else Y_data.shape[2]
|
||||
|
||||
import importlib
|
||||
|
||||
ss_setup.use_cython = use_cython
|
||||
global ssm
|
||||
if (ssm.cython_code_available) and (ssm.use_cython != use_cython):
|
||||
reload(ssm)
|
||||
importlib.reload(ssm)
|
||||
|
||||
(
|
||||
f_mean,
|
||||
|
|
@ -406,15 +407,15 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
grad_calc_params=grad_calc_params,
|
||||
)
|
||||
|
||||
f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value
|
||||
f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
|
||||
_f_mean_squeezed = np.squeeze(f_mean[1:, :]) # exclude initial value
|
||||
_f_var_squeezed = np.squeeze(f_var[1:, :]) # exclude initial value
|
||||
|
||||
np.testing.assert_equal(f_mean.shape, (Y_data.shape[0] + 1, state_dim, ts_no))
|
||||
np.testing.assert_equal(
|
||||
f_var.shape, (Y_data.shape[0] + 1, state_dim, state_dim)
|
||||
)
|
||||
|
||||
(M_smooth, P_smooth) = ssm.ContDescrStateSpace.cont_discr_rts_smoother(
|
||||
(_M_smooth, _P_smooth) = ssm.ContDescrStateSpace.cont_discr_rts_smoother(
|
||||
state_dim, f_mean, f_var, dynamic_callables_smoother
|
||||
)
|
||||
|
||||
|
|
@ -1516,13 +1517,3 @@ class StateSpaceKernelsTests(np.testing.TestCase):
|
|||
# plt.show()
|
||||
# # plotting <-
|
||||
# # 1D measurements, 1 ts_no <-
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running state-space inference tests...")
|
||||
unittest.main()
|
||||
|
||||
# tt = StateSpaceKernelsTests('test_discrete_ss_first')
|
||||
# res = tt.test_discrete_ss_first(plot=True)
|
||||
# res = tt.test_discrete_ss_1D(plot=True)
|
||||
# res = tt.test_discrete_ss_2D(plot=False)
|
||||
# res = tt.test_continuos_ss(plot=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue