From 64fb6ddc4cd6c726119c4b7f72ef9b0ef85c3421 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Fri, 9 May 2014 09:04:56 +0100 Subject: [PATCH] [caching] done right --- GPy/util/caching.py | 149 ++++++++++++++++++++++++++------------------ 1 file changed, 87 insertions(+), 62 deletions(-) diff --git a/GPy/util/caching.py b/GPy/util/caching.py index bb162ee3..533d287a 100644 --- a/GPy/util/caching.py +++ b/GPy/util/caching.py @@ -1,84 +1,107 @@ from ..core.parameterization.parameter_core import Observable -import itertools +import itertools, collections, weakref class Cacher(object): - """ - - - """ - def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()): + """ + Parameters: + *********** + :param callable operation: function to cache + :param int limit: depth of cacher + :param [int] ignore_args: list of indices, pointing at arguments to ignore in *args of operation(*args). This includes self! + :param [str] force_kwargs: list of kwarg names (strings). If a kwarg with that name is given, the cacher will force recompute and wont cache anything. + """ self.limit = int(limit) self.ignore_args = ignore_args self.force_kwargs = force_kwargs self.operation=operation - self.cached_inputs = [] - self.cached_outputs = [] - self.inputs_changed = [] + self.order = collections.deque() + self.cached_inputs = {} # point from cache_ids to a list of [ind_ids], which where used in cache cache_id + + #======================================================================= + # point from each ind_id to [ref(obj), cache_ids] + # 0: a weak reference to the object itself + # 1: the cache_ids in which this ind_id is used (len will be how many times we have seen this ind_id) + self.cached_input_ids = {} + #======================================================================= + + self.cached_outputs = {} # point from cache_ids to outputs + self.inputs_changed = {} # point from cache_ids to bools + + def combine_args_kw(self, args, kw): + "Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute" + return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0])) + + def preprocess(self, combined_args_kw, ignore_args): + "get the cacheid (conc. string of argument ids in order) ignoring ignore_args" + return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args) + + def ensure_cache_length(self, cache_id): + "Ensures the cache is within its limits and has one place free" + if len(self.order) == self.limit: + # we have reached the limit, so lets release one element + cache_id = self.order.popleft() + combined_args_kw = self.cached_inputs[cache_id] + for ind_id in combined_args_kw: + ref, cache_ids = self.cached_input_ids[ind_id] + if len(cache_ids) == 1 and ref() is not None: + ref().remove_observer(self, self.on_cache_changed) + del self.cached_input_ids[ind_id] + else: + cache_ids.remove(cache_id) + self.cached_input_ids[ind_id] = [ref, cache_ids] + del self.cached_outputs[cache_id] + del self.inputs_changed[cache_id] + del self.cached_inputs[cache_id] + + def add_to_cache(self, cache_id, combined_args_kw, output): + self.inputs_changed[cache_id] = False + self.cached_outputs[cache_id] = output + self.order.append(cache_id) + self.cached_inputs[cache_id] = combined_args_kw + for a in combined_args_kw: + ind_id = id(a) + v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []]) + v[1].append(cache_id) + if len(v[1]) == 1: + a.add_observer(self, self.on_cache_changed) + self.cached_input_ids[ind_id] = v def __call__(self, *args, **kw): """ A wrapper function for self.operation, """ - #ensure that specified arguments are ignored - items = sorted(kw.items(), key=lambda x: x[0]) - oa_all = args + tuple(a for _,a in items) - if len(self.ignore_args) != 0: - oa = [a for i,a in itertools.chain(enumerate(args), items) if i not in self.ignore_args and i not in self.force_kwargs] - else: - oa = oa_all - - # this makes sure we only add an observer once, and that None can be in args - observable_args = [] - for a in oa: - if (not any(a is ai for ai in observable_args)) and a is not None: - observable_args.append(a) - - #make sure that all the found argument really are observable: - #otherswise don't cache anything, pass args straight though - if not all([isinstance(arg, Observable) for arg in observable_args]): - return self.operation(*args, **kw) - + # 1: Check whether we have forced recompute arguments: if len(self.force_kwargs) != 0: - # check if there are force args, which force reloading for k in self.force_kwargs: if k in kw and kw[k] is not None: return self.operation(*args, **kw) - # TODO: WARNING !!! Cache OFFSWITCH !!! WARNING - # return self.operation(*args, **kw) - #if the result is cached, return the cached computation - state = [all(a is b for a, b in itertools.izip_longest(args, cached_i)) for cached_i in self.cached_inputs] + # 2: preprocess and get the unique id string for this call + combined_args_kw = self.combine_args_kw(args, kw) + cache_id = self.preprocess(combined_args_kw, self.ignore_args) + + # 2: if anything is not cachable, we will just return the operation, without caching + if reduce(lambda a,b: a or (not isinstance(b, Observable)), combined_args_kw, False): + return self.operation(*args, **kw) + # 3&4: check whether this cache_id has been cached, then has it changed? try: - if any(state): - i = state.index(True) - if self.inputs_changed[i]: - #(elements of) the args have changed since we last computed: update - self.cached_outputs[i] = self.operation(*args, **kw) - self.inputs_changed[i] = False - return self.cached_outputs[i] - else: - #first time we've seen these arguments: compute - - #first make sure the depth limit isn't exceeded - if len(self.cached_inputs) == self.limit: - args_ = self.cached_inputs.pop(0) - args_ = [a for i,a in enumerate(args_) if i not in self.ignore_args and i not in self.force_kwargs] - [a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None] - self.inputs_changed.pop(0) - self.cached_outputs.pop(0) - #compute - self.cached_inputs.append(oa_all) - self.cached_outputs.append(self.operation(*args, **kw)) - self.inputs_changed.append(False) - [a.add_observer(self, self.on_cache_changed) for a in observable_args] - return self.cached_outputs[-1]#return + if(self.inputs_changed[cache_id]): + # 4: This happens, when one element has changed for this cache id + self.inputs_changed[cache_id] = False + self.cached_outputs[cache_id] = self.operation(*args, **kw) + except KeyError: + # 3: This is when we never saw this chache_id: + self.ensure_cache_length(cache_id) + self.add_to_cache(cache_id, combined_args_kw, self.operation(*args, **kw)) except: self.reset() raise + # 5: We have seen this cache_id and it is cached: + return self.cached_outputs[cache_id] def on_cache_changed(self, direct, which=None): """ @@ -86,17 +109,19 @@ class Cacher(object): this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here. """ - self.inputs_changed = [any([a is direct or a is which for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)] + for ind_id in [id(direct), id(which)]: + _, cache_ids = self.cached_input_ids.get(ind_id, [None, []]) + for cache_id in cache_ids: + self.inputs_changed[cache_id] = True def reset(self): """ Totally reset the cache """ - [[a.remove_observer(self, self.on_cache_changed) for a in args if isinstance(a, Observable)] for args in self.cached_inputs] - [[a.remove_observer(self, self.reset) for a in args if isinstance(a, Observable)] for args in self.cached_inputs] - self.cached_inputs = [] - self.cached_outputs = [] - self.inputs_changed = [] + [a().remove_observer(self, self.on_cache_changed) if (a() is not None) else None for a in self.cached_input_ids.values()] + self.cached_input_ids = {} + self.cached_outputs = {} + self.inputs_changed = {} def __deepcopy__(self, memo=None): return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs)