getstate > _getstate and setstate > _setstate

This commit is contained in:
Max Zwiessele 2014-01-24 15:48:23 +00:00
parent e128059377
commit cec766b723
17 changed files with 73 additions and 71 deletions

View file

@ -309,11 +309,11 @@ class GP(Model):
def getstate(self): def _getstate(self):
""" """
Get the current state of the class, here we return everything that is needed to recompute the model. Get the current state of the class, here we return everything that is needed to recompute the model.
""" """
return Model.getstate(self) + [self.X, return Model._getstate(self) + [self.X,
self.num_data, self.num_data,
self.input_dim, self.input_dim,
self.kern, self.kern,
@ -323,7 +323,7 @@ class GP(Model):
self._Xscale, self._Xscale,
] ]
def setstate(self, state): def _setstate(self, state):
self._Xscale = state.pop() self._Xscale = state.pop()
self._Xoffset = state.pop() self._Xoffset = state.pop()
self.output_dim = state.pop() self.output_dim = state.pop()
@ -332,6 +332,6 @@ class GP(Model):
self.input_dim = state.pop() self.input_dim = state.pop()
self.num_data = state.pop() self.num_data = state.pop()
self.X = state.pop() self.X = state.pop()
Model.setstate(self, state) Model._setstate(self, state)

View file

@ -40,7 +40,7 @@ class Model(Parameterized):
return g return g
raise NotImplementedError, "this needs to be implemented to use the model class" raise NotImplementedError, "this needs to be implemented to use the model class"
def getstate(self): def _getstate(self):
""" """
Get the current state of the class. Get the current state of the class.
Inherited from Parameterized, so add those parameters to the state Inherited from Parameterized, so add those parameters to the state
@ -48,24 +48,24 @@ class Model(Parameterized):
:return: list of states from the model. :return: list of states from the model.
""" """
return Parameterized.getstate(self) + \ return Parameterized._getstate(self) + \
[self.priors, self.optimization_runs, [self.priors, self.optimization_runs,
self.sampling_runs, self.preferred_optimizer] self.sampling_runs, self.preferred_optimizer]
def setstate(self, state): def _setstate(self, state):
""" """
set state from previous call to getstate set state from previous call to _getstate
call Parameterized with the rest of the state call Parameterized with the rest of the state
:param state: the state of the model. :param state: the state of the model.
:type state: list as returned from getstate. :type state: list as returned from _getstate.
""" """
self.preferred_optimizer = state.pop() self.preferred_optimizer = state.pop()
self.sampling_runs = state.pop() self.sampling_runs = state.pop()
self.optimization_runs = state.pop() self.optimization_runs = state.pop()
self.priors = state.pop() self.priors = state.pop()
Parameterized.setstate(self, state) Parameterized._setstate(self, state)
def set_prior(self, regexp, what): def set_prior(self, regexp, what):
""" """

View file

@ -403,7 +403,7 @@ class Param(ObservableArray, Constrainable):
x=self.name_hirarchical) x=self.name_hirarchical)
return name + super(Param, self).__repr__(*args,**kwargs) return name + super(Param, self).__repr__(*args,**kwargs)
def _ties_for(self, rav_index): def _ties_for(self, rav_index):
size = sum(p.size for p in self._tied_to_) #size = sum(p.size for p in self._tied_to_)
ties = numpy.empty(shape=(len(self._tied_to_), numpy.size(rav_index)), dtype=Param) ties = numpy.empty(shape=(len(self._tied_to_), numpy.size(rav_index)), dtype=Param)
for i, tied_to in enumerate(self._tied_to_): for i, tied_to in enumerate(self._tied_to_):
for t, ind in tied_to._tied_to_me_.iteritems(): for t, ind in tied_to._tied_to_me_.iteritems():

View file

