mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
making observables accessable
This commit is contained in:
parent
22e4f8a1e8
commit
04a889b3a9
7 changed files with 45 additions and 59 deletions
|
|
@ -4,6 +4,7 @@
|
|||
from model import *
|
||||
from parameterization.parameterized import adjust_name_for_printing, Parameterizable
|
||||
from parameterization.param import Param, ParamConcatenation
|
||||
from parameterization.observable_array import ObsAr
|
||||
|
||||
from gp import GP
|
||||
from sparse_gp import SparseGP
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Created on 27 Feb 2014
|
|||
'''
|
||||
|
||||
from collections import defaultdict
|
||||
import weakref
|
||||
|
||||
def intarray_default_factory():
|
||||
import numpy as np
|
||||
|
|
@ -41,49 +42,40 @@ class ObservablesList(object):
|
|||
def __init__(self):
|
||||
self._poc = []
|
||||
|
||||
def remove(self, value):
|
||||
return self._poc.remove(value)
|
||||
|
||||
|
||||
def __delitem__(self, ind):
|
||||
return self._poc.__delitem__(ind)
|
||||
|
||||
|
||||
def __setitem__(self, ind, item):
|
||||
return self._poc.__setitem__(ind, item)
|
||||
|
||||
|
||||
def __getitem__(self, ind):
|
||||
return self._poc.__getitem__(ind)
|
||||
|
||||
def remove(self, priority, observable, callble):
|
||||
"""
|
||||
"""
|
||||
self._poc.remove((priority, observable, callble))
|
||||
|
||||
def __repr__(self):
|
||||
return self._poc.__repr__()
|
||||
|
||||
|
||||
def append(self, obj):
|
||||
return self._poc.append(obj)
|
||||
|
||||
|
||||
def index(self, value):
|
||||
return self._poc.index(value)
|
||||
|
||||
|
||||
def extend(self, iterable):
|
||||
return self._poc.extend(iterable)
|
||||
|
||||
|
||||
def add(self, priority, observable, callble):
|
||||
i = 0
|
||||
for i, [p, _, _] in enumerate(self._poc):
|
||||
if p < priority:
|
||||
break
|
||||
self._poc.insert(i, (priority, weakref.ref(observable), callble))
|
||||
|
||||
def __str__(self):
|
||||
return self._poc.__str__()
|
||||
|
||||
ret = []
|
||||
curr_p = None
|
||||
for p, o, c in self:
|
||||
curr = ''
|
||||
if curr_p != p:
|
||||
pre = "{!s}: ".format(p)
|
||||
curr_pre = pre
|
||||
else: curr_pre = " "*len(pre)
|
||||
curr_p = p
|
||||
curr += curr_pre
|
||||
ret.append(curr + ", ".join(map(str, [o,c])))
|
||||
return '\n'.join(ret)
|
||||
|
||||
def __iter__(self):
|
||||
return self._poc.__iter__()
|
||||
|
||||
|
||||
def insert(self, index, obj):
|
||||
return self._poc.insert(index, obj)
|
||||
|
||||
self._poc = [(p,o,c) for p,o,c in self._poc if o() is not None]
|
||||
for p, o, c in self._poc:
|
||||
if o() is not None:
|
||||
yield p, o(), c
|
||||
|
||||
def __len__(self):
|
||||
return self._poc.__len__()
|
||||
|
|
@ -106,6 +98,6 @@ class ObservablesList(object):
|
|||
def __setstate__(self, state):
|
||||
self._poc = []
|
||||
for p, o, c in state:
|
||||
self._poc.append((p,o,getattr(o, c)))
|
||||
self.add(p,o,getattr(o, c))
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class ObsAr(np.ndarray, Pickleable, Observable):
|
|||
def __array_finalize__(self, obj):
|
||||
# see InfoArray.__array_finalize__ for comments
|
||||
if obj is None: return
|
||||
self._observer_callables_ = getattr(obj, '_observer_callables_', None)
|
||||
self.observers = getattr(obj, 'observers', None)
|
||||
|
||||
def __array_wrap__(self, out_arr, context=None):
|
||||
return out_arr.view(np.ndarray)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
import pydot
|
||||
node = pydot.Node(id(self), shape='record', label=self.name)
|
||||
G.add_node(node)
|
||||
for o in self._observer_callables_.keys():
|
||||
for o in self.observers.keys():
|
||||
label = o.name if hasattr(o, 'name') else str(o)
|
||||
observed_node = pydot.Node(id(o), label=label)
|
||||
G.add_node(observed_node)
|
||||
|
|
|
|||
|
|
@ -44,22 +44,23 @@ class Observable(object):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(Observable, self).__init__()
|
||||
from lists_and_dicts import ObservablesList
|
||||
self._observer_callables_ = ObservablesList()
|
||||
self.observers = ObservablesList()
|
||||
|
||||
def add_observer(self, observer, callble, priority=0):
|
||||
self._insert_sorted(priority, observer, callble)
|
||||
self.observers.add(priority, observer, callble)
|
||||
|
||||
def remove_observer(self, observer, callble=None):
|
||||
to_remove = []
|
||||
for p, obs, clble in self._observer_callables_:
|
||||
for poc in self.observers:
|
||||
_, obs, clble = poc
|
||||
if callble is not None:
|
||||
if (obs == observer) and (callble == clble):
|
||||
to_remove.append((p, obs, clble))
|
||||
to_remove.append(poc)
|
||||
else:
|
||||
if obs is observer:
|
||||
to_remove.append((p, obs, clble))
|
||||
to_remove.append(poc)
|
||||
for r in to_remove:
|
||||
self._observer_callables_.remove(r)
|
||||
self.observers.remove(*r)
|
||||
|
||||
def notify_observers(self, which=None, min_priority=None):
|
||||
"""
|
||||
|
|
@ -74,21 +75,13 @@ class Observable(object):
|
|||
if which is None:
|
||||
which = self
|
||||
if min_priority is None:
|
||||
[callble(self, which=which) for _, _, callble in self._observer_callables_]
|
||||
[callble(self, which=which) for _, _, callble in self.observers]
|
||||
else:
|
||||
for p, _, callble in self._observer_callables_:
|
||||
for p, _, callble in self.observers:
|
||||
if p <= min_priority:
|
||||
break
|
||||
callble(self, which=which)
|
||||
|
||||
def _insert_sorted(self, p, o, c):
|
||||
ins = 0
|
||||
for pr, _, _ in self._observer_callables_:
|
||||
if p > pr:
|
||||
break
|
||||
ins += 1
|
||||
self._observer_callables_.insert(ins, (p, o, c))
|
||||
|
||||
#===============================================================================
|
||||
# Foundation framework for parameterized and param objects:
|
||||
#===============================================================================
|
||||
|
|
@ -192,7 +185,7 @@ class Pickleable(object):
|
|||
|
||||
def __getstate__(self):
|
||||
ignore_list = ([#'_parent_', '_parent_index_',
|
||||
#'_observer_callables_',
|
||||
#'observers',
|
||||
'_param_array_', '_gradient_array_', '_fixes_',
|
||||
'_Cacher_wrap__cachers']
|
||||
#+ self.parameter_names(recursive=False)
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class Parameterized(Parameterizable):
|
|||
child_node = child.build_pydot(G)
|
||||
G.add_edge(pydot.Edge(node, child_node))
|
||||
|
||||
for o in self._observer_callables_.keys():
|
||||
for o in self.observers.keys():
|
||||
label = o.name if hasattr(o, 'name') else str(o)
|
||||
observed_node = pydot.Node(id(o), label=label)
|
||||
G.add_node(observed_node)
|
||||
|
|
|
|||
|
|
@ -191,13 +191,13 @@ class Test(ListDictTestCase):
|
|||
par.count = 0
|
||||
par.add_observer(self, self._callback, 1)
|
||||
pcopy = GPRegression(par.X.copy(), par.Y.copy(), kernel=par.kern.copy())
|
||||
self.assertNotIn(par._observer_callables_[0], pcopy._observer_callables_)
|
||||
self.assertNotIn(par.observers[0], pcopy.observers)
|
||||
pcopy = par.copy()
|
||||
pcopy.name = "copy"
|
||||
self.assertTrue(par.checkgrad())
|
||||
self.assertTrue(pcopy.checkgrad())
|
||||
self.assertTrue(pcopy.kern.checkgrad())
|
||||
self.assertIn(par._observer_callables_[0], pcopy._observer_callables_)
|
||||
self.assertIn(par.observers[0], pcopy.observers)
|
||||
self.assertEqual(par.count, 3)
|
||||
self.assertEqual(pcopy.count, 6) # 3 of each call to checkgrad
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue