linear without caching, derivatives done

This commit is contained in:
Max Zwiessele 2014-02-21 09:14:31 +00:00
parent 1d722c4f28
commit 0c92fca31a
7 changed files with 71 additions and 54 deletions

View file

@ -1,5 +1,4 @@
from ..core.parameterization.parameter_core import Observable
from ..core.parameterization.array_core import ParamList
class Cacher(object):
def __init__(self, operation, limit=5, reset_on_first=False):
@ -13,7 +12,7 @@ class Cacher(object):
def __call__(self, *args):
if self._reset_on_first:
assert isinstance(args[0], Observable)
args[0].add_observer(args[0], self.reset)
args[0].add_observer(self.reset)
cached_args = args
else:
cached_args = args[1:]
@ -30,21 +29,21 @@ class Cacher(object):
else:
if len(self.cached_inputs) == self.limit:
args_ = self.cached_inputs.pop(0)
[a.remove_observer(self) for a in args_]
[a.remove_observer(self.on_cache_changed) for a in args_]
self.inputs_changed.pop(0)
self.cached_outputs.pop(0)
self.cached_inputs.append(cached_args)
self.cached_outputs.append(self.operation(*args))
self.inputs_changed.append(False)
[a.add_observer(self, self.on_cache_changed) for a in args]
[a.add_observer(self.on_cache_changed) for a in args]
return self.cached_outputs[-1]
def on_cache_changed(self, arg):
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
def reset(self, obj):
[[a.remove_observer(self) for a in args] for args in self.cached_inputs]
[[a.remove_observer(self.reset) for a in args] for args in self.cached_inputs]
self.cached_inputs = []
self.cached_outputs = []
self.inputs_changed = []