From c6b1f513d3569cc55eb204b8d4fdd6f9649bd741 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Thu, 13 Mar 2014 12:29:14 +0000 Subject: [PATCH] caching now resets cache on error --- GPy/util/caching.py | 50 ++++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/GPy/util/caching.py b/GPy/util/caching.py index 5de03059..ec8f9754 100644 --- a/GPy/util/caching.py +++ b/GPy/util/caching.py @@ -52,29 +52,33 @@ class Cacher(object): #if the result is cached, return the cached computation state = [all(a is b for a, b in itertools.izip_longest(args, cached_i)) for cached_i in self.cached_inputs] - if any(state): - i = state.index(True) - if self.inputs_changed[i]: - #(elements of) the args have changed since we last computed: update - self.cached_outputs[i] = self.operation(*args, **kw) - self.inputs_changed[i] = False - return self.cached_outputs[i] - else: - #first time we've seen these arguments: compute + try: + if any(state): + i = state.index(True) + if self.inputs_changed[i]: + #(elements of) the args have changed since we last computed: update + self.cached_outputs[i] = self.operation(*args, **kw) + self.inputs_changed[i] = False + return self.cached_outputs[i] + else: + #first time we've seen these arguments: compute - #first make sure the depth limit isn't exceeded - if len(self.cached_inputs) == self.limit: - args_ = self.cached_inputs.pop(0) - [a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None] - self.inputs_changed.pop(0) - self.cached_outputs.pop(0) - - #compute - self.cached_inputs.append(oa_all) - self.cached_outputs.append(self.operation(*args, **kw)) - self.inputs_changed.append(False) - [a.add_observer(self, self.on_cache_changed) for a in observable_args] - return self.cached_outputs[-1]#return + #first make sure the depth limit isn't exceeded + if len(self.cached_inputs) == self.limit: + args_ = self.cached_inputs.pop(0) + [a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None] + self.inputs_changed.pop(0) + self.cached_outputs.pop(0) + #compute + self.cached_inputs.append(oa_all) + self.cached_outputs.append(self.operation(*args, **kw)) + self.inputs_changed.append(False) + [a.add_observer(self, self.on_cache_changed) for a in observable_args] + return self.cached_outputs[-1]#return + except: + raise + finally: + self.reset() def on_cache_changed(self, arg): """ @@ -84,7 +88,7 @@ class Cacher(object): """ self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)] - def reset(self, obj): + def reset(self): """ Totally reset the cache """