Basic framework for serializing GPy models

This commit is contained in:
Moreno 2017-04-11 11:42:58 +01:00
parent d529da3e6c
commit e572bfb746
26 changed files with 828 additions and 64 deletions

View file

@ -33,6 +33,27 @@ class _Norm(object):
"""
raise NotImplementedError
def to_dict(self):
raise NotImplementedError
def _to_dict(self):
input_dict = {}
return input_dict
@staticmethod
def from_dict(input_dict):
import copy
input_dict = copy.deepcopy(input_dict)
normalizer_class = input_dict.pop('class')
import GPy
normalizer_class = eval(normalizer_class)
return normalizer_class._from_dict(normalizer_class, input_dict)
@staticmethod
def _from_dict(normalizer_class, input_dict):
return normalizer_class(**input_dict)
class Standardize(_Norm):
def __init__(self):
self.mean = None
@ -50,9 +71,26 @@ class Standardize(_Norm):
def scaled(self):
return self.mean is not None
def to_dict(self):
input_dict = super(Standardize, self)._to_dict()
input_dict["class"] = "GPy.util.normalizer.Standardize"
if self.mean is not None:
input_dict["mean"] = self.mean.tolist()
input_dict["std"] = self.std.tolist()
return input_dict
@staticmethod
def _from_dict(kernel_class, input_dict):
s = Standardize()
if "mean" in input_dict:
s.mean = np.array(input_dict["mean"])
if "std" in input_dict:
s.std = np.array(input_dict["std"])
return s
# Inverse variance to be implemented, disabling for now
# If someone in the future want to implement this,
# we need to implement the inverse variance for
# we need to implement the inverse variance for
# normalization. This means, we need to know the factor
# for the variance to be multiplied to the variance in
# normalized space. This is easy to compute for standardization
@ -71,7 +109,7 @@ class Standardize(_Norm):
# def inverse_mean(self, X):
# return (X + .5) * (self.ymax - self.ymin) + self.ymin
# def inverse_variance(self, var):
#
#
# return (var*(self.std**2))
# def scaled(self):
# return (self.ymin is not None) and (self.ymax is not None)
# return (self.ymin is not None) and (self.ymax is not None)