Basic framework for serializing GPy models

This commit is contained in:
Moreno 2017-04-11 11:42:58 +01:00
parent d529da3e6c
commit e572bfb746
26 changed files with 828 additions and 64 deletions

View file

@ -43,6 +43,25 @@ class GPTransformation(object):
"""
raise NotImplementedError
def to_dict(self):
raise NotImplementedError
def _to_dict(self):
return {}
@staticmethod
def from_dict(input_dict):
import copy
input_dict = copy.deepcopy(input_dict)
link_class = input_dict.pop('class')
import GPy
link_class = eval(link_class)
return link_class._from_dict(link_class, input_dict)
@staticmethod
def _from_dict(link_class, input_dict):
return link_class(**input_dict)
class Identity(GPTransformation):
"""
.. math::
@ -62,6 +81,10 @@ class Identity(GPTransformation):
def d3transf_df3(self,f):
return np.zeros_like(f)
def to_dict(self):
input_dict = super(Identity, self)._to_dict()
input_dict["class"] = "GPy.likelihoods.link_functions.Identity"
return input_dict
class Probit(GPTransformation):
"""
@ -82,6 +105,11 @@ class Probit(GPTransformation):
def d3transf_df3(self,f):
return (safe_square(f)-1.)*std_norm_pdf(f)
def to_dict(self):
input_dict = super(Probit, self)._to_dict()
input_dict["class"] = "GPy.likelihoods.link_functions.Probit"
return input_dict
class Cloglog(GPTransformation):
"""