diff --git a/GPy/core/__init__.py b/GPy/core/__init__.py index a42d76ed..25651827 100644 --- a/GPy/core/__init__.py +++ b/GPy/core/__init__.py @@ -4,6 +4,7 @@ from model import * from parameterization.parameterized import adjust_name_for_printing, Parameterizable from parameterization.param import Param, ParamConcatenation +from parameterization.observable_array import ObsAr from gp import GP from sparse_gp import SparseGP diff --git a/GPy/core/parameterization/lists_and_dicts.py b/GPy/core/parameterization/lists_and_dicts.py index 6902c249..dd93c5ba 100644 --- a/GPy/core/parameterization/lists_and_dicts.py +++ b/GPy/core/parameterization/lists_and_dicts.py @@ -5,6 +5,7 @@ Created on 27 Feb 2014 ''' from collections import defaultdict +import weakref def intarray_default_factory(): import numpy as np @@ -41,49 +42,40 @@ class ObservablesList(object): def __init__(self): self._poc = [] - def remove(self, value): - return self._poc.remove(value) - - - def __delitem__(self, ind): - return self._poc.__delitem__(ind) - - - def __setitem__(self, ind, item): - return self._poc.__setitem__(ind, item) - - - def __getitem__(self, ind): - return self._poc.__getitem__(ind) - + def remove(self, priority, observable, callble): + """ + """ + self._poc.remove((priority, observable, callble)) def __repr__(self): return self._poc.__repr__() - - def append(self, obj): - return self._poc.append(obj) - - - def index(self, value): - return self._poc.index(value) - - - def extend(self, iterable): - return self._poc.extend(iterable) - - + def add(self, priority, observable, callble): + i = 0 + for i, [p, _, _] in enumerate(self._poc): + if p < priority: + break + self._poc.insert(i, (priority, weakref.ref(observable), callble)) + def __str__(self): - return self._poc.__str__() - + ret = [] + curr_p = None + for p, o, c in self: + curr = '' + if curr_p != p: + pre = "{!s}: ".format(p) + curr_pre = pre + else: curr_pre = " "*len(pre) + curr_p = p + curr += curr_pre + ret.append(curr + ", ".join(map(str, [o,c]))) + return '\n'.join(ret) def __iter__(self): - return self._poc.__iter__() - - - def insert(self, index, obj): - return self._poc.insert(index, obj) - + self._poc = [(p,o,c) for p,o,c in self._poc if o() is not None] + for p, o, c in self._poc: + if o() is not None: + yield p, o(), c def __len__(self): return self._poc.__len__() @@ -106,6 +98,6 @@ class ObservablesList(object): def __setstate__(self, state): self._poc = [] for p, o, c in state: - self._poc.append((p,o,getattr(o, c))) + self.add(p,o,getattr(o, c)) pass diff --git a/GPy/core/parameterization/observable_array.py b/GPy/core/parameterization/observable_array.py index fc9d6cf2..56d33bfc 100644 --- a/GPy/core/parameterization/observable_array.py +++ b/GPy/core/parameterization/observable_array.py @@ -25,7 +25,7 @@ class ObsAr(np.ndarray, Pickleable, Observable): def __array_finalize__(self, obj): # see InfoArray.__array_finalize__ for comments if obj is None: return - self._observer_callables_ = getattr(obj, '_observer_callables_', None) + self.observers = getattr(obj, 'observers', None) def __array_wrap__(self, out_arr, context=None): return out_arr.view(np.ndarray) diff --git a/GPy/core/parameterization/param.py b/GPy/core/parameterization/param.py index 60bdfe9d..4490a8ee 100644 --- a/GPy/core/parameterization/param.py +++ b/GPy/core/parameterization/param.py @@ -59,7 +59,7 @@ class Param(OptimizationHandlable, ObsAr): import pydot node = pydot.Node(id(self), shape='record', label=self.name) G.add_node(node) - for o in self._observer_callables_.keys(): + for o in self.observers.keys(): label = o.name if hasattr(o, 'name') else str(o) observed_node = pydot.Node(id(o), label=label) G.add_node(observed_node) diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 2dac9bf3..43bc7177 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -44,22 +44,23 @@ class Observable(object): def __init__(self, *args, **kwargs): super(Observable, self).__init__() from lists_and_dicts import ObservablesList - self._observer_callables_ = ObservablesList() + self.observers = ObservablesList() def add_observer(self, observer, callble, priority=0): - self._insert_sorted(priority, observer, callble) + self.observers.add(priority, observer, callble) def remove_observer(self, observer, callble=None): to_remove = [] - for p, obs, clble in self._observer_callables_: + for poc in self.observers: + _, obs, clble = poc if callble is not None: if (obs == observer) and (callble == clble): - to_remove.append((p, obs, clble)) + to_remove.append(poc) else: if obs is observer: - to_remove.append((p, obs, clble)) + to_remove.append(poc) for r in to_remove: - self._observer_callables_.remove(r) + self.observers.remove(*r) def notify_observers(self, which=None, min_priority=None): """ @@ -74,21 +75,13 @@ class Observable(object): if which is None: which = self if min_priority is None: - [callble(self, which=which) for _, _, callble in self._observer_callables_] + [callble(self, which=which) for _, _, callble in self.observers] else: - for p, _, callble in self._observer_callables_: + for p, _, callble in self.observers: if p <= min_priority: break callble(self, which=which) - def _insert_sorted(self, p, o, c): - ins = 0 - for pr, _, _ in self._observer_callables_: - if p > pr: - break - ins += 1 - self._observer_callables_.insert(ins, (p, o, c)) - #=============================================================================== # Foundation framework for parameterized and param objects: #=============================================================================== @@ -192,7 +185,7 @@ class Pickleable(object): def __getstate__(self): ignore_list = ([#'_parent_', '_parent_index_', - #'_observer_callables_', + #'observers', '_param_array_', '_gradient_array_', '_fixes_', '_Cacher_wrap__cachers'] #+ self.parameter_names(recursive=False) diff --git a/GPy/core/parameterization/parameterized.py b/GPy/core/parameterization/parameterized.py index 75085ca2..a794ab40 100644 --- a/GPy/core/parameterization/parameterized.py +++ b/GPy/core/parameterization/parameterized.py @@ -90,7 +90,7 @@ class Parameterized(Parameterizable): child_node = child.build_pydot(G) G.add_edge(pydot.Edge(node, child_node)) - for o in self._observer_callables_.keys(): + for o in self.observers.keys(): label = o.name if hasattr(o, 'name') else str(o) observed_node = pydot.Node(id(o), label=label) G.add_node(observed_node) diff --git a/GPy/testing/pickle_tests.py b/GPy/testing/pickle_tests.py index fc52581a..b888353c 100644 --- a/GPy/testing/pickle_tests.py +++ b/GPy/testing/pickle_tests.py @@ -191,13 +191,13 @@ class Test(ListDictTestCase): par.count = 0 par.add_observer(self, self._callback, 1) pcopy = GPRegression(par.X.copy(), par.Y.copy(), kernel=par.kern.copy()) - self.assertNotIn(par._observer_callables_[0], pcopy._observer_callables_) + self.assertNotIn(par.observers[0], pcopy.observers) pcopy = par.copy() pcopy.name = "copy" self.assertTrue(par.checkgrad()) self.assertTrue(pcopy.checkgrad()) self.assertTrue(pcopy.kern.checkgrad()) - self.assertIn(par._observer_callables_[0], pcopy._observer_callables_) + self.assertIn(par.observers[0], pcopy.observers) self.assertEqual(par.count, 3) self.assertEqual(pcopy.count, 6) # 3 of each call to checkgrad