fix pytesting pickle

This commit is contained in:
Martin Bubel 2023-10-17 08:12:07 +02:00
parent 3a8b093c65
commit 9006236463

View file

@ -53,8 +53,8 @@ class TestPickleSupport(ListDictTestCase):
assert par.param_array.tolist() == pcopy.param_array.tolist() assert par.param_array.tolist() == pcopy.param_array.tolist()
np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full) np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full)
assert str(par) == str(pcopy) assert str(par) == str(pcopy)
assert np.all(par.param_array != pcopy.param_array) assert np.all(par.param_array == pcopy.param_array)
assert par.gradient_full != pcopy.gradient_full assert np.all(par.gradient_full == pcopy.gradient_full)
assert pcopy.checkgrad() assert pcopy.checkgrad()
assert np.any(pcopy.gradient != 0.0) assert np.any(pcopy.gradient != 0.0)
with tempfile.TemporaryFile("w+b") as f: with tempfile.TemporaryFile("w+b") as f:
@ -72,8 +72,8 @@ class TestPickleSupport(ListDictTestCase):
np.testing.assert_allclose(par.param_array, pcopy.param_array) np.testing.assert_allclose(par.param_array, pcopy.param_array)
np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full) np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full)
assert str(par) == str(pcopy) assert str(par) == str(pcopy)
assert np.all(par.param_array != pcopy.param_array) assert np.all(par.param_array == pcopy.param_array)
assert par.gradient_full != pcopy.gradient_full assert np.all(par.gradient_full == pcopy.gradient_full)
assert pcopy.checkgrad() assert pcopy.checkgrad()
assert np.any(pcopy.gradient != 0.0) assert np.any(pcopy.gradient != 0.0)
np.testing.assert_allclose(pcopy.param_array, par.param_array, atol=1e-6) np.testing.assert_allclose(pcopy.param_array, par.param_array, atol=1e-6)
@ -97,8 +97,8 @@ class TestPickleSupport(ListDictTestCase):
assert par.param_array.tolist() == pcopy.param_array.tolist() assert par.param_array.tolist() == pcopy.param_array.tolist()
assert par.gradient_full.tolist() == pcopy.gradient_full.tolist() assert par.gradient_full.tolist() == pcopy.gradient_full.tolist()
assert str(par) == str(pcopy) assert str(par) == str(pcopy)
assert np.all(par.param_array != pcopy.param_array) assert np.all(par.param_array == pcopy.param_array)
assert par.gradient_full != pcopy.gradient_full assert np.all(par.gradient_full == pcopy.gradient_full)
with tempfile.TemporaryFile("w+b") as f: with tempfile.TemporaryFile("w+b") as f:
par.pickle(f) par.pickle(f)
f.seek(0) f.seek(0)
@ -116,8 +116,8 @@ class TestPickleSupport(ListDictTestCase):
assert par.param_array.tolist() == pcopy.param_array.tolist() assert par.param_array.tolist() == pcopy.param_array.tolist()
assert par.gradient_full.tolist() == pcopy.gradient_full.tolist() assert par.gradient_full.tolist() == pcopy.gradient_full.tolist()
assert str(par) == str(pcopy) assert str(par) == str(pcopy)
assert np.all(par.param_array != pcopy.param_array) assert np.all(par.param_array == pcopy.param_array)
assert par.gradient_full != pcopy.gradient_full assert np.all(par.gradient_full == pcopy.gradient_full)
assert par.checkgrad() assert par.checkgrad()
assert pcopy.checkgrad() assert pcopy.checkgrad()
assert np.any(pcopy.gradient != 0.0) assert np.any(pcopy.gradient != 0.0)