diff --git a/GPy/core/parameterization/array_core.py b/GPy/core/parameterization/array_core.py index a338ceed..208cd4fb 100644 --- a/GPy/core/parameterization/array_core.py +++ b/GPy/core/parameterization/array_core.py @@ -62,7 +62,7 @@ class ObservableArray(np.ndarray, Observable): def __setitem__(self, s, val): if self._s_not_empty(s): super(ObservableArray, self).__setitem__(s, val) - self._notify_observers() + self._notify_observers(self[s]) def __getslice__(self, start, stop): return self.__getitem__(slice(start, stop)) diff --git a/GPy/core/parameterization/lists_and_dicts.py b/GPy/core/parameterization/lists_and_dicts.py new file mode 100644 index 00000000..cdf9f5f6 --- /dev/null +++ b/GPy/core/parameterization/lists_and_dicts.py @@ -0,0 +1,18 @@ +''' +Created on 27 Feb 2014 + +@author: maxz +''' + +class ParamList(list): + """ + List to store ndarray-likes in. + It will look for 'is' instead of calling __eq__ on each element. + """ + def __contains__(self, other): + for el in self: + if el is other: + return True + return False + + pass diff --git a/GPy/core/parameterization/param.py b/GPy/core/parameterization/param.py index 89d3a4e4..ca9905f7 100644 --- a/GPy/core/parameterization/param.py +++ b/GPy/core/parameterization/param.py @@ -172,6 +172,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable): try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base except AttributeError: pass # returning 0d array or float, double etc return new_arr + def __setitem__(self, s, val): super(Param, self).__setitem__(s, val) if self.has_parent(): diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 6afa94cb..58dd63d8 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -2,6 +2,7 @@ # Licensed under the BSD 3-clause license (see LICENSE.txt) from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED +import heapq __updated__ = '2013-12-16' @@ -11,25 +12,29 @@ def adjust_name_for_printing(name): return '' class Observable(object): + _updated = True def __init__(self, *args, **kwargs): - from collections import defaultdict - self._observer_callables_ = defaultdict(list) - - def add_observer(self, observer, callble): - self._observer_callables_[observer].append(callble) + self._observer_callables_ = [] + def add_observer(self, observer, callble, priority=0): + heapq.heappush(self._observer_callables_, (priority, observer, callble)) + def remove_observer(self, observer, callble=None): - if observer in self._observer_callables_: - if callble is None: - del self._observer_callables_[observer] - elif callble in self._observer_callables_[observer]: - self._observer_callables_[observer].remove(callble) - if len(self._observer_callables_[observer]) == 0: - self.remove_observer(observer) - - def _notify_observers(self): - [[callble(self) for callble in callables] - for callables in self._observer_callables_.itervalues()] + to_remove = [] + for p, obs, clble in self._observer_callables_: + if callble is not None: + if (obs == observer) and (callble == clble): + to_remove.append((p, obs, clble)) + else: + if obs is observer: + to_remove.append((p, obs, clble)) + for r in to_remove: + self._observer_callables_.remove(r) + + def _notify_observers(self, which=None): + if which is None: + which = self + [callble(which) for _, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_)] class Pickleable(object): def _getstate(self): @@ -333,7 +338,7 @@ class Constrainable(Nameable, Indexable): class Parameterizable(Constrainable, Observable): def __init__(self, *args, **kwargs): super(Parameterizable, self).__init__(*args, **kwargs) - from GPy.core.parameterization.array_core import ParamList + from GPy.core.parameterization.lists_and_dicts import ParamList _parameters_ = ParamList() self._added_names_ = set() @@ -398,7 +403,7 @@ class Parameterizable(Constrainable, Observable): """Returns a (deep) copy of the current model""" import copy from .index_operations import ParameterIndexOperations, ParameterIndexOperationsView - from .array_core import ParamList + from .lists_and_dicts import ParamList dc = dict() for k, v in self.__dict__.iteritems(): @@ -427,7 +432,6 @@ class Parameterizable(Constrainable, Observable): def _notify_parameters_changed(self): self.parameters_changed() - self._notify_observers() if self.has_parent(): self._direct_parent_._notify_parameters_changed() diff --git a/GPy/core/parameterization/parameterized.py b/GPy/core/parameterization/parameterized.py index f5fcc6ad..fe8c76e4 100644 --- a/GPy/core/parameterization/parameterized.py +++ b/GPy/core/parameterization/parameterized.py @@ -116,6 +116,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable): self.constraints.update(param.constraints, start) self.priors.update(param.priors, start) self._parameters_.insert(index, param) + param.add_observer(self, self._pass_through_notify, -1) self.size += param.size else: raise RuntimeError, """Parameter exists already added and no copy made""" @@ -169,6 +170,12 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable): self._param_slices_.append(slice(sizes[-2], sizes[-1])) self._add_parameter_name(p) + #=========================================================================== + # notification system + #=========================================================================== + def _pass_through_notify(self, which): + self._notify_observers(which) + #=========================================================================== # Pickling operations #=========================================================================== diff --git a/GPy/testing/observable_tests.py b/GPy/testing/observable_tests.py new file mode 100644 index 00000000..214a678f --- /dev/null +++ b/GPy/testing/observable_tests.py @@ -0,0 +1,108 @@ +''' +Created on 27 Feb 2014 + +@author: maxz +''' +import unittest +from GPy.core.parameterization.parameterized import Parameterized +from GPy.core.parameterization.param import Param +import numpy + + +class ParamTestParent(Parameterized): + parent_changed_count = 0 + def parameters_changed(self): + self.parent_changed_count += 1 + +class ParameterizedTest(Parameterized): + params_changed_count = 0 + def parameters_changed(self): + self.params_changed_count += 1 + +class Test(unittest.TestCase): + + def setUp(self): + self.parent = ParamTestParent('test parent') + self.par = ParameterizedTest('test model') + self.p = Param('test parameter', numpy.random.normal(1,2,(10,3))) + + self.par.add_parameter(self.p) + self.parent.add_parameter(self.par) + + self._observer_triggered = None + self._trigger_count = 0 + self._first = None + self._second = None + + def _trigger(self, which): + self._observer_triggered = float(which) + self._trigger_count += 1 + if self._first is not None: + self._second = self._trigger + else: + self._first = self._trigger + + def _trigger_priority(self, which): + if self._first is not None: + self._second = self._trigger_priority + else: + self._first = self._trigger_priority + + def test_observable(self): + self.par.add_observer(self, self._trigger, -1) + self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet') + self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') + + self.p[0,1] = 3 # trigger observers + self.assertEqual(self._observer_triggered, 3, 'observer should have triggered') + self.assertEqual(self._trigger_count, 1, 'observer should have triggered once') + self.assertEqual(self.par.params_changed_count, 1, 'params changed once') + self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') + + self.par.remove_observer(self) + self.p[2,1] = 4 + self.assertEqual(self._observer_triggered, 3, 'observer should not have triggered') + self.assertEqual(self._trigger_count, 1, 'observer should have triggered once') + self.assertEqual(self.par.params_changed_count, 2, 'params changed second') + self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') + + self.par.add_observer(self, self._trigger, -1) + self.p[2,1] = 4 + self.assertEqual(self._observer_triggered, 4, 'observer should have triggered') + self.assertEqual(self._trigger_count, 2, 'observer should have triggered once') + self.assertEqual(self.par.params_changed_count, 3, 'params changed second') + self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') + + self.par.remove_observer(self, self._trigger) + self.p[0,1] = 3 + self.assertEqual(self._observer_triggered, 4, 'observer should not have triggered') + self.assertEqual(self._trigger_count, 2, 'observer should have triggered once') + self.assertEqual(self.par.params_changed_count, 4, 'params changed second') + self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') + + def test_set_params(self): + self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet') + self.par._set_params(numpy.ones(self.par.size)) + self.assertEqual(self.par.params_changed_count, 1, 'now params changed') + self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') + + def test_priority(self): + self.par.add_observer(self, self._trigger, -1) + self.par.add_observer(self, self._trigger_priority, 0) + self.par._notify_observers(0) + self.assertEqual(self._first, self._trigger_priority, 'priority should be first') + self.assertEqual(self._second, self._trigger, 'priority should be first') + + self.par.remove_observer(self) + self._first = self._second = None + + self.par.add_observer(self, self._trigger, 1) + self.par.add_observer(self, self._trigger_priority, 0) + self.par._notify_observers(0) + self.assertEqual(self._first, self._trigger, 'priority should be second') + self.assertEqual(self._second, self._trigger_priority, 'priority should be second') + + +if __name__ == "__main__": + #import sys;sys.argv = ['', 'Test.testName'] + unittest.main() \ No newline at end of file