pickling unified with __getstate__ and __setstate__

This commit is contained in:
Max Zwiessele 2013-06-25 18:01:18 +01:00
parent 1e06ca2d40
commit 05e8e75c58
5 changed files with 35 additions and 13 deletions

View file

@ -6,7 +6,7 @@ import numpy as np
import pylab as pb import pylab as pb
from .. import kern from .. import kern
from ..util.linalg import pdinv, mdot, tdot, dpotrs, dtrtrs from ..util.linalg import pdinv, mdot, tdot, dpotrs, dtrtrs
#from ..util.plot import gpplot, Tango # from ..util.plot import gpplot, Tango
from ..likelihoods import EP from ..likelihoods import EP
from gp_base import GPBase from gp_base import GPBase
@ -46,12 +46,12 @@ class GP(GPBase):
# the gradient of the likelihood wrt the covariance matrix # the gradient of the likelihood wrt the covariance matrix
if self.likelihood.YYT is None: if self.likelihood.YYT is None:
#alpha = np.dot(self.Ki, self.likelihood.Y) # alpha = np.dot(self.Ki, self.likelihood.Y)
alpha,_ = dpotrs(self.L, self.likelihood.Y,lower=1) alpha, _ = dpotrs(self.L, self.likelihood.Y, lower=1)
self.dL_dK = 0.5 * (tdot(alpha) - self.output_dim * self.Ki) self.dL_dK = 0.5 * (tdot(alpha) - self.output_dim * self.Ki)
else: else:
#tmp = mdot(self.Ki, self.likelihood.YYT, self.Ki) # tmp = mdot(self.Ki, self.likelihood.YYT, self.Ki)
tmp, _ = dpotrs(self.L, np.asfortranarray(self.likelihood.YYT), lower=1) tmp, _ = dpotrs(self.L, np.asfortranarray(self.likelihood.YYT), lower=1)
tmp, _ = dpotrs(self.L, np.asfortranarray(tmp.T), lower=1) tmp, _ = dpotrs(self.L, np.asfortranarray(tmp.T), lower=1)
self.dL_dK = 0.5 * (tmp - self.output_dim * self.Ki) self.dL_dK = 0.5 * (tmp - self.output_dim * self.Ki)
@ -72,7 +72,7 @@ class GP(GPBase):
""" """
self.likelihood.restart() self.likelihood.restart()
self.likelihood.fit_full(self.kern.K(self.X)) self.likelihood.fit_full(self.kern.K(self.X))
self._set_params(self._get_params()) # update the GP self._set_params(self._get_params()) # update the GP
def _model_fit_term(self): def _model_fit_term(self):
""" """
@ -81,7 +81,7 @@ class GP(GPBase):
if self.likelihood.YYT is None: if self.likelihood.YYT is None:
tmp, _ = dtrtrs(self.L, np.asfortranarray(self.likelihood.Y), lower=1) tmp, _ = dtrtrs(self.L, np.asfortranarray(self.likelihood.Y), lower=1)
return -0.5 * np.sum(np.square(tmp)) return -0.5 * np.sum(np.square(tmp))
#return -0.5 * np.sum(np.square(np.dot(self.Li, self.likelihood.Y))) # return -0.5 * np.sum(np.square(np.dot(self.Li, self.likelihood.Y)))
else: else:
return -0.5 * np.sum(np.multiply(self.Ki, self.likelihood.YYT)) return -0.5 * np.sum(np.multiply(self.Ki, self.likelihood.YYT))
@ -104,13 +104,13 @@ class GP(GPBase):
""" """
return np.hstack((self.kern.dK_dtheta(dL_dK=self.dL_dK, X=self.X), self.likelihood._gradients(partial=np.diag(self.dL_dK)))) return np.hstack((self.kern.dK_dtheta(dL_dK=self.dL_dK, X=self.X), self.likelihood._gradients(partial=np.diag(self.dL_dK))))
def _raw_predict(self, _Xnew, which_parts='all', full_cov=False,stop=False): def _raw_predict(self, _Xnew, which_parts='all', full_cov=False, stop=False):
""" """
Internal helper function for making predictions, does not account Internal helper function for making predictions, does not account
for normalization or likelihood for normalization or likelihood
""" """
Kx = self.kern.K(_Xnew,self.X,which_parts=which_parts).T Kx = self.kern.K(_Xnew, self.X, which_parts=which_parts).T
#KiKx = np.dot(self.Ki, Kx) # KiKx = np.dot(self.Ki, Kx)
KiKx, _ = dpotrs(self.L, np.asfortranarray(Kx), lower=1) KiKx, _ = dpotrs(self.L, np.asfortranarray(Kx), lower=1)
mu = np.dot(KiKx.T, self.likelihood.Y) mu = np.dot(KiKx.T, self.likelihood.Y)
if full_cov: if full_cov:

View file

@ -29,7 +29,7 @@ class GPBase(Model):
self._Xscale = np.ones((1, self.input_dim)) self._Xscale = np.ones((1, self.input_dim))
super(GPBase, self).__init__() super(GPBase, self).__init__()
#Model.__init__(self) # Model.__init__(self)
# All leaf nodes should call self._set_params(self._get_params()) at # All leaf nodes should call self._set_params(self._get_params()) at
# the end # the end
@ -57,7 +57,6 @@ class GPBase(Model):
self.num_data = state.pop() self.num_data = state.pop()
self.X = state.pop() self.X = state.pop()
Model.__setstate__(self, state) Model.__setstate__(self, state)
self._set_params(self._get_params())
def plot_f(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, full_cov=False, fignum=None, ax=None): def plot_f(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, full_cov=False, fignum=None, ax=None):
""" """

View file

@ -43,6 +43,28 @@ class kern(Parameterised):
Parameterised.__init__(self) Parameterised.__init__(self)
def __getstate__(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return Parameterised.__getstate__(self) + [self.parts,
self.Nparts,
self.num_params,
self.input_dim,
self.input_slices,
self.param_slices
]
def __setstate__(self, state):
self.param_slices = state.pop()
self.input_slices = state.pop()
self.input_dim = state.pop()
self.num_params = state.pop()
self.Nparts = state.pop()
self.parts = state.pop()
Parameterised.__setstate__(self, state)
def plot_ARD(self, fignum=None, ax=None, title=None): def plot_ARD(self, fignum=None, ax=None, title=None):
"""If an ARD kernel is present, it bar-plots the ARD parameters""" """If an ARD kernel is present, it bar-plots the ARD parameters"""

View file

@ -53,7 +53,7 @@ class BayesianGPLVM(SparseGP, GPLVM):
Get the current state of the class, Get the current state of the class,
here just all the indices, rest can get recomputed here just all the indices, rest can get recomputed
""" """
return [self.init] + SparseGP.__getstate__(self) return SparseGP.__getstate__(self) + [self.init]
def __setstate__(self, state): def __setstate__(self, state):
self.init = state.pop() self.init = state.pop()

View file

@ -85,7 +85,7 @@ class MRD(Model):
self.ensure_default_constraints() self.ensure_default_constraints()
def __getstate__(self): def __getstate__(self):
return [self.names, return Model.__getstate__(self) + [self.names,
self.bgplvms, self.bgplvms,
self.gref, self.gref,
self.nparams, self.nparams,
@ -105,6 +105,7 @@ class MRD(Model):
self.gref = state.pop() self.gref = state.pop()
self.bgplvms = state.pop() self.bgplvms = state.pop()
self.names = state.pop() self.names = state.pop()
Model.__setstate__(self, state)
@property @property
def X(self): def X(self):