diff --git a/GPy/core/parameterization/param.py b/GPy/core/parameterization/param.py index 20ffd5db..7212c658 100644 --- a/GPy/core/parameterization/param.py +++ b/GPy/core/parameterization/param.py @@ -249,7 +249,7 @@ class Param(Parameterizable, ObsAr): try: indices = np.indices(self._realshape_, dtype=int) indices = indices[(slice(None),)+slice_index] - indices = np.rollaxis(indices, 0, indices.ndim).reshape(-1,2) + indices = np.rollaxis(indices, 0, indices.ndim).reshape(-1,self._realndim_) #print indices_ #if not np.all(indices==indices__): # import ipdb; ipdb.set_trace() diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 82c494b2..c851e5d8 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -50,30 +50,50 @@ class Observable(object): as an observer. Every time the observable changes, it sends a notification with self as only argument to all its observers. """ - _updates = True def __init__(self, *args, **kwargs): super(Observable, self).__init__() from lists_and_dicts import ObserverList self.observers = ObserverList() + self._updates = True - @property - def updates(self): + def updates(self, updates=None): + """ + Get or set, whether automatic updates are performed. When updates are + off, the model might be in a non-working state. To make the model work + turn updates on again. + + :param bool|None updates: + + bool: whether to do updates + None: get the current update state + """ + if updates is None: + p = getattr(self, '_highest_parent_', None) + if p is not None: + self._updates = p._updates + return self._updates + assert isinstance(updates, bool), "updates are either on (True) or off (False)" p = getattr(self, '_highest_parent_', None) if p is not None: - self._updates = p._updates - return self._updates - - @updates.setter - def updates(self, ups): - assert isinstance(ups, bool), "updates are either on (True) or off (False)" - p = getattr(self, '_highest_parent_', None) - if p is not None: - p._updates = ups + p._updates = updates else: - self._updates = ups - if ups: - self._trigger_params_changed() + self._updates = updates + self.update_model() + + def toggle_updates(self): + self.updates(not self.updates()) + def update_model(self): + """ + Update the model from the current state. + Make sure that updates are on, otherwise this + method will do nothing + """ + if not self.updates(): + #print "Warning: updates are off, updating the model will do nothing" + return + self._trigger_params_changed() + def add_observer(self, observer, callble, priority=0): """ Add an observer `observer` with the callback `callble` @@ -110,7 +130,7 @@ class Observable(object): :param min_priority: only notify observers with priority > min_priority if min_priority is None, notify all observers in order """ - if not self.updates: + if not self.updates(): return if which is None: which = self @@ -798,7 +818,7 @@ class OptimizationHandlable(Indexable): """ # first take care of all parameters (from N(0,1)) x = rand_gen(size=self._size_transformed(), *args, **kwargs) - self.updates = False # Switch off the updates + self.updates(False) # Switch off the updates self.optimizer_array = x # makes sure all of the tied parameters get the same init (since there's only one prior object...) # now draw from prior where possible x = self.param_array.copy() @@ -806,7 +826,7 @@ class OptimizationHandlable(Indexable): unfixlist = np.ones((self.size,),dtype=np.bool) unfixlist[self.constraints[__fixed__]] = False self.param_array[unfixlist] = x[unfixlist] - self.updates = True + self.updates(True) #=========================================================================== # For shared memory arrays. This does nothing in Param, but sets the memory diff --git a/GPy/testing/parameterized_tests.py b/GPy/testing/parameterized_tests.py index c647c6eb..a96ac64d 100644 --- a/GPy/testing/parameterized_tests.py +++ b/GPy/testing/parameterized_tests.py @@ -152,6 +152,12 @@ class ParameterizedTest(unittest.TestCase): self.test1.kern.randomize() self.assertEqual(val, self.rbf.variance) + def test_updates(self): + self.test1.updates = False + val = float(self.rbf.variance) + self.test1.kern.randomize() + self.assertEqual(val, self.rbf.variance) + def test_fixing_optimize(self): self.testmodel.kern.lengthscale.fix() val = float(self.testmodel.kern.lengthscale)