mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
[updates] now handled in observable, should have from the begining :/
This commit is contained in:
parent
cd8dd9ab98
commit
b9b6ce91d8
6 changed files with 44 additions and 20 deletions
|
|
@ -14,6 +14,10 @@ class Observable(object):
|
|||
super(Observable, self).__init__()
|
||||
from lists_and_dicts import ObserverList
|
||||
self.observers = ObserverList()
|
||||
self._update_on = True
|
||||
|
||||
def set_updates(self, on=True):
|
||||
self._update_on = on
|
||||
|
||||
def add_observer(self, observer, callble, priority=0):
|
||||
"""
|
||||
|
|
@ -51,15 +55,16 @@ class Observable(object):
|
|||
:param min_priority: only notify observers with priority > min_priority
|
||||
if min_priority is None, notify all observers in order
|
||||
"""
|
||||
if which is None:
|
||||
which = self
|
||||
if min_priority is None:
|
||||
[callble(self, which=which) for _, _, callble in self.observers]
|
||||
else:
|
||||
for p, _, callble in self.observers:
|
||||
if p <= min_priority:
|
||||
break
|
||||
callble(self, which=which)
|
||||
if self._update_on:
|
||||
if which is None:
|
||||
which = self
|
||||
if min_priority is None:
|
||||
[callble(self, which=which) for _, _, callble in self.observers]
|
||||
else:
|
||||
for p, _, callble in self.observers:
|
||||
if p <= min_priority:
|
||||
break
|
||||
callble(self, which=which)
|
||||
|
||||
def change_priority(self, observer, callble, priority):
|
||||
self.remove_observer(observer, callble)
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class Param(Parameterizable, ObsAr):
|
|||
self._original_ = getattr(obj, '_original_', None)
|
||||
self._name = getattr(obj, '_name', None)
|
||||
self._gradient_array_ = getattr(obj, '_gradient_array_', None)
|
||||
self._update_on = getattr(obj, '_update_on', None)
|
||||
self.constraints = getattr(obj, 'constraints', None)
|
||||
self.priors = getattr(obj, 'priors', None)
|
||||
|
||||
|
|
@ -360,7 +361,7 @@ class ParamConcatenation(object):
|
|||
#===========================================================================
|
||||
def update_all_params(self):
|
||||
for par in self.parents:
|
||||
par.notify_observers()
|
||||
par.trigger_update(trigger_parent=False)
|
||||
|
||||
def constrain(self, constraint, warning=True):
|
||||
[param.constrain(constraint, trigger_parent=False) for param in self.params]
|
||||
|
|
|
|||
|
|
@ -471,7 +471,7 @@ class Indexable(Nameable, Updateable):
|
|||
self.param_array[...] = transform.initialize(self.param_array)
|
||||
reconstrained = self.unconstrain()
|
||||
added = self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
|
||||
self.notify_observers(self, None if trigger_parent else -np.inf)
|
||||
self.trigger_update(trigger_parent)
|
||||
return added
|
||||
|
||||
def unconstrain(self, *transforms):
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ class Parameterized(Parameterizable):
|
|||
p._parent_index_ += 1
|
||||
self.parameters.insert(index, param)
|
||||
|
||||
param.add_observer(self, self._pass_through_notify_observers, -1000)
|
||||
param.add_observer(self, self._pass_through_notify_observers, -np.inf)
|
||||
|
||||
parent = self
|
||||
while parent is not None:
|
||||
|
|
@ -296,7 +296,7 @@ class Parameterized(Parameterizable):
|
|||
self.param_array[name] = value
|
||||
except:
|
||||
raise ValueError, "Setting by slice or index only allowed with array-like"
|
||||
self._trigger_params_changed()
|
||||
self.trigger_update()
|
||||
else:
|
||||
try: param = self.__getitem__(name, paramlist)
|
||||
except: raise
|
||||
|
|
|
|||
|
|
@ -27,18 +27,18 @@ class Updateable(Observable):
|
|||
None: get the current update state
|
||||
"""
|
||||
if updates is None:
|
||||
p = getattr(self, '_highest_parent_', None)
|
||||
if p is not None:
|
||||
self._updates = p._updates
|
||||
return self._updates
|
||||
return self._update_on
|
||||
assert isinstance(updates, bool), "updates are either on (True) or off (False)"
|
||||
p = getattr(self, '_highest_parent_', None)
|
||||
if p is not None:
|
||||
p._updates = updates
|
||||
self._updates = updates
|
||||
def turn_updates(s):
|
||||
s._update_on = updates
|
||||
p.traverse(turn_updates)
|
||||
self.trigger_update()
|
||||
|
||||
def toggle_update(self):
|
||||
print "deprecated: toggle_update was renamed to update_toggle for easier access"
|
||||
self.update_toggle()
|
||||
def update_toggle(self):
|
||||
self.update_model(not self.update_model())
|
||||
|
||||
def trigger_update(self, trigger_parent=True):
|
||||
|
|
|
|||
|
|
@ -178,6 +178,24 @@ class MiscTests(unittest.TestCase):
|
|||
m.optimize()
|
||||
print m
|
||||
|
||||
def test_model_updates(self):
|
||||
Y1 = np.random.normal(0, 1, (40, 13))
|
||||
Y2 = np.random.normal(0, 1, (40, 6))
|
||||
m = GPy.models.MRD([Y1, Y2], 5)
|
||||
self.count = 0
|
||||
m.add_observer(self, self._count_updates, -2000)
|
||||
m.update_model(False)
|
||||
m['.*Gaussian'] = .001
|
||||
self.assertEquals(self.count, 0)
|
||||
m['.*Gaussian'].constrain_bounded(0,.01)
|
||||
self.assertEquals(self.count, 0)
|
||||
m.Z.fix()
|
||||
self.assertEquals(self.count, 0)
|
||||
m.update_model(True)
|
||||
self.assertEquals(self.count, 1)
|
||||
def _count_updates(self, me, which):
|
||||
self.count+=1
|
||||
|
||||
def test_model_optimize(self):
|
||||
X = np.random.uniform(-3., 3., (20, 1))
|
||||
Y = np.sin(X) + np.random.randn(20, 1) * 0.05
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue