mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
caching can handle None
This commit is contained in:
parent
ebea658f5c
commit
a6d3fda234
3 changed files with 11 additions and 8 deletions
|
|
@ -4,7 +4,7 @@
|
||||||
__updated__ = '2013-12-16'
|
__updated__ = '2013-12-16'
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parameter_core import Observable, Parameterizable
|
from parameter_core import Observable
|
||||||
|
|
||||||
class ParamList(list):
|
class ParamList(list):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,13 @@ class Observable(object):
|
||||||
self._observer_callables_[observer].append(callble)
|
self._observer_callables_[observer].append(callble)
|
||||||
|
|
||||||
def remove_observer(self, observer, callble=None):
|
def remove_observer(self, observer, callble=None):
|
||||||
if callble is None:
|
if observer in self._observer_callables_:
|
||||||
del self._observer_callables_[observer]
|
if callble is None:
|
||||||
else:
|
del self._observer_callables_[observer]
|
||||||
self._observer_callables_[observer].remove(callble)
|
elif callble in self._observer_callables_[observer]:
|
||||||
if len(self._observer_callables_[observer]) == 0:
|
self._observer_callables_[observer].remove(callble)
|
||||||
self.remove_observer(observer)
|
if len(self._observer_callables_[observer]) == 0:
|
||||||
|
self.remove_observer(observer)
|
||||||
|
|
||||||
def _notify_observers(self):
|
def _notify_observers(self):
|
||||||
[[callble(self) for callble in callables]
|
[[callble(self) for callble in callables]
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,9 @@ class Cacher(object):
|
||||||
for a in ca:
|
for a in ca:
|
||||||
if (not any(a is ai for ai in cached_args)) and a is not None:
|
if (not any(a is ai for ai in cached_args)) and a is not None:
|
||||||
cached_args.append(a)
|
cached_args.append(a)
|
||||||
|
|
||||||
if not all([isinstance(arg, Observable) for arg in cached_args]):
|
if not all([isinstance(arg, Observable) for arg in cached_args]):
|
||||||
|
print cached_args
|
||||||
|
import ipdb;ipdb.set_trace()
|
||||||
return self.operation(*args)
|
return self.operation(*args)
|
||||||
|
|
||||||
if cached_args in self.cached_inputs:
|
if cached_args in self.cached_inputs:
|
||||||
|
|
@ -46,6 +47,7 @@ class Cacher(object):
|
||||||
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
|
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
|
||||||
|
|
||||||
def reset(self, obj):
|
def reset(self, obj):
|
||||||
|
[[a.remove_observer(self, self.on_cache_changed) for a in args] for args in self.cached_inputs]
|
||||||
[[a.remove_observer(self, self.reset) for a in args] for args in self.cached_inputs]
|
[[a.remove_observer(self, self.reset) for a in args] for args in self.cached_inputs]
|
||||||
self.cached_inputs = []
|
self.cached_inputs = []
|
||||||
self.cached_outputs = []
|
self.cached_outputs = []
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue