Return deserialized models with actual type instead of base type

This commit is contained in:
Keerthana Elango 2018-07-24 10:46:33 +01:00
parent 06441f583f
commit eca5806518
5 changed files with 54 additions and 40 deletions

View file

@ -144,14 +144,14 @@ class GP(Model):
return input_dict
@staticmethod
def _build_from_input_dict(input_dict, data=None):
def _format_input_dict(input_dict, data=None):
import GPy
import numpy as np
if (input_dict['X'] is None) or (input_dict['Y'] is None):
assert(data is not None)
input_dict["X"], input_dict["Y"] = np.array(data[0]), np.array(data[1])
elif data is not None:
print("WARNING: The model has been saved with X,Y! The original values are being overriden!")
warnings.warn("WARNING: The model has been saved with X,Y! The original values are being overridden!")
input_dict["X"], input_dict["Y"] = np.array(data[0]), np.array(data[1])
else:
input_dict["X"], input_dict["Y"] = np.array(input_dict['X']), np.array(input_dict['Y'])
@ -173,6 +173,11 @@ class GP(Model):
input_dict["normalizer"] = GPy.util.normalizer._Norm.from_dict(normalizer)
else:
input_dict["normalizer"] = normalizer
return input_dict
@staticmethod
def _build_from_input_dict(input_dict, data=None):
input_dict = GP._format_input_dict(input_dict, data)
return GP(**input_dict)
def save_model(self, output_filename, compress=True, save_data=True):

View file

@ -130,37 +130,13 @@ class SparseGP(GP):
input_dict["Z"] = self.Z.tolist()
return input_dict
@staticmethod
def _format_input_dict(input_dict, data=None):
input_dict = GP._format_input_dict(input_dict, data)
input_dict["Z"] = np.array(input_dict["Z"])
return input_dict
@staticmethod
def _build_from_input_dict(input_dict, data=None):
# Called from the from_dict method.
import GPy
if (input_dict['X'] is None) or (input_dict['Y'] is None):
if data is None:
raise ValueError("The model was serialized whithout the training data. 'data' must be not None!")
input_dict["X"], input_dict["Y"] = np.array(data[0]), np.array(data[1])
elif data is not None:
print("WARNING: The model has been saved with X,Y! The original values are being overriden!")
input_dict["X"], input_dict["Y"] = np.array(data[0]), np.array(data[1])
else:
input_dict["X"], input_dict["Y"] = np.array(input_dict['X']), np.array(input_dict['Y'])
input_dict["Z"] = np.array(input_dict['Z'])
input_dict["kernel"] = GPy.kern.Kern.from_dict(input_dict["kernel"])
input_dict["likelihood"] = GPy.likelihoods.likelihood.Likelihood.from_dict(input_dict["likelihood"])
mean_function = input_dict.get("mean_function")
if mean_function is not None:
input_dict["mean_function"] = GPy.core.mapping.Mapping.from_dict(mean_function)
else:
input_dict["mean_function"] = mean_function
input_dict["inference_method"] = GPy.inference.latent_function_inference.LatentFunctionInference.from_dict(input_dict["inference_method"])
#FIXME: Assumes the Y_metadata is serializable. We should create a Metadata class
Y_metadata = input_dict.get("Y_metadata")
input_dict["Y_metadata"] = Y_metadata
normalizer = input_dict.get("normalizer")
if normalizer is not None:
input_dict["normalizer"] = GPy.util.normalizer._Norm.from_dict(normalizer)
else:
input_dict["normalizer"] = normalizer
input_dict = SparseGP._format_input_dict(input_dict, data)
return SparseGP(**input_dict)