mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
Sparse GP serialization
This commit is contained in:
parent
d85c9d5379
commit
7b2af57aee
4 changed files with 146 additions and 52 deletions
|
|
@ -117,3 +117,42 @@ class SparseGP(GP):
|
|||
self.Z.gradient = self.kern.gradients_X(self.grad_dict['dL_dKmm'], self.Z)
|
||||
self.Z.gradient += self.kern.gradients_X(self.grad_dict['dL_dKnm'].T, self.Z, self.X)
|
||||
self._Zgrad = self.Z.gradient.copy()
|
||||
|
||||
def to_dict(self, save_data=True):
|
||||
input_dict = super(SparseGP, self).to_dict(save_data)
|
||||
input_dict["class"] = "GPy.core.SparseGP"
|
||||
input_dict["Z"] = self.Z.tolist()
|
||||
return input_dict
|
||||
|
||||
@staticmethod
|
||||
def _from_dict(input_dict, data=None):
|
||||
import GPy
|
||||
if (input_dict['X'] is None) or (input_dict['Y'] is None):
|
||||
assert(data is not None)
|
||||
input_dict["X"], input_dict["Y"] = np.array(data[0]), np.array(data[1])
|
||||
elif data is not None:
|
||||
print("WARNING: The model has been saved with X,Y! The original values are being overriden!")
|
||||
input_dict["X"], input_dict["Y"] = np.array(data[0]), np.array(data[1])
|
||||
else:
|
||||
input_dict["X"], input_dict["Y"] = np.array(input_dict['X']), np.array(input_dict['Y'])
|
||||
|
||||
input_dict["Z"] = np.array(input_dict['Z'])
|
||||
input_dict["kernel"] = GPy.kern.Kern.from_dict(input_dict["kernel"])
|
||||
input_dict["likelihood"] = GPy.likelihoods.likelihood.Likelihood.from_dict(input_dict["likelihood"])
|
||||
mean_function = input_dict.get("mean_function")
|
||||
if mean_function is not None:
|
||||
input_dict["mean_function"] = GPy.core.mapping.Mapping.from_dict(mean_function)
|
||||
else:
|
||||
input_dict["mean_function"] = mean_function
|
||||
input_dict["inference_method"] = GPy.inference.latent_function_inference.LatentFunctionInference.from_dict(input_dict["inference_method"])
|
||||
|
||||
#FIXME: Assumes the Y_metadata is serializable. We should create a Metadata class
|
||||
Y_metadata = input_dict.get("Y_metadata")
|
||||
input_dict["Y_metadata"] = Y_metadata
|
||||
|
||||
normalizer = input_dict.get("normalizer")
|
||||
if normalizer is not None:
|
||||
input_dict["normalizer"] = GPy.util.normalizer._Norm.from_dict(normalizer)
|
||||
else:
|
||||
input_dict["normalizer"] = normalizer
|
||||
return SparseGP(**input_dict)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue