pickling working for array-likes, but observers not yet connected back

This commit is contained in:
Max Zwiessele 2014-03-17 16:55:21 +00:00
parent 19dc7cecf4
commit 2ce3a93b3f
5 changed files with 35 additions and 21 deletions

View file

@ -1,7 +1,7 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt). # Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
__updated__ = '2013-12-16' __updated__ = '2014-03-17'
import numpy as np import numpy as np
from parameter_core import Observable from parameter_core import Observable
@ -18,7 +18,7 @@ class ObservableArray(np.ndarray, Observable):
if not isinstance(input_array, ObservableArray): if not isinstance(input_array, ObservableArray):
obj = np.atleast_1d(np.require(input_array, dtype=np.float64, requirements=['W', 'C'])).view(cls) obj = np.atleast_1d(np.require(input_array, dtype=np.float64, requirements=['W', 'C'])).view(cls)
else: obj = input_array else: obj = input_array
cls.__name__ = "ObservableArray\n " cls.__name__ = "ObsAr" # because of fixed printing of `array` in np printing
super(ObservableArray, obj).__init__(*a, **kw) super(ObservableArray, obj).__init__(*a, **kw)
return obj return obj
@ -30,6 +30,14 @@ class ObservableArray(np.ndarray, Observable):
def __array_wrap__(self, out_arr, context=None): def __array_wrap__(self, out_arr, context=None):
return out_arr.view(np.ndarray) return out_arr.view(np.ndarray)
def __reduce__(self):
func, args, state = np.ndarray.__reduce__(self)
return func, args, (state, Observable._getstate(self))
def __setstate__(self, state):
np.ndarray.__setstate__(self, state[0])
Observable._setstate(self, state[1])
def _s_not_empty(self, s): def _s_not_empty(self, s):
# this checks whether there is something picked by this slice. # this checks whether there is something picked by this slice.
return True return True

View file

@ -269,6 +269,8 @@ class Param(OptimizationHandlable, ObservableArray):
@property @property
def _ties_str(self): def _ties_str(self):
return [''] return ['']
def _ties_for(self, ravi):
return [['N/A']]*ravi.size
def __repr__(self, *args, **kwargs): def __repr__(self, *args, **kwargs):
name = "\033[1m{x:s}\033[0;0m:\n".format( name = "\033[1m{x:s}\033[0;0m:\n".format(
x=self.hierarchy_name()) x=self.hierarchy_name())
@ -312,7 +314,7 @@ class Param(OptimizationHandlable, ObservableArray):
ravi = self._raveled_index(filter_) ravi = self._raveled_index(filter_)
if constr_matrix is None: constr_matrix = self.constraints.properties_for(ravi) if constr_matrix is None: constr_matrix = self.constraints.properties_for(ravi)
if prirs is None: prirs = self.priors.properties_for(ravi) if prirs is None: prirs = self.priors.properties_for(ravi)
if ties is None: ties = [['N/A']]*self.size if ties is None: ties = self._ties_for(ravi)
ties = [' '.join(map(lambda x: x, t)) for t in ties] ties = [' '.join(map(lambda x: x, t)) for t in ties]
if lc is None: lc = self._max_len_names(constr_matrix, __constraints_name__) if lc is None: lc = self._max_len_names(constr_matrix, __constraints_name__)
if lx is None: lx = self._max_len_values() if lx is None: lx = self._max_len_values()

View file

@ -16,7 +16,7 @@ Observable Pattern for patameterization
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
import numpy as np import numpy as np
__updated__ = '2014-03-14' __updated__ = '2014-03-17'
class HierarchyError(Exception): class HierarchyError(Exception):
""" """
@ -56,7 +56,7 @@ class InterfacePickleFunctions(object):
""" """
raise NotImplementedError, "To be able to use pickling you need to implement this method" raise NotImplementedError, "To be able to use pickling you need to implement this method"
class Pickleable(object): class Pickleable(InterfacePickleFunctions):
""" """
Make an object pickleable (See python doc 'pickling'). Make an object pickleable (See python doc 'pickling').
@ -95,7 +95,7 @@ class Pickleable(object):
def _has_get_set_state(self): def _has_get_set_state(self):
return '_getstate' in vars(self.__class__) and '_setstate' in vars(self.__class__) return '_getstate' in vars(self.__class__) and '_setstate' in vars(self.__class__)
class Observable(InterfacePickleFunctions): class Observable(Pickleable):
""" """
Observable pattern for parameterization. Observable pattern for parameterization.
@ -155,6 +155,7 @@ class Observable(InterfacePickleFunctions):
def _getstate(self): def _getstate(self):
return [self._observer_callables_] return [self._observer_callables_]
def _setstate(self, state): def _setstate(self, state):
self._observer_callables_ = state.pop() self._observer_callables_ = state.pop()

View file

@ -21,8 +21,6 @@ class ParameterizedTest(Parameterized):
params_changed_count = _trigger_start params_changed_count = _trigger_start
def parameters_changed(self): def parameters_changed(self):
self.params_changed_count += 1 self.params_changed_count += 1
def _set_params(self, params, trigger_parent=True):
Parameterized._set_params(self, params, trigger_parent=trigger_parent)
class Test(unittest.TestCase): class Test(unittest.TestCase):

View file

@ -108,7 +108,7 @@ class ParameterizedTest(unittest.TestCase):
self.assertEqual(self.param.constraints._offset, 3) self.assertEqual(self.param.constraints._offset, 3)
def test_fixing_randomize(self): def test_fixing_randomize(self):
self.white.fix(warning=False) self.white.fix(warning=True)
val = float(self.test1.white.variance) val = float(self.test1.white.variance)
self.test1.randomize() self.test1.randomize()
self.assertEqual(val, self.white.variance) self.assertEqual(val, self.white.variance)
@ -119,6 +119,11 @@ class ParameterizedTest(unittest.TestCase):
self.testmodel.randomize() self.testmodel.randomize()
self.assertEqual(val, self.testmodel.kern.lengthscale) self.assertEqual(val, self.testmodel.kern.lengthscale)
def test_printing(self):
print self.test1
print self.param
print self.test1['']
if __name__ == "__main__": if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.test_add_parameter'] #import sys;sys.argv = ['', 'Test.test_add_parameter']
unittest.main() unittest.main()