diff --git a/GPy/core/parameterization/array_core.py b/GPy/core/parameterization/array_core.py index e3a5b137..6920e894 100644 --- a/GPy/core/parameterization/array_core.py +++ b/GPy/core/parameterization/array_core.py @@ -1,7 +1,7 @@ # Copyright (c) 2012, GPy authors (see AUTHORS.txt). # Licensed under the BSD 3-clause license (see LICENSE.txt) -__updated__ = '2013-12-16' +__updated__ = '2014-03-17' import numpy as np from parameter_core import Observable @@ -18,7 +18,7 @@ class ObservableArray(np.ndarray, Observable): if not isinstance(input_array, ObservableArray): obj = np.atleast_1d(np.require(input_array, dtype=np.float64, requirements=['W', 'C'])).view(cls) else: obj = input_array - cls.__name__ = "ObservableArray\n " + cls.__name__ = "ObsAr" # because of fixed printing of `array` in np printing super(ObservableArray, obj).__init__(*a, **kw) return obj @@ -30,6 +30,14 @@ class ObservableArray(np.ndarray, Observable): def __array_wrap__(self, out_arr, context=None): return out_arr.view(np.ndarray) + def __reduce__(self): + func, args, state = np.ndarray.__reduce__(self) + return func, args, (state, Observable._getstate(self)) + + def __setstate__(self, state): + np.ndarray.__setstate__(self, state[0]) + Observable._setstate(self, state[1]) + def _s_not_empty(self, s): # this checks whether there is something picked by this slice. return True diff --git a/GPy/core/parameterization/param.py b/GPy/core/parameterization/param.py index 2ede8436..ed394806 100644 --- a/GPy/core/parameterization/param.py +++ b/GPy/core/parameterization/param.py @@ -269,6 +269,8 @@ class Param(OptimizationHandlable, ObservableArray): @property def _ties_str(self): return [''] + def _ties_for(self, ravi): + return [['N/A']]*ravi.size def __repr__(self, *args, **kwargs): name = "\033[1m{x:s}\033[0;0m:\n".format( x=self.hierarchy_name()) @@ -312,7 +314,7 @@ class Param(OptimizationHandlable, ObservableArray): ravi = self._raveled_index(filter_) if constr_matrix is None: constr_matrix = self.constraints.properties_for(ravi) if prirs is None: prirs = self.priors.properties_for(ravi) - if ties is None: ties = [['N/A']]*self.size + if ties is None: ties = self._ties_for(ravi) ties = [' '.join(map(lambda x: x, t)) for t in ties] if lc is None: lc = self._max_len_names(constr_matrix, __constraints_name__) if lx is None: lx = self._max_len_values() diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index f58143bd..0aab890c 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -16,7 +16,7 @@ Observable Pattern for patameterization from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED import numpy as np -__updated__ = '2014-03-14' +__updated__ = '2014-03-17' class HierarchyError(Exception): """ @@ -56,7 +56,7 @@ class InterfacePickleFunctions(object): """ raise NotImplementedError, "To be able to use pickling you need to implement this method" -class Pickleable(object): +class Pickleable(InterfacePickleFunctions): """ Make an object pickleable (See python doc 'pickling'). @@ -95,7 +95,7 @@ class Pickleable(object): def _has_get_set_state(self): return '_getstate' in vars(self.__class__) and '_setstate' in vars(self.__class__) -class Observable(InterfacePickleFunctions): +class Observable(Pickleable): """ Observable pattern for parameterization. @@ -155,6 +155,7 @@ class Observable(InterfacePickleFunctions): def _getstate(self): return [self._observer_callables_] + def _setstate(self, state): self._observer_callables_ = state.pop() diff --git a/GPy/testing/observable_tests.py b/GPy/testing/observable_tests.py index ebda1630..f8be4a48 100644 --- a/GPy/testing/observable_tests.py +++ b/GPy/testing/observable_tests.py @@ -8,7 +8,7 @@ from GPy.core.parameterization.parameterized import Parameterized from GPy.core.parameterization.param import Param import numpy -# One trigger in init +# One trigger in init _trigger_start = -1 class ParamTestParent(Parameterized): @@ -21,11 +21,9 @@ class ParameterizedTest(Parameterized): params_changed_count = _trigger_start def parameters_changed(self): self.params_changed_count += 1 - def _set_params(self, params, trigger_parent=True): - Parameterized._set_params(self, params, trigger_parent=trigger_parent) class Test(unittest.TestCase): - + def setUp(self): self.parent = ParamTestParent('test parent') self.par = ParameterizedTest('test model') @@ -41,12 +39,12 @@ class Test(unittest.TestCase): self.parent.add_parameter(self.par) self.parent.add_parameter(self.par2) - + self._observer_triggered = None self._trigger_count = 0 self._first = None self._second = None - + def _trigger(self, which): self._observer_triggered = float(which) self._trigger_count += 1 @@ -54,18 +52,18 @@ class Test(unittest.TestCase): self._second = self._trigger else: self._first = self._trigger - + def _trigger_priority(self, which): if self._first is not None: self._second = self._trigger_priority else: self._first = self._trigger_priority - + def test_observable(self): self.par.add_observer(self, self._trigger, -1) self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet') self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') - + self.p[0,1] = 3 # trigger observers self.assertEqual(self._observer_triggered, 3, 'observer should have triggered') self.assertEqual(self._trigger_count, 1, 'observer should have triggered once') @@ -78,14 +76,14 @@ class Test(unittest.TestCase): self.assertEqual(self._trigger_count, 1, 'observer should have triggered once') self.assertEqual(self.par.params_changed_count, 2, 'params changed second') self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') - + self.par.add_observer(self, self._trigger, -1) self.p[2,1] = 4 self.assertEqual(self._observer_triggered, 4, 'observer should have triggered') self.assertEqual(self._trigger_count, 2, 'observer should have triggered once') self.assertEqual(self.par.params_changed_count, 3, 'params changed second') self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') - + self.par.remove_observer(self, self._trigger) self.p[0,1] = 3 self.assertEqual(self._observer_triggered, 4, 'observer should not have triggered') @@ -99,7 +97,7 @@ class Test(unittest.TestCase): self.par._trigger_params_changed() self.assertEqual(self.par.params_changed_count, 1, 'now params changed') self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count) - + self.par._param_array_[:] = 2 self.par._trigger_params_changed() self.assertEqual(self.par.params_changed_count, 2, 'now params changed') @@ -125,13 +123,13 @@ class Test(unittest.TestCase): self.par.remove_observer(self) self._first = self._second = None - + self.par.add_observer(self, self._trigger, 1) self.par.add_observer(self, self._trigger_priority, 0) self.par.notify_observers(0) self.assertEqual(self._first, self._trigger, 'priority should be second') self.assertEqual(self._second, self._trigger_priority, 'priority should be second') - + if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.testName'] diff --git a/GPy/testing/parameterized_tests.py b/GPy/testing/parameterized_tests.py index 5b718cbd..754e95db 100644 --- a/GPy/testing/parameterized_tests.py +++ b/GPy/testing/parameterized_tests.py @@ -108,7 +108,7 @@ class ParameterizedTest(unittest.TestCase): self.assertEqual(self.param.constraints._offset, 3) def test_fixing_randomize(self): - self.white.fix(warning=False) + self.white.fix(warning=True) val = float(self.test1.white.variance) self.test1.randomize() self.assertEqual(val, self.white.variance) @@ -119,6 +119,11 @@ class ParameterizedTest(unittest.TestCase): self.testmodel.randomize() self.assertEqual(val, self.testmodel.kern.lengthscale) + def test_printing(self): + print self.test1 + print self.param + print self.test1[''] + if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.test_add_parameter'] unittest.main() \ No newline at end of file