[updates] now handled in observable, should have from the begining :/

This commit is contained in:
Max Zwiessele 2015-01-13 09:45:11 +00:00
parent cd8dd9ab98
commit b9b6ce91d8
6 changed files with 44 additions and 20 deletions

View file

@ -14,6 +14,10 @@ class Observable(object):
super(Observable, self).__init__()
from lists_and_dicts import ObserverList
self.observers = ObserverList()
self._update_on = True
def set_updates(self, on=True):
self._update_on = on
def add_observer(self, observer, callble, priority=0):
"""
@ -51,15 +55,16 @@ class Observable(object):
:param min_priority: only notify observers with priority > min_priority
if min_priority is None, notify all observers in order
"""
if which is None:
which = self
if min_priority is None:
[callble(self, which=which) for _, _, callble in self.observers]
else:
for p, _, callble in self.observers:
if p <= min_priority:
break
callble(self, which=which)
if self._update_on:
if which is None:
which = self
if min_priority is None:
[callble(self, which=which) for _, _, callble in self.observers]
else:
for p, _, callble in self.observers:
if p <= min_priority:
break
callble(self, which=which)
def change_priority(self, observer, callble, priority):
self.remove_observer(observer, callble)

View file

@ -84,6 +84,7 @@ class Param(Parameterizable, ObsAr):
self._original_ = getattr(obj, '_original_', None)
self._name = getattr(obj, '_name', None)
self._gradient_array_ = getattr(obj, '_gradient_array_', None)
self._update_on = getattr(obj, '_update_on', None)
self.constraints = getattr(obj, 'constraints', None)
self.priors = getattr(obj, 'priors', None)
@ -360,7 +361,7 @@ class ParamConcatenation(object):
#===========================================================================
def update_all_params(self):
for par in self.parents:
par.notify_observers()
par.trigger_update(trigger_parent=False)
def constrain(self, constraint, warning=True):
[param.constrain(constraint, trigger_parent=False) for param in self.params]

View file

@ -471,7 +471,7 @@ class Indexable(Nameable, Updateable):
self.param_array[...] = transform.initialize(self.param_array)
reconstrained = self.unconstrain()
added = self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
self.notify_observers(self, None if trigger_parent else -np.inf)
self.trigger_update(trigger_parent)
return added
def unconstrain(self, *transforms):

View file

@ -156,7 +156,7 @@ class Parameterized(Parameterizable):
p._parent_index_ += 1
self.parameters.insert(index, param)
param.add_observer(self, self._pass_through_notify_observers, -1000)
param.add_observer(self, self._pass_through_notify_observers, -np.inf)
parent = self
while parent is not None:
@ -296,7 +296,7 @@ class Parameterized(Parameterizable):
self.param_array[name] = value
except:
raise ValueError, "Setting by slice or index only allowed with array-like"
self._trigger_params_changed()
self.trigger_update()
else:
try: param = self.__getitem__(name, paramlist)
except: raise

View file

@ -27,18 +27,18 @@ class Updateable(Observable):
None: get the current update state
"""
if updates is None:
p = getattr(self, '_highest_parent_', None)
if p is not None:
self._updates = p._updates
return self._updates
return self._update_on
assert isinstance(updates, bool), "updates are either on (True) or off (False)"
p = getattr(self, '_highest_parent_', None)
if p is not None:
p._updates = updates
self._updates = updates
def turn_updates(s):
s._update_on = updates
p.traverse(turn_updates)
self.trigger_update()
def toggle_update(self):
print "deprecated: toggle_update was renamed to update_toggle for easier access"
self.update_toggle()
def update_toggle(self):
self.update_model(not self.update_model())
def trigger_update(self, trigger_parent=True):

View file

@ -178,6 +178,24 @@ class MiscTests(unittest.TestCase):
m.optimize()
print m
def test_model_updates(self):
Y1 = np.random.normal(0, 1, (40, 13))
Y2 = np.random.normal(0, 1, (40, 6))
m = GPy.models.MRD([Y1, Y2], 5)
self.count = 0
m.add_observer(self, self._count_updates, -2000)
m.update_model(False)
m['.*Gaussian'] = .001
self.assertEquals(self.count, 0)
m['.*Gaussian'].constrain_bounded(0,.01)
self.assertEquals(self.count, 0)
m.Z.fix()
self.assertEquals(self.count, 0)
m.update_model(True)
self.assertEquals(self.count, 1)
def _count_updates(self, me, which):
self.count+=1
def test_model_optimize(self):
X = np.random.uniform(-3., 3., (20, 1))
Y = np.sin(X) + np.random.randn(20, 1) * 0.05