@ -21,7 +21,7 @@ class Observable(object):
[callble(self) for callble in self._observers_.itervalues()] [callble(self) for callble in self._observers_.itervalues()]
class Pickleable(object): class Pickleable(object):
def getstate(self): def _getstate(self):
""" """
Returns the state of this class in a memento pattern. Returns the state of this class in a memento pattern.
The state must be a list-like structure of all the fields The state must be a list-like structure of all the fields
@ -30,13 +30,13 @@ class Pickleable(object):
See python doc "pickling" (`__getstate__` and `__setstate__`) for details. See python doc "pickling" (`__getstate__` and `__setstate__`) for details.
""" """
raise NotImplementedError, "To be able to use pickling you need to implement this method" raise NotImplementedError, "To be able to use pickling you need to implement this method"
def setstate(self, state): def _setstate(self, state):
""" """
Set the state (memento pattern) of this class to the given state. Set the state (memento pattern) of this class to the given state.
Usually this is just the counterpart to getstate, such that Usually this is just the counterpart to _getstate, such that
an object is a copy of another when calling an object is a copy of another when calling
copy = <classname>.__new__(*args,**kw).setstate(<to_be_copied>.getstate()) copy = <classname>.__new__(*args,**kw)._setstate(<to_be_copied>._getstate())
See python doc "pickling" (`__getstate__` and `__setstate__`) for details. See python doc "pickling" (`__getstate__` and `__setstate__`) for details.
""" """
@ -54,7 +54,7 @@ class Parentable(object):
self._highest_parent_ = highest_parent self._highest_parent_ = highest_parent
def has_parent(self): def has_parent(self):
return self._direct_parent_ is not None return self._direct_parent_ is not None and self._highest_parent_ is not None
class Nameable(Parentable): class Nameable(Parentable):
_name = None _name = None
@ -96,6 +96,7 @@ class Constrainable(Nameable):
self._add_constrain(p, transform, warning) self._add_constrain(p, transform, warning)
if update: if update:
self.parameters_changed() self.parameters_changed()
def constrain_positive(self, warning=True): def constrain_positive(self, warning=True):
""" """
:param warning: print a warning if re-constraining parameters. :param warning: print a warning if re-constraining parameters.

View file

@ -218,7 +218,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
p._direct_parent_ = self p._direct_parent_ = self
p._parent_index_ = i p._parent_index_ = i
i += 1 i += 1
for pi in p.flattened_parameters: for pi in p._parameters_:
pi._highest_parent_ = self pi._highest_parent_ = self
not_unique = [] not_unique = []
sizes.append(p.size+sizes[-1]) sizes.append(p.size+sizes[-1])
@ -233,6 +233,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
elif not (pname in not_unique): elif not (pname in not_unique):
self.__dict__[pname] = p self.__dict__[pname] = p
self._added_names_.add(pname) self._added_names_.add(pname)
#=========================================================================== #===========================================================================
# Pickling operations # Pickling operations
#=========================================================================== #===========================================================================
@ -253,24 +254,24 @@ class Parameterized(Constrainable, Pickleable, Observable):
return copy.deepcopy(self) return copy.deepcopy(self)
def __getstate__(self): def __getstate__(self):
if self._has_get_set_state(): if self._has_get_set_state():
return self.getstate() return self._getstate()
return self.__dict__ return self.__dict__
def __setstate__(self, state): def __setstate__(self, state):
if self._has_get_set_state(): if self._has_get_set_state():
self.setstate(state) # set state self._setstate(state) # set state
#self._set_params(self._get_params()) # restore all values #self._set_params(self._get_params()) # restore all values
return return
self.__dict__ = state self.__dict__ = state
def _has_get_set_state(self): def _has_get_set_state(self):
return 'getstate' in vars(self.__class__) and 'setstate' in vars(self.__class__) return '_getstate' in vars(self.__class__) and '_setstate' in vars(self.__class__)
def getstate(self): def _getstate(self):
""" """
Get the current state of the class, Get the current state of the class,
here just all the indices, rest can get recomputed here just all the indices, rest can get recomputed
For inheriting from Parameterized: For inheriting from Parameterized:
Allways append the state of the inherited object Allways append the state of the inherited object
and call down to the inherited object in setstate!! and call down to the inherited object in _setstate!!
""" """
return [ return [
self._fixes_, self._fixes_,
@ -280,7 +281,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
self._added_names_, self._added_names_,
] ]
def setstate(self, state): def _setstate(self, state):
self._added_names_ = state.pop() self._added_names_ = state.pop()
self._name = state.pop() self._name = state.pop()
self._parameters_ = state.pop() self._parameters_ = state.pop()

View file

@ -183,20 +183,20 @@ class SparseGP(GP):
else: else:
raise NotImplementedError, "Cannot define a frame with more than two input dimensions" raise NotImplementedError, "Cannot define a frame with more than two input dimensions"
def getstate(self): def _getstate(self):
""" """
Get the current state of the class, Get the current state of the class,
here just all the indices, rest can get recomputed here just all the indices, rest can get recomputed
""" """
return GP.getstate(self) + [self.Z, return GP._getstate(self) + [self.Z,
self.num_inducing, self.num_inducing,
self.has_uncertain_inputs, self.has_uncertain_inputs,
self.X_variance] self.X_variance]
def setstate(self, state): def _setstate(self, state):
self.X_variance = state.pop() self.X_variance = state.pop()
self.has_uncertain_inputs = state.pop() self.has_uncertain_inputs = state.pop()
self.num_inducing = state.pop() self.num_inducing = state.pop()
self.Z = state.pop() self.Z = state.pop()
GP.setstate(self, state) GP._setstate(self, state)

