mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 04:22:38 +02:00
made observers accessible and observers now only weak reference the observables
This commit is contained in:
parent
04a889b3a9
commit
11059fb615
3 changed files with 24 additions and 8 deletions
|
|
@ -42,20 +42,29 @@ class ObservablesList(object):
|
|||
def __init__(self):
|
||||
self._poc = []
|
||||
|
||||
def __getitem__(self, ind):
|
||||
p,o,c = self._poc[ind]
|
||||
return p, o(), c
|
||||
|
||||
def remove(self, priority, observable, callble):
|
||||
"""
|
||||
"""
|
||||
self._poc.remove((priority, observable, callble))
|
||||
self.flush()
|
||||
for i in range(len(self) - 1, -1, -1):
|
||||
p,o,c = self[i]
|
||||
if priority==p and observable==o and callble==c:
|
||||
del self._poc[i]
|
||||
|
||||
def __repr__(self):
|
||||
return self._poc.__repr__()
|
||||
|
||||
def add(self, priority, observable, callble):
|
||||
i = 0
|
||||
for i, [p, _, _] in enumerate(self._poc):
|
||||
if p < priority:
|
||||
ins = 0
|
||||
for pr, _, _ in self:
|
||||
if priority > pr:
|
||||
break
|
||||
self._poc.insert(i, (priority, weakref.ref(observable), callble))
|
||||
ins += 1
|
||||
self._poc.insert(ins, (priority, weakref.ref(observable), callble))
|
||||
|
||||
def __str__(self):
|
||||
ret = []
|
||||
|
|
@ -68,25 +77,31 @@ class ObservablesList(object):
|
|||
else: curr_pre = " "*len(pre)
|
||||
curr_p = p
|
||||
curr += curr_pre
|
||||
ret.append(curr + ", ".join(map(str, [o,c])))
|
||||
ret.append(curr + ", ".join(map(repr, [o,c])))
|
||||
return '\n'.join(ret)
|
||||
|
||||
def __iter__(self):
|
||||
def flush(self):
|
||||
self._poc = [(p,o,c) for p,o,c in self._poc if o() is not None]
|
||||
|
||||
def __iter__(self):
|
||||
self.flush()
|
||||
for p, o, c in self._poc:
|
||||
if o() is not None:
|
||||
yield p, o(), c
|
||||
|
||||
def __len__(self):
|
||||
self.flush()
|
||||
return self._poc.__len__()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
self.flush()
|
||||
s = ObservablesList()
|
||||
import copy
|
||||
s._poc = copy.deepcopy(self._poc, memo)
|
||||
return s
|
||||
|
||||
def __getstate__(self):
|
||||
self.flush()
|
||||
from ...util.caching import Cacher
|
||||
obs = []
|
||||
for p, o, c in self:
|
||||
|
|
|
|||
|
|
@ -324,7 +324,7 @@ class ParamConcatenation(object):
|
|||
if update:
|
||||
self.update_all_params()
|
||||
def values(self):
|
||||
return numpy.hstack([p.param_array for p in self.params])
|
||||
return numpy.hstack([p.param_array.flat for p in self.params])
|
||||
#===========================================================================
|
||||
# parameter operations:
|
||||
#===========================================================================
|
||||
|
|
|
|||
|
|
@ -197,6 +197,7 @@ class Test(ListDictTestCase):
|
|||
self.assertTrue(par.checkgrad())
|
||||
self.assertTrue(pcopy.checkgrad())
|
||||
self.assertTrue(pcopy.kern.checkgrad())
|
||||
import ipdb;ipdb.set_trace()
|
||||
self.assertIn(par.observers[0], pcopy.observers)
|
||||
self.assertEqual(par.count, 3)
|
||||
self.assertEqual(pcopy.count, 6) # 3 of each call to checkgrad
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue