mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 03:52:39 +02:00
observer pattern now tested and fully operational. needed the good night rest : )
This commit is contained in:
parent
a35464b32f
commit
999d2419dd
6 changed files with 158 additions and 20 deletions
|
|
@ -62,7 +62,7 @@ class ObservableArray(np.ndarray, Observable):
|
||||||
def __setitem__(self, s, val):
|
def __setitem__(self, s, val):
|
||||||
if self._s_not_empty(s):
|
if self._s_not_empty(s):
|
||||||
super(ObservableArray, self).__setitem__(s, val)
|
super(ObservableArray, self).__setitem__(s, val)
|
||||||
self._notify_observers()
|
self._notify_observers(self[s])
|
||||||
|
|
||||||
def __getslice__(self, start, stop):
|
def __getslice__(self, start, stop):
|
||||||
return self.__getitem__(slice(start, stop))
|
return self.__getitem__(slice(start, stop))
|
||||||
|
|
|
||||||
18
GPy/core/parameterization/lists_and_dicts.py
Normal file
18
GPy/core/parameterization/lists_and_dicts.py
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
'''
|
||||||
|
Created on 27 Feb 2014
|
||||||
|
|
||||||
|
@author: maxz
|
||||||
|
'''
|
||||||
|
|
||||||
|
class ParamList(list):
|
||||||
|
"""
|
||||||
|
List to store ndarray-likes in.
|
||||||
|
It will look for 'is' instead of calling __eq__ on each element.
|
||||||
|
"""
|
||||||
|
def __contains__(self, other):
|
||||||
|
for el in self:
|
||||||
|
if el is other:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
@ -172,6 +172,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||||
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
|
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
|
||||||
except AttributeError: pass # returning 0d array or float, double etc
|
except AttributeError: pass # returning 0d array or float, double etc
|
||||||
return new_arr
|
return new_arr
|
||||||
|
|
||||||
def __setitem__(self, s, val):
|
def __setitem__(self, s, val):
|
||||||
super(Param, self).__setitem__(s, val)
|
super(Param, self).__setitem__(s, val)
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
||||||
|
import heapq
|
||||||
|
|
||||||
__updated__ = '2013-12-16'
|
__updated__ = '2013-12-16'
|
||||||
|
|
||||||
|
|
@ -11,25 +12,29 @@ def adjust_name_for_printing(name):
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
class Observable(object):
|
class Observable(object):
|
||||||
|
_updated = True
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
from collections import defaultdict
|
self._observer_callables_ = []
|
||||||
self._observer_callables_ = defaultdict(list)
|
|
||||||
|
|
||||||
def add_observer(self, observer, callble):
|
def add_observer(self, observer, callble, priority=0):
|
||||||
self._observer_callables_[observer].append(callble)
|
heapq.heappush(self._observer_callables_, (priority, observer, callble))
|
||||||
|
|
||||||
def remove_observer(self, observer, callble=None):
|
def remove_observer(self, observer, callble=None):
|
||||||
if observer in self._observer_callables_:
|
to_remove = []
|
||||||
if callble is None:
|
for p, obs, clble in self._observer_callables_:
|
||||||
del self._observer_callables_[observer]
|
if callble is not None:
|
||||||
elif callble in self._observer_callables_[observer]:
|
if (obs == observer) and (callble == clble):
|
||||||
self._observer_callables_[observer].remove(callble)
|
to_remove.append((p, obs, clble))
|
||||||
if len(self._observer_callables_[observer]) == 0:
|
else:
|
||||||
self.remove_observer(observer)
|
if obs is observer:
|
||||||
|
to_remove.append((p, obs, clble))
|
||||||
|
for r in to_remove:
|
||||||
|
self._observer_callables_.remove(r)
|
||||||
|
|
||||||
def _notify_observers(self):
|
def _notify_observers(self, which=None):
|
||||||
[[callble(self) for callble in callables]
|
if which is None:
|
||||||
for callables in self._observer_callables_.itervalues()]
|
which = self
|
||||||
|
[callble(which) for _, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_)]
|
||||||
|
|
||||||
class Pickleable(object):
|
class Pickleable(object):
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
|
|
@ -333,7 +338,7 @@ class Constrainable(Nameable, Indexable):
|
||||||
class Parameterizable(Constrainable, Observable):
|
class Parameterizable(Constrainable, Observable):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Parameterizable, self).__init__(*args, **kwargs)
|
super(Parameterizable, self).__init__(*args, **kwargs)
|
||||||
from GPy.core.parameterization.array_core import ParamList
|
from GPy.core.parameterization.lists_and_dicts import ParamList
|
||||||
_parameters_ = ParamList()
|
_parameters_ = ParamList()
|
||||||
self._added_names_ = set()
|
self._added_names_ = set()
|
||||||
|
|
||||||
|
|
@ -398,7 +403,7 @@ class Parameterizable(Constrainable, Observable):
|
||||||
"""Returns a (deep) copy of the current model"""
|
"""Returns a (deep) copy of the current model"""
|
||||||
import copy
|
import copy
|
||||||
from .index_operations import ParameterIndexOperations, ParameterIndexOperationsView
|
from .index_operations import ParameterIndexOperations, ParameterIndexOperationsView
|
||||||
from .array_core import ParamList
|
from .lists_and_dicts import ParamList
|
||||||
|
|
||||||
dc = dict()
|
dc = dict()
|
||||||
for k, v in self.__dict__.iteritems():
|
for k, v in self.__dict__.iteritems():
|
||||||
|
|
@ -427,7 +432,6 @@ class Parameterizable(Constrainable, Observable):
|
||||||
|
|
||||||
def _notify_parameters_changed(self):
|
def _notify_parameters_changed(self):
|
||||||
self.parameters_changed()
|
self.parameters_changed()
|
||||||
self._notify_observers()
|
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
self._direct_parent_._notify_parameters_changed()
|
self._direct_parent_._notify_parameters_changed()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
self.constraints.update(param.constraints, start)
|
self.constraints.update(param.constraints, start)
|
||||||
self.priors.update(param.priors, start)
|
self.priors.update(param.priors, start)
|
||||||
self._parameters_.insert(index, param)
|
self._parameters_.insert(index, param)
|
||||||
|
param.add_observer(self, self._pass_through_notify, -1)
|
||||||
self.size += param.size
|
self.size += param.size
|
||||||
else:
|
else:
|
||||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||||
|
|
@ -169,6 +170,12 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
self._param_slices_.append(slice(sizes[-2], sizes[-1]))
|
self._param_slices_.append(slice(sizes[-2], sizes[-1]))
|
||||||
self._add_parameter_name(p)
|
self._add_parameter_name(p)
|
||||||
|
|
||||||
|
#===========================================================================
|
||||||
|
# notification system
|
||||||
|
#===========================================================================
|
||||||
|
def _pass_through_notify(self, which):
|
||||||
|
self._notify_observers(which)
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Pickling operations
|
# Pickling operations
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
|
||||||
108
GPy/testing/observable_tests.py
Normal file
108
GPy/testing/observable_tests.py
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
'''
|
||||||
|
Created on 27 Feb 2014
|
||||||
|
|
||||||
|
@author: maxz
|
||||||
|
'''
|
||||||
|
import unittest
|
||||||
|
from GPy.core.parameterization.parameterized import Parameterized
|
||||||
|
from GPy.core.parameterization.param import Param
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
class ParamTestParent(Parameterized):
|
||||||
|
parent_changed_count = 0
|
||||||
|
def parameters_changed(self):
|
||||||
|
self.parent_changed_count += 1
|
||||||
|
|
||||||
|
class ParameterizedTest(Parameterized):
|
||||||
|
params_changed_count = 0
|
||||||
|
def parameters_changed(self):
|
||||||
|
self.params_changed_count += 1
|
||||||
|
|
||||||
|
class Test(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.parent = ParamTestParent('test parent')
|
||||||
|
self.par = ParameterizedTest('test model')
|
||||||
|
self.p = Param('test parameter', numpy.random.normal(1,2,(10,3)))
|
||||||
|
|
||||||
|
self.par.add_parameter(self.p)
|
||||||
|
self.parent.add_parameter(self.par)
|
||||||
|
|
||||||
|
self._observer_triggered = None
|
||||||
|
self._trigger_count = 0
|
||||||
|
self._first = None
|
||||||
|
self._second = None
|
||||||
|
|
||||||
|
def _trigger(self, which):
|
||||||
|
self._observer_triggered = float(which)
|
||||||
|
self._trigger_count += 1
|
||||||
|
if self._first is not None:
|
||||||
|
self._second = self._trigger
|
||||||
|
else:
|
||||||
|
self._first = self._trigger
|
||||||
|
|
||||||
|
def _trigger_priority(self, which):
|
||||||
|
if self._first is not None:
|
||||||
|
self._second = self._trigger_priority
|
||||||
|
else:
|
||||||
|
self._first = self._trigger_priority
|
||||||
|
|
||||||
|
def test_observable(self):
|
||||||
|
self.par.add_observer(self, self._trigger, -1)
|
||||||
|
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
||||||
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
|
self.p[0,1] = 3 # trigger observers
|
||||||
|
self.assertEqual(self._observer_triggered, 3, 'observer should have triggered')
|
||||||
|
self.assertEqual(self._trigger_count, 1, 'observer should have triggered once')
|
||||||
|
self.assertEqual(self.par.params_changed_count, 1, 'params changed once')
|
||||||
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
|
self.par.remove_observer(self)
|
||||||
|
self.p[2,1] = 4
|
||||||
|
self.assertEqual(self._observer_triggered, 3, 'observer should not have triggered')
|
||||||
|
self.assertEqual(self._trigger_count, 1, 'observer should have triggered once')
|
||||||
|
self.assertEqual(self.par.params_changed_count, 2, 'params changed second')
|
||||||
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
|
self.par.add_observer(self, self._trigger, -1)
|
||||||
|
self.p[2,1] = 4
|
||||||
|
self.assertEqual(self._observer_triggered, 4, 'observer should have triggered')
|
||||||
|
self.assertEqual(self._trigger_count, 2, 'observer should have triggered once')
|
||||||
|
self.assertEqual(self.par.params_changed_count, 3, 'params changed second')
|
||||||
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
|
self.par.remove_observer(self, self._trigger)
|
||||||
|
self.p[0,1] = 3
|
||||||
|
self.assertEqual(self._observer_triggered, 4, 'observer should not have triggered')
|
||||||
|
self.assertEqual(self._trigger_count, 2, 'observer should have triggered once')
|
||||||
|
self.assertEqual(self.par.params_changed_count, 4, 'params changed second')
|
||||||
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
|
def test_set_params(self):
|
||||||
|
self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet')
|
||||||
|
self.par._set_params(numpy.ones(self.par.size))
|
||||||
|
self.assertEqual(self.par.params_changed_count, 1, 'now params changed')
|
||||||
|
self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
|
||||||
|
|
||||||
|
def test_priority(self):
|
||||||
|
self.par.add_observer(self, self._trigger, -1)
|
||||||
|
self.par.add_observer(self, self._trigger_priority, 0)
|
||||||
|
self.par._notify_observers(0)
|
||||||
|
self.assertEqual(self._first, self._trigger_priority, 'priority should be first')
|
||||||
|
self.assertEqual(self._second, self._trigger, 'priority should be first')
|
||||||
|
|
||||||
|
self.par.remove_observer(self)
|
||||||
|
self._first = self._second = None
|
||||||
|
|
||||||
|
self.par.add_observer(self, self._trigger, 1)
|
||||||
|
self.par.add_observer(self, self._trigger_priority, 0)
|
||||||
|
self.par._notify_observers(0)
|
||||||
|
self.assertEqual(self._first, self._trigger, 'priority should be second')
|
||||||
|
self.assertEqual(self._second, self._trigger_priority, 'priority should be second')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#import sys;sys.argv = ['', 'Test.testName']
|
||||||
|
unittest.main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue