Return deserialized models with actual type instead of base type

This commit is contained in:
Keerthana Elango 2018-07-24 10:46:33 +01:00
parent 06441f583f
commit eca5806518
5 changed files with 54 additions and 40 deletions

View file

@ -237,7 +237,9 @@ class Test(unittest.TestCase):
m.save_model("temp_test_gp_classifier_with_data.json", compress=True, save_data=True)
m.save_model("temp_test_gp_classifier_without_data.json", compress=True, save_data=False)
m1_r = GPy.models.GPClassification.load_model("temp_test_gp_classifier_with_data.json.zip")
self.assertTrue(type(m) == type(m1_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)))
m2_r = GPy.models.GPClassification.load_model("temp_test_gp_classifier_without_data.json.zip", (X,Y))
self.assertTrue(type(m) == type(m2_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)))
os.remove("temp_test_gp_classifier_with_data.json.zip")
os.remove("temp_test_gp_classifier_without_data.json.zip")
@ -259,7 +261,9 @@ class Test(unittest.TestCase):
m.save_model("temp_test_sparse_gp_classifier_with_data.json", compress=True, save_data=True)
m.save_model("temp_test_sparse_gp_classifier_without_data.json", compress=True, save_data=False)
m1_r = GPy.models.SparseGPClassification.load_model("temp_test_sparse_gp_classifier_with_data.json.zip")
self.assertTrue(type(m) == type(m1_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m1_r)))
m2_r = GPy.models.SparseGPClassification.load_model("temp_test_sparse_gp_classifier_without_data.json.zip", (X,Y))
self.assertTrue(type(m) == type(m2_r), "Incorrect model type. Expected: {} Actual: {}".format(type(m), type(m2_r)))
os.remove("temp_test_sparse_gp_classifier_with_data.json.zip")
os.remove("temp_test_sparse_gp_classifier_without_data.json.zip")