View file

@ -90,9 +90,9 @@ class SVIGP(GP):
self._param_steplength_trace = [] self._param_steplength_trace = []
self._vb_steplength_trace = [] self._vb_steplength_trace = []
def getstate(self): def _getstate(self):
steplength_params = [self.hbar_t, self.tau_t, self.gbar_t, self.gbar_t1, self.gbar_t2, self.hbar_tp, self.tau_tp, self.gbar_tp, self.adapt_param_steplength, self.adapt_vb_steplength, self.vb_steplength, self.param_steplength] steplength_params = [self.hbar_t, self.tau_t, self.gbar_t, self.gbar_t1, self.gbar_t2, self.hbar_tp, self.tau_tp, self.gbar_tp, self.adapt_param_steplength, self.adapt_vb_steplength, self.vb_steplength, self.param_steplength]
return GP.getstate(self) + \ return GP._getstate(self) + \
[self.get_vb_param(), [self.get_vb_param(),
self.Z, self.Z,
self.num_inducing, self.num_inducing,
@ -116,7 +116,7 @@ class SVIGP(GP):
self.iterations self.iterations
] ]
def setstate(self, state): def _setstate(self, state):
self.iterations = state.pop() self.iterations = state.pop()
self._permutation = state.pop() self._permutation = state.pop()
self.Y = state.pop() self.Y = state.pop()
@ -139,7 +139,7 @@ class SVIGP(GP):
self.num_inducing = state.pop() self.num_inducing = state.pop()
self.Z = state.pop() self.Z = state.pop()
vb_param = state.pop() vb_param = state.pop()
GP.setstate(self, state) GP._setstate(self, state)
self.set_vb_param(vb_param) self.set_vb_param(vb_param)
def _compute_kernel_matrices(self): def _compute_kernel_matrices(self):

View file

@ -50,25 +50,25 @@ class kern(Parameterized):
def connect_input(self, Xparam): def connect_input(self, Xparam):
[p.connect_input(Xparam) for p in self._parameters_] [p.connect_input(Xparam) for p in self._parameters_]
def getstate(self): def _getstate(self):
""" """
Get the current state of the class, Get the current state of the class,
here just all the indices, rest can get recomputed here just all the indices, rest can get recomputed
""" """
return Parameterized.getstate(self) + [#self._parameters_, return Parameterized._getstate(self) + [#self._parameters_,
#self.num_params, #self.num_params,
self.input_dim, self.input_dim,
self.input_slices, self.input_slices,
self._param_slices_ self._param_slices_
] ]
def setstate(self, state): def _setstate(self, state):
self._param_slices_ = state.pop() self._param_slices_ = state.pop()
self.input_slices = state.pop() self.input_slices = state.pop()
self.input_dim = state.pop() self.input_dim = state.pop()
#self.num_params = state.pop() #self.num_params = state.pop()
#self._parameters_ = state.pop() #self._parameters_ = state.pop()
Parameterized.setstate(self, state) Parameterized._setstate(self, state)
def plot_ARD(self, fignum=None, ax=None, title='', legend=False): def plot_ARD(self, fignum=None, ax=None, title='', legend=False):

View file

@ -51,17 +51,17 @@ class BayesianGPLVM(SparseGP, GPLVM):
self.add_parameter(self.q, gradient=self._dbound_dmuS, index=0) self.add_parameter(self.q, gradient=self._dbound_dmuS, index=0)
self.ensure_default_constraints() self.ensure_default_constraints()
def getstate(self): def _getstate(self):
""" """
Get the current state of the class, Get the current state of the class,
here just all the indices, rest can get recomputed here just all the indices, rest can get recomputed
""" """
return SparseGP.getstate(self) + [self.init] return SparseGP._getstate(self) + [self.init]
def setstate(self, state): def _setstate(self, state):
self._const_jitter = None self._const_jitter = None
self.init = state.pop() self.init = state.pop()
SparseGP.setstate(self, state) SparseGP._setstate(self, state)
# def _get_param_names(self): # def _get_param_names(self):
# X_names = sum([['X_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], []) # X_names = sum([['X_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], [])

View file

@ -30,8 +30,8 @@ class GPRegression(GP):
super(GPRegression, self).__init__(X, Y, kernel, likelihood, name='gp_regression') super(GPRegression, self).__init__(X, Y, kernel, likelihood, name='gp_regression')
def getstate(self): def _getstate(self):
return GP.getstate(self) return GP._getstate(self)
def setstate(self, state): def _setstate(self, state):
return GP.setstate(self, state) return GP._setstate(self, state)

