mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-27 14:25:16 +02:00
migrate serialzation_tests to pytest
This commit is contained in:
parent
96d8ac0975
commit
c69f68feba
1 changed files with 8 additions and 23 deletions
|
|
@ -10,7 +10,7 @@ import os
|
|||
fixed_seed = 11
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
class TestSerialization:
|
||||
def test_serialize_deserialize_kernels(self):
|
||||
k1 = GPy.kern.RBF(2, variance=1.0, lengthscale=[1.0, 1.0], ARD=True)
|
||||
k2 = GPy.kern.RatQuad(
|
||||
|
|
@ -371,23 +371,19 @@ class Test(unittest.TestCase):
|
|||
m1_r = GPy.models.GPClassification.load_model(
|
||||
"temp_test_gp_classifier_with_data.json.zip"
|
||||
)
|
||||
self.assertTrue(
|
||||
type(m) == type(m1_r),
|
||||
"Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)),
|
||||
)
|
||||
assert type(m) == type(
|
||||
m1_r
|
||||
), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r))
|
||||
m2_r = GPy.models.GPClassification.load_model(
|
||||
"temp_test_gp_classifier_without_data.json.zip", (X, Y)
|
||||
)
|
||||
self.assertTrue(
|
||||
type(m) == type(m2_r),
|
||||
"Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)),
|
||||
)
|
||||
assert type(m) == type(m2_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)),
|
||||
os.remove("temp_test_gp_classifier_with_data.json.zip")
|
||||
os.remove("temp_test_gp_classifier_without_data.json.zip")
|
||||
|
||||
var = m.predict(X)[0]
|
||||
var1_r = m1_r.predict(X)[0]
|
||||
var2_r = m2_r.predict(X)[0]
|
||||
_var2_r = m2_r.predict(X)[0]
|
||||
np.testing.assert_array_equal(
|
||||
np.array(var).flatten(), np.array(var1_r).flatten()
|
||||
)
|
||||
|
|
@ -419,17 +415,11 @@ class Test(unittest.TestCase):
|
|||
m1_r = GPy.models.SparseGPClassification.load_model(
|
||||
"temp_test_sparse_gp_classifier_with_data.json.zip"
|
||||
)
|
||||
self.assertTrue(
|
||||
type(m) == type(m1_r),
|
||||
"Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)),
|
||||
)
|
||||
assert type(m) == type(m1_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r))
|
||||
m2_r = GPy.models.SparseGPClassification.load_model(
|
||||
"temp_test_sparse_gp_classifier_without_data.json.zip", (X, Y)
|
||||
)
|
||||
self.assertTrue(
|
||||
type(m) == type(m2_r),
|
||||
"Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)),
|
||||
)
|
||||
assert type(m) == type(m2_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)),
|
||||
os.remove("temp_test_sparse_gp_classifier_with_data.json.zip")
|
||||
os.remove("temp_test_sparse_gp_classifier_without_data.json.zip")
|
||||
|
||||
|
|
@ -442,8 +432,3 @@ class Test(unittest.TestCase):
|
|||
np.testing.assert_array_equal(
|
||||
np.array(var).flatten(), np.array(var1_r).flatten()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import sys;sys.argv = ['', 'Test.test_parameter_index_operations']
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue