observer pattern now tested and fully operational. needed the good night rest : )

This commit is contained in:
Max Zwiessele 2014-02-28 08:58:34 +00:00
parent a35464b32f
commit 999d2419dd
6 changed files with 158 additions and 20 deletions

View file

@ -62,7 +62,7 @@ class ObservableArray(np.ndarray, Observable):
def __setitem__(self, s, val):
if self._s_not_empty(s):
super(ObservableArray, self).__setitem__(s, val)
self._notify_observers()
self._notify_observers(self[s])
def __getslice__(self, start, stop):
return self.__getitem__(slice(start, stop))

View file

@ -0,0 +1,18 @@
'''
Created on 27 Feb 2014
@author: maxz
'''
class ParamList(list):
"""
List to store ndarray-likes in.
It will look for 'is' instead of calling __eq__ on each element.
"""
def __contains__(self, other):
for el in self:
if el is other:
return True
return False
pass

View file

@ -172,6 +172,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
except AttributeError: pass # returning 0d array or float, double etc
return new_arr
def __setitem__(self, s, val):
super(Param, self).__setitem__(s, val)
if self.has_parent():

View file

@ -2,6 +2,7 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
from transformations import Transformation, Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
import heapq
__updated__ = '2013-12-16'
@ -11,25 +12,29 @@ def adjust_name_for_printing(name):
return ''
class Observable(object):
_updated = True
def __init__(self, *args, **kwargs):
from collections import defaultdict
self._observer_callables_ = defaultdict(list)
def add_observer(self, observer, callble):
self._observer_callables_[observer].append(callble)
self._observer_callables_ = []
def add_observer(self, observer, callble, priority=0):
heapq.heappush(self._observer_callables_, (priority, observer, callble))
def remove_observer(self, observer, callble=None):
if observer in self._observer_callables_:
if callble is None:
del self._observer_callables_[observer]
elif callble in self._observer_callables_[observer]:
self._observer_callables_[observer].remove(callble)
if len(self._observer_callables_[observer]) == 0:
self.remove_observer(observer)
def _notify_observers(self):
[[callble(self) for callble in callables]
for callables in self._observer_callables_.itervalues()]
to_remove = []
for p, obs, clble in self._observer_callables_:
if callble is not None:
if (obs == observer) and (callble == clble):
to_remove.append((p, obs, clble))
else:
if obs is observer:
to_remove.append((p, obs, clble))
for r in to_remove:
self._observer_callables_.remove(r)
def _notify_observers(self, which=None):
if which is None:
which = self
[callble(which) for _, _, callble in heapq.nlargest(len(self._observer_callables_), self._observer_callables_)]
class Pickleable(object):
def _getstate(self):
@ -333,7 +338,7 @@ class Constrainable(Nameable, Indexable):
class Parameterizable(Constrainable, Observable):
def __init__(self, *args, **kwargs):
super(Parameterizable, self).__init__(*args, **kwargs)
from GPy.core.parameterization.array_core import ParamList
from GPy.core.parameterization.lists_and_dicts import ParamList
_parameters_ = ParamList()
self._added_names_ = set()
@ -398,7 +403,7 @@ class Parameterizable(Constrainable, Observable):
"""Returns a (deep) copy of the current model"""
import copy
from .index_operations import ParameterIndexOperations, ParameterIndexOperationsView
from .array_core import ParamList
from .lists_and_dicts import ParamList
dc = dict()
for k, v in self.__dict__.iteritems():
@ -427,7 +432,6 @@ class Parameterizable(Constrainable, Observable):
def _notify_parameters_changed(self):
self.parameters_changed()
self._notify_observers()
if self.has_parent():
self._direct_parent_._notify_parameters_changed()

View file

@ -116,6 +116,7 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
self.constraints.update(param.constraints, start)
self.priors.update(param.priors, start)
self._parameters_.insert(index, param)
param.add_observer(self, self._pass_through_notify, -1)
self.size += param.size
else:
raise RuntimeError, """Parameter exists already added and no copy made"""
@ -169,6 +170,12 @@ class Parameterized(Parameterizable, Pickleable, Gradcheckable):
self._param_slices_.append(slice(sizes[-2], sizes[-1]))
self._add_parameter_name(p)
#===========================================================================
# notification system
#===========================================================================
def _pass_through_notify(self, which):
self._notify_observers(which)
#===========================================================================
# Pickling operations
#===========================================================================