mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 03:52:39 +02:00
gradient operations and cachong
This commit is contained in:
parent
0c92fca31a
commit
b19f9b9f33
8 changed files with 151 additions and 138 deletions
|
|
@ -12,7 +12,7 @@ class Cacher(object):
|
|||
def __call__(self, *args):
|
||||
if self._reset_on_first:
|
||||
assert isinstance(args[0], Observable)
|
||||
args[0].add_observer(self.reset)
|
||||
args[0].add_observer(self, self.reset)
|
||||
cached_args = args
|
||||
else:
|
||||
cached_args = args[1:]
|
||||
|
|
@ -29,21 +29,21 @@ class Cacher(object):
|
|||
else:
|
||||
if len(self.cached_inputs) == self.limit:
|
||||
args_ = self.cached_inputs.pop(0)
|
||||
[a.remove_observer(self.on_cache_changed) for a in args_]
|
||||
[a.remove_observer(self, self.on_cache_changed) for a in args_]
|
||||
self.inputs_changed.pop(0)
|
||||
self.cached_outputs.pop(0)
|
||||
|
||||
self.cached_inputs.append(cached_args)
|
||||
self.cached_outputs.append(self.operation(*args))
|
||||
self.inputs_changed.append(False)
|
||||
[a.add_observer(self.on_cache_changed) for a in args]
|
||||
[a.add_observer(self, self.on_cache_changed) for a in args]
|
||||
return self.cached_outputs[-1]
|
||||
|
||||
def on_cache_changed(self, arg):
|
||||
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
|
||||
|
||||
def reset(self, obj):
|
||||
[[a.remove_observer(self.reset) for a in args] for args in self.cached_inputs]
|
||||
[[a.remove_observer(self, self.reset) for a in args] for args in self.cached_inputs]
|
||||
self.cached_inputs = []
|
||||
self.cached_outputs = []
|
||||
self.inputs_changed = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue