making observables accessable

This commit is contained in:
mzwiessele 2014-04-04 13:47:02 +01:00
parent 22e4f8a1e8
commit 04a889b3a9
7 changed files with 45 additions and 59 deletions

View file

@ -4,6 +4,7 @@
from model import *
from parameterization.parameterized import adjust_name_for_printing, Parameterizable
from parameterization.param import Param, ParamConcatenation
from parameterization.observable_array import ObsAr
from gp import GP
from sparse_gp import SparseGP

View file

@ -5,6 +5,7 @@ Created on 27 Feb 2014
'''
from collections import defaultdict
import weakref
def intarray_default_factory():
import numpy as np
@ -41,49 +42,40 @@ class ObservablesList(object):
def __init__(self):
self._poc = []
def remove(self, value):
return self._poc.remove(value)
def __delitem__(self, ind):
return self._poc.__delitem__(ind)
def __setitem__(self, ind, item):
return self._poc.__setitem__(ind, item)
def __getitem__(self, ind):
return self._poc.__getitem__(ind)
def remove(self, priority, observable, callble):
"""
"""
self._poc.remove((priority, observable, callble))
def __repr__(self):
return self._poc.__repr__()
def append(self, obj):
return self._poc.append(obj)
def index(self, value):
return self._poc.index(value)
def extend(self, iterable):
return self._poc.extend(iterable)
def add(self, priority, observable, callble):
i = 0
for i, [p, _, _] in enumerate(self._poc):
if p < priority:
break
self._poc.insert(i, (priority, weakref.ref(observable), callble))
def __str__(self):
return self._poc.__str__()
ret = []
curr_p = None
for p, o, c in self:
curr = ''
if curr_p != p:
pre = "{!s}: ".format(p)
curr_pre = pre
else: curr_pre = " "*len(pre)
curr_p = p
curr += curr_pre
ret.append(curr + ", ".join(map(str, [o,c])))
return '\n'.join(ret)
def __iter__(self):
return self._poc.__iter__()
def insert(self, index, obj):
return self._poc.insert(index, obj)
self._poc = [(p,o,c) for p,o,c in self._poc if o() is not None]
for p, o, c in self._poc:
if o() is not None:
yield p, o(), c
def __len__(self):
return self._poc.__len__()
@ -106,6 +98,6 @@ class ObservablesList(object):
def __setstate__(self, state):
self._poc = []
for p, o, c in state:
self._poc.append((p,o,getattr(o, c)))
self.add(p,o,getattr(o, c))
pass

View file

@ -25,7 +25,7 @@ class ObsAr(np.ndarray, Pickleable, Observable):
def __array_finalize__(self, obj):
# see InfoArray.__array_finalize__ for comments
if obj is None: return
self._observer_callables_ = getattr(obj, '_observer_callables_', None)
self.observers = getattr(obj, 'observers', None)
def __array_wrap__(self, out_arr, context=None):
return out_arr.view(np.ndarray)

View file

@ -59,7 +59,7 @@ class Param(OptimizationHandlable, ObsAr):
import pydot
node = pydot.Node(id(self), shape='record', label=self.name)
G.add_node(node)
for o in self._observer_callables_.keys():
for o in self.observers.keys():
label = o.name if hasattr(o, 'name') else str(o)
observed_node = pydot.Node(id(o), label=label)
G.add_node(observed_node)

View file

@ -44,22 +44,23 @@ class Observable(object):
def __init__(self, *args, **kwargs):
super(Observable, self).__init__()
from lists_and_dicts import ObservablesList
self._observer_callables_ = ObservablesList()
self.observers = ObservablesList()
def add_observer(self, observer, callble, priority=0):
self._insert_sorted(priority, observer, callble)
self.observers.add(priority, observer, callble)
def remove_observer(self, observer, callble=None):
to_remove = []
for p, obs, clble in self._observer_callables_:
for poc in self.observers:
_, obs, clble = poc
if callble is not None:
if (obs == observer) and (callble == clble):
to_remove.append((p, obs, clble))
to_remove.append(poc)
else:
if obs is observer:
to_remove.append((p, obs, clble))
to_remove.append(poc)
for r in to_remove:
self._observer_callables_.remove(r)
self.observers.remove(*r)
def notify_observers(self, which=None, min_priority=None):
"""
@ -74,21 +75,13 @@ class Observable(object):
if which is None:
which = self
if min_priority is None:
[callble(self, which=which) for _, _, callble in self._observer_callables_]
[callble(self, which=which) for _, _, callble in self.observers]
else:
for p, _, callble in self._observer_callables_:
for p, _, callble in self.observers:
if p <= min_priority:
break
callble(self, which=which)
def _insert_sorted(self, p, o, c):
ins = 0
for pr, _, _ in self._observer_callables_:
if p > pr:
break
ins += 1
self._observer_callables_.insert(ins, (p, o, c))
#===============================================================================
# Foundation framework for parameterized and param objects:
#===============================================================================
@ -192,7 +185,7 @@ class Pickleable(object):
def __getstate__(self):
ignore_list = ([#'_parent_', '_parent_index_',
#'_observer_callables_',
#'observers',
'_param_array_', '_gradient_array_', '_fixes_',
'_Cacher_wrap__cachers']
#+ self.parameter_names(recursive=False)

View file

@ -90,7 +90,7 @@ class Parameterized(Parameterizable):
child_node = child.build_pydot(G)
G.add_edge(pydot.Edge(node, child_node))
for o in self._observer_callables_.keys():
for o in self.observers.keys():
label = o.name if hasattr(o, 'name') else str(o)
observed_node = pydot.Node(id(o), label=label)
G.add_node(observed_node)

View file

@ -191,13 +191,13 @@ class Test(ListDictTestCase):
par.count = 0
par.add_observer(self, self._callback, 1)
pcopy = GPRegression(par.X.copy(), par.Y.copy(), kernel=par.kern.copy())
self.assertNotIn(par._observer_callables_[0], pcopy._observer_callables_)
self.assertNotIn(par.observers[0], pcopy.observers)
pcopy = par.copy()
pcopy.name = "copy"
self.assertTrue(par.checkgrad())
self.assertTrue(pcopy.checkgrad())
self.assertTrue(pcopy.kern.checkgrad())
self.assertIn(par._observer_callables_[0], pcopy._observer_callables_)
self.assertIn(par.observers[0], pcopy.observers)
self.assertEqual(par.count, 3)
self.assertEqual(pcopy.count, 6) # 3 of each call to checkgrad