diff --git a/GPy/core/parameterization/observable.py b/GPy/core/parameterization/observable.py index 4782d2ea..8a85c6ca 100644 --- a/GPy/core/parameterization/observable.py +++ b/GPy/core/parameterization/observable.py @@ -14,6 +14,10 @@ class Observable(object): super(Observable, self).__init__() from lists_and_dicts import ObserverList self.observers = ObserverList() + self._update_on = True + + def set_updates(self, on=True): + self._update_on = on def add_observer(self, observer, callble, priority=0): """ @@ -51,15 +55,16 @@ class Observable(object): :param min_priority: only notify observers with priority > min_priority if min_priority is None, notify all observers in order """ - if which is None: - which = self - if min_priority is None: - [callble(self, which=which) for _, _, callble in self.observers] - else: - for p, _, callble in self.observers: - if p <= min_priority: - break - callble(self, which=which) + if self._update_on: + if which is None: + which = self + if min_priority is None: + [callble(self, which=which) for _, _, callble in self.observers] + else: + for p, _, callble in self.observers: + if p <= min_priority: + break + callble(self, which=which) def change_priority(self, observer, callble, priority): self.remove_observer(observer, callble) diff --git a/GPy/core/parameterization/param.py b/GPy/core/parameterization/param.py index 2fbb5df5..e9a42cb5 100644 --- a/GPy/core/parameterization/param.py +++ b/GPy/core/parameterization/param.py @@ -84,6 +84,7 @@ class Param(Parameterizable, ObsAr): self._original_ = getattr(obj, '_original_', None) self._name = getattr(obj, '_name', None) self._gradient_array_ = getattr(obj, '_gradient_array_', None) + self._update_on = getattr(obj, '_update_on', None) self.constraints = getattr(obj, 'constraints', None) self.priors = getattr(obj, 'priors', None) @@ -360,7 +361,7 @@ class ParamConcatenation(object): #=========================================================================== def update_all_params(self): for par in self.parents: - par.notify_observers() + par.trigger_update(trigger_parent=False) def constrain(self, constraint, warning=True): [param.constrain(constraint, trigger_parent=False) for param in self.params] diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 656bd1c5..9a903079 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -471,7 +471,7 @@ class Indexable(Nameable, Updateable): self.param_array[...] = transform.initialize(self.param_array) reconstrained = self.unconstrain() added = self._add_to_index_operations(self.constraints, reconstrained, transform, warning) - self.notify_observers(self, None if trigger_parent else -np.inf) + self.trigger_update(trigger_parent) return added def unconstrain(self, *transforms): diff --git a/GPy/core/parameterization/parameterized.py b/GPy/core/parameterization/parameterized.py index 28b58973..317f8f47 100644 --- a/GPy/core/parameterization/parameterized.py +++ b/GPy/core/parameterization/parameterized.py @@ -156,7 +156,7 @@ class Parameterized(Parameterizable): p._parent_index_ += 1 self.parameters.insert(index, param) - param.add_observer(self, self._pass_through_notify_observers, -1000) + param.add_observer(self, self._pass_through_notify_observers, -np.inf) parent = self while parent is not None: @@ -296,7 +296,7 @@ class Parameterized(Parameterizable): self.param_array[name] = value except: raise ValueError, "Setting by slice or index only allowed with array-like" - self._trigger_params_changed() + self.trigger_update() else: try: param = self.__getitem__(name, paramlist) except: raise diff --git a/GPy/core/parameterization/updateable.py b/GPy/core/parameterization/updateable.py index 593f3c05..278ba8cd 100644 --- a/GPy/core/parameterization/updateable.py +++ b/GPy/core/parameterization/updateable.py @@ -27,18 +27,18 @@ class Updateable(Observable): None: get the current update state """ if updates is None: - p = getattr(self, '_highest_parent_', None) - if p is not None: - self._updates = p._updates - return self._updates + return self._update_on assert isinstance(updates, bool), "updates are either on (True) or off (False)" p = getattr(self, '_highest_parent_', None) - if p is not None: - p._updates = updates - self._updates = updates + def turn_updates(s): + s._update_on = updates + p.traverse(turn_updates) self.trigger_update() def toggle_update(self): + print "deprecated: toggle_update was renamed to update_toggle for easier access" + self.update_toggle() + def update_toggle(self): self.update_model(not self.update_model()) def trigger_update(self, trigger_parent=True): diff --git a/GPy/testing/model_tests.py b/GPy/testing/model_tests.py index 521baeb3..559014f7 100644 --- a/GPy/testing/model_tests.py +++ b/GPy/testing/model_tests.py @@ -178,6 +178,24 @@ class MiscTests(unittest.TestCase): m.optimize() print m + def test_model_updates(self): + Y1 = np.random.normal(0, 1, (40, 13)) + Y2 = np.random.normal(0, 1, (40, 6)) + m = GPy.models.MRD([Y1, Y2], 5) + self.count = 0 + m.add_observer(self, self._count_updates, -2000) + m.update_model(False) + m['.*Gaussian'] = .001 + self.assertEquals(self.count, 0) + m['.*Gaussian'].constrain_bounded(0,.01) + self.assertEquals(self.count, 0) + m.Z.fix() + self.assertEquals(self.count, 0) + m.update_model(True) + self.assertEquals(self.count, 1) + def _count_updates(self, me, which): + self.count+=1 + def test_model_optimize(self): X = np.random.uniform(-3., 3., (20, 1)) Y = np.sin(X) + np.random.randn(20, 1) * 0.05