View file

@ -42,11 +42,11 @@ class GPLVM(GP):
Xr[:PC.shape[0], :PC.shape[1]] = PC Xr[:PC.shape[0], :PC.shape[1]] = PC
return Xr return Xr
def getstate(self): def _getstate(self):
return GP.getstate(self) return GP._getstate(self)
def setstate(self, state): def _setstate(self, state):
GP.setstate(self, state) GP._setstate(self, state)
# def _get_param_names(self): # def _get_param_names(self):
# return sum([['X_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], []) + GP._get_param_names(self) # return sum([['X_%i_%i' % (n, q) for q in range(self.input_dim)] for n in range(self.num_data)], []) + GP._get_param_names(self)

View file

@ -79,8 +79,8 @@ class MRD(Model):
Model.__init__(self) Model.__init__(self)
self.ensure_default_constraints() self.ensure_default_constraints()
def getstate(self): def _getstate(self):
return Model.getstate(self) + [self.names, return Model._getstate(self) + [self.names,
self.bgplvms, self.bgplvms,
self.gref, self.gref,
self.nparams, self.nparams,
@ -90,7 +90,7 @@ class MRD(Model):
self.NQ, self.NQ,
self.MQ] self.MQ]
def setstate(self, state): def _setstate(self, state):
self.MQ = state.pop() self.MQ = state.pop()
self.NQ = state.pop() self.NQ = state.pop()
self.num_data = state.pop() self.num_data = state.pop()
@ -100,7 +100,7 @@ class MRD(Model):
self.gref = state.pop() self.gref = state.pop()
self.bgplvms = state.pop() self.bgplvms = state.pop()
self.names = state.pop() self.names = state.pop()
Model.setstate(self, state) Model._setstate(self, state)
@property @property
def X(self): def X(self):

View file

@ -46,11 +46,11 @@ class SparseGPClassification(SparseGP):
SparseGP.__init__(self, X, likelihood, kernel, Z=Z, normalize_X=normalize_X) SparseGP.__init__(self, X, likelihood, kernel, Z=Z, normalize_X=normalize_X)
self.ensure_default_constraints() self.ensure_default_constraints()
def getstate(self): def _getstate(self):
return SparseGP.getstate(self) return SparseGP._getstate(self)
def setstate(self, state): def _setstate(self, state):
return SparseGP.setstate(self, state) return SparseGP._setstate(self, state)
pass pass

View file

@ -49,11 +49,11 @@ class SparseGPRegression(SparseGP):
self.ensure_default_constraints() self.ensure_default_constraints()
pass pass
def getstate(self): def _getstate(self):
return SparseGP.getstate(self) return SparseGP._getstate(self)
def setstate(self, state): def _setstate(self, state):
return SparseGP.setstate(self, state) return SparseGP._setstate(self, state)
pass pass

View file

@ -28,12 +28,12 @@ class SparseGPLVM(SparseGPRegression, GPLVM):
SparseGPRegression.__init__(self, X, Y, kernel=kernel, num_inducing=num_inducing) SparseGPRegression.__init__(self, X, Y, kernel=kernel, num_inducing=num_inducing)
self.ensure_default_constraints() self.ensure_default_constraints()
def getstate(self): def _getstate(self):
return SparseGPRegression.getstate(self) return SparseGPRegression._getstate(self)
def setstate(self, state): def _setstate(self, state):
return SparseGPRegression.setstate(self, state) return SparseGPRegression._setstate(self, state)
def _get_param_names(self): def _get_param_names(self):

View file

@ -43,10 +43,10 @@ class SVIGPRegression(SVIGP):
SVIGP.__init__(self, X, likelihood, kernel, Z, q_u=q_u, batchsize=batchsize) SVIGP.__init__(self, X, likelihood, kernel, Z, q_u=q_u, batchsize=batchsize)
self.load_batch() self.load_batch()
def getstate(self): def _getstate(self):
return GPBase.getstate(self) return GPBase._getstate(self)
def setstate(self, state): def _setstate(self, state):
return GPBase.setstate(self, state) return GPBase._setstate(self, state)

View file

@ -30,12 +30,12 @@ class WarpedGP(GP):
GP.__init__(self, X, likelihood, kernel, normalize_X=normalize_X) GP.__init__(self, X, likelihood, kernel, normalize_X=normalize_X)
self._set_params(self._get_params()) self._set_params(self._get_params())
def getstate(self): def _getstate(self):
return GP.getstate(self) return GP._getstate(self)
def setstate(self, state): def _setstate(self, state):
return GP.setstate(self, state) return GP._setstate(self, state)
def _scale_data(self, Y): def _scale_data(self, Y):