migrate pickle_tests to pytest

This commit is contained in:
Martin Bubel 2023-10-10 19:39:19 +02:00
parent 1464c1253f
commit e4ea3bc8b2

View file

@ -5,6 +5,7 @@ Created on 13 Mar 2014
""" """
# import cPickle as pickle # import cPickle as pickle
import pickle import pickle
import pytest
import numpy as np import numpy as np
import tempfile import tempfile
from GPy.examples.dimensionality_reduction import mrd_simulation from GPy.examples.dimensionality_reduction import mrd_simulation
@ -20,7 +21,7 @@ def toy_model():
return m return m
class ListDictTestCase(unittest.TestCase): class ListDictTestCase:
def assertListDictEquals(self, d1, d2, msg=None): def assertListDictEquals(self, d1, d2, msg=None):
# py3 fix # py3 fix
# for k,v in d1.iteritems(): # for k,v in d1.iteritems():
@ -32,8 +33,9 @@ class ListDictTestCase(unittest.TestCase):
np.testing.assert_array_equal(a1, a2) np.testing.assert_array_equal(a1, a2)
class Test(ListDictTestCase): class TestPickleSupport(ListDictTestCase):
@SkipTest # TODO: why is this test skipped?
@pytest.mark.skip("") # TODO
def test_load_pickle(self): def test_load_pickle(self):
import os import os
@ -43,37 +45,37 @@ class Test(ListDictTestCase):
) )
) )
assert m.checkgrad() assert m.checkgrad()
self.assertEqual(m.log_likelihood(), -4.7351019830022087) assert m.log_likelihood(), -4.7351019830022087
def test_model(self): def test_model(self):
par = toy_model() par = toy_model()
pcopy = par.copy() pcopy = par.copy()
self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) assert par.param_array.tolist() == pcopy.param_array.tolist()
np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full) np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full)
self.assertSequenceEqual(str(par), str(pcopy)) assert str(par) == str(pcopy)
self.assertIsNot(par.param_array, pcopy.param_array) assert par.param_array != pcopy.param_array
self.assertIsNot(par.gradient_full, pcopy.gradient_full) assert par.gradient_full != pcopy.gradient_full
self.assertTrue(pcopy.checkgrad()) assert pcopy.checkgrad()
self.assert_(np.any(pcopy.gradient != 0.0)) assert np.any(pcopy.gradient != 0.0)
with tempfile.TemporaryFile("w+b") as f: with tempfile.TemporaryFile("w+b") as f:
par.pickle(f) par.pickle(f)
f.seek(0) f.seek(0)
pcopy = pickle.load(f) pcopy = pickle.load(f)
self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) assert par.param_array.tolist() == pcopy.param_array.tolist()
np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full) np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full)
self.assertSequenceEqual(str(par), str(pcopy)) assert str(par) == str(pcopy)
self.assert_(pcopy.checkgrad()) assert pcopy.checkgrad()
def test_modelrecreation(self): def test_modelrecreation(self):
par = toy_model() par = toy_model()
pcopy = GPRegression(par.X.copy(), par.Y.copy(), kernel=par.kern.copy()) pcopy = GPRegression(par.X.copy(), par.Y.copy(), kernel=par.kern.copy())
np.testing.assert_allclose(par.param_array, pcopy.param_array) np.testing.assert_allclose(par.param_array, pcopy.param_array)
np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full) np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full)
self.assertSequenceEqual(str(par), str(pcopy)) assert str(par) == str(pcopy)
self.assertIsNot(par.param_array, pcopy.param_array) assert par.param_array != pcopy.param_array
self.assertIsNot(par.gradient_full, pcopy.gradient_full) assert par.gradient_full != pcopy.gradient_full
self.assertTrue(pcopy.checkgrad()) assert pcopy.checkgrad()
self.assert_(np.any(pcopy.gradient != 0.0)) assert np.any(pcopy.gradient != 0.0)
np.testing.assert_allclose(pcopy.param_array, par.param_array, atol=1e-6) np.testing.assert_allclose(pcopy.param_array, par.param_array, atol=1e-6)
par.randomize() par.randomize()
with tempfile.TemporaryFile("w+b") as f: with tempfile.TemporaryFile("w+b") as f:
@ -82,8 +84,8 @@ class Test(ListDictTestCase):
pcopy = pickle.load(f) pcopy = pickle.load(f)
np.testing.assert_allclose(par.param_array, pcopy.param_array) np.testing.assert_allclose(par.param_array, pcopy.param_array)
np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full, atol=1e-6) np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full, atol=1e-6)
self.assertSequenceEqual(str(par), str(pcopy)) assert str(par) == str(pcopy)
self.assert_(pcopy.checkgrad()) assert pcopy.checkgrad()
def test_posterior(self): def test_posterior(self):
X = np.random.randn(3, 5) X = np.random.randn(3, 5)
@ -92,46 +94,41 @@ class Test(ListDictTestCase):
par.gradient = 10 par.gradient = 10
pcopy = par.copy() pcopy = par.copy()
pcopy.gradient = 10 pcopy.gradient = 10
self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) assert par.param_array.tolist() == pcopy.param_array.tolist()
self.assertListEqual(par.gradient_full.tolist(), pcopy.gradient_full.tolist()) assert par.gradient_full.tolist() == pcopy.gradient_full.tolist()
self.assertSequenceEqual(str(par), str(pcopy)) assert str(par) == str(pcopy)
self.assertIsNot(par.param_array, pcopy.param_array) assert par.param_array != pcopy.param_array
self.assertIsNot(par.gradient_full, pcopy.gradient_full) assert par.gradient_full != pcopy.gradient_full
with tempfile.TemporaryFile("w+b") as f: with tempfile.TemporaryFile("w+b") as f:
par.pickle(f) par.pickle(f)
f.seek(0) f.seek(0)
pcopy = pickle.load(f) pcopy = pickle.load(f)
self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) assert par.param_array.tolist() == pcopy.param_array.tolist()
pcopy.gradient = 10 pcopy.gradient = 10
np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full) np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full)
np.testing.assert_allclose(pcopy.mean.gradient_full, 10) np.testing.assert_allclose(pcopy.mean.gradient_full, 10)
self.assertSequenceEqual(str(par), str(pcopy)) assert str(par) == str(pcopy)
def test_model_concat(self): def test_model_concat(self):
par = mrd_simulation(optimize=0, plot=0, plot_sim=0) par = mrd_simulation(optimize=0, plot=0, plot_sim=0)
par.randomize() par.randomize()
pcopy = par.copy() pcopy = par.copy()
self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) assert par.param_array.tolist() == pcopy.param_array.tolist()
self.assertListEqual(par.gradient_full.tolist(), pcopy.gradient_full.tolist()) assert par.gradient_full.tolist() == pcopy.gradient_full.tolist()
self.assertSequenceEqual(str(par), str(pcopy)) assert str(par) == str(pcopy)
self.assertIsNot(par.param_array, pcopy.param_array) assert par.param_array != pcopy.param_array
self.assertIsNot(par.gradient_full, pcopy.gradient_full) assert par.gradient_full != pcopy.gradient_full
self.assertTrue(par.checkgrad()) assert par.checkgrad()
self.assertTrue(pcopy.checkgrad()) assert pcopy.checkgrad()
self.assert_(np.any(pcopy.gradient != 0.0)) assert np.any(pcopy.gradient != 0.0)
with tempfile.TemporaryFile("w+b") as f: with tempfile.TemporaryFile("w+b") as f:
par.pickle(f) par.pickle(f)
f.seek(0) f.seek(0)
pcopy = pickle.load(f) pcopy = pickle.load(f)
self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) assert par.param_array.tolist() == pcopy.param_array.tolist()
np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full) np.testing.assert_allclose(par.gradient_full, pcopy.gradient_full)
self.assertSequenceEqual(str(par), str(pcopy)) assert str(par) == str(pcopy)
self.assert_(pcopy.checkgrad()) assert pcopy.checkgrad()
def _callback(self, what, which): def _callback(self, what, which):
what.count += 1 what.count += 1
if __name__ == "__main__":
# import sys;sys.argv = ['', 'Test.test_parameter_index_operations']
unittest.main()