[caching] done right

This commit is contained in:
Max Zwiessele 2014-05-09 09:04:56 +01:00
parent 37e46b5da3
commit 64fb6ddc4c

View file

@ -1,84 +1,107 @@
from ..core.parameterization.parameter_core import Observable
import itertools
import itertools, collections, weakref
class Cacher(object):
"""
"""
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
"""
Parameters:
***********
:param callable operation: function to cache
:param int limit: depth of cacher
:param [int] ignore_args: list of indices, pointing at arguments to ignore in *args of operation(*args). This includes self!
:param [str] force_kwargs: list of kwarg names (strings). If a kwarg with that name is given, the cacher will force recompute and wont cache anything.
"""
self.limit = int(limit)
self.ignore_args = ignore_args
self.force_kwargs = force_kwargs
self.operation=operation
self.cached_inputs = []
self.cached_outputs = []
self.inputs_changed = []
self.order = collections.deque()
self.cached_inputs = {} # point from cache_ids to a list of [ind_ids], which where used in cache cache_id
#=======================================================================
# point from each ind_id to [ref(obj), cache_ids]
# 0: a weak reference to the object itself
# 1: the cache_ids in which this ind_id is used (len will be how many times we have seen this ind_id)
self.cached_input_ids = {}
#=======================================================================
self.cached_outputs = {} # point from cache_ids to outputs
self.inputs_changed = {} # point from cache_ids to bools
def combine_args_kw(self, args, kw):
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
def preprocess(self, combined_args_kw, ignore_args):
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
def ensure_cache_length(self, cache_id):
"Ensures the cache is within its limits and has one place free"
if len(self.order) == self.limit:
# we have reached the limit, so lets release one element
cache_id = self.order.popleft()
combined_args_kw = self.cached_inputs[cache_id]
for ind_id in combined_args_kw:
ref, cache_ids = self.cached_input_ids[ind_id]
if len(cache_ids) == 1 and ref() is not None:
ref().remove_observer(self, self.on_cache_changed)
del self.cached_input_ids[ind_id]
else:
cache_ids.remove(cache_id)
self.cached_input_ids[ind_id] = [ref, cache_ids]
del self.cached_outputs[cache_id]
del self.inputs_changed[cache_id]
del self.cached_inputs[cache_id]
def add_to_cache(self, cache_id, combined_args_kw, output):
self.inputs_changed[cache_id] = False
self.cached_outputs[cache_id] = output
self.order.append(cache_id)
self.cached_inputs[cache_id] = combined_args_kw
for a in combined_args_kw:
ind_id = id(a)
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
v[1].append(cache_id)
if len(v[1]) == 1:
a.add_observer(self, self.on_cache_changed)
self.cached_input_ids[ind_id] = v
def __call__(self, *args, **kw):
"""
A wrapper function for self.operation,
"""
#ensure that specified arguments are ignored
items = sorted(kw.items(), key=lambda x: x[0])
oa_all = args + tuple(a for _,a in items)
if len(self.ignore_args) != 0:
oa = [a for i,a in itertools.chain(enumerate(args), items) if i not in self.ignore_args and i not in self.force_kwargs]
else:
oa = oa_all
# this makes sure we only add an observer once, and that None can be in args
observable_args = []
for a in oa:
if (not any(a is ai for ai in observable_args)) and a is not None:
observable_args.append(a)
#make sure that all the found argument really are observable:
#otherswise don't cache anything, pass args straight though
if not all([isinstance(arg, Observable) for arg in observable_args]):
return self.operation(*args, **kw)
# 1: Check whether we have forced recompute arguments:
if len(self.force_kwargs) != 0:
# check if there are force args, which force reloading
for k in self.force_kwargs:
if k in kw and kw[k] is not None:
return self.operation(*args, **kw)
# TODO: WARNING !!! Cache OFFSWITCH !!! WARNING
# return self.operation(*args, **kw)
#if the result is cached, return the cached computation
state = [all(a is b for a, b in itertools.izip_longest(args, cached_i)) for cached_i in self.cached_inputs]
# 2: preprocess and get the unique id string for this call
combined_args_kw = self.combine_args_kw(args, kw)
cache_id = self.preprocess(combined_args_kw, self.ignore_args)
# 2: if anything is not cachable, we will just return the operation, without caching
if reduce(lambda a,b: a or (not isinstance(b, Observable)), combined_args_kw, False):
return self.operation(*args, **kw)
# 3&4: check whether this cache_id has been cached, then has it changed?
try:
if any(state):
i = state.index(True)
if self.inputs_changed[i]:
#(elements of) the args have changed since we last computed: update
self.cached_outputs[i] = self.operation(*args, **kw)
self.inputs_changed[i] = False
return self.cached_outputs[i]
else:
#first time we've seen these arguments: compute
#first make sure the depth limit isn't exceeded
if len(self.cached_inputs) == self.limit:
args_ = self.cached_inputs.pop(0)
args_ = [a for i,a in enumerate(args_) if i not in self.ignore_args and i not in self.force_kwargs]
[a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None]
self.inputs_changed.pop(0)
self.cached_outputs.pop(0)
#compute
self.cached_inputs.append(oa_all)
self.cached_outputs.append(self.operation(*args, **kw))
self.inputs_changed.append(False)
[a.add_observer(self, self.on_cache_changed) for a in observable_args]
return self.cached_outputs[-1]#return
if(self.inputs_changed[cache_id]):
# 4: This happens, when one element has changed for this cache id
self.inputs_changed[cache_id] = False
self.cached_outputs[cache_id] = self.operation(*args, **kw)
except KeyError:
# 3: This is when we never saw this chache_id:
self.ensure_cache_length(cache_id)
self.add_to_cache(cache_id, combined_args_kw, self.operation(*args, **kw))
except:
self.reset()
raise
# 5: We have seen this cache_id and it is cached:
return self.cached_outputs[cache_id]
def on_cache_changed(self, direct, which=None):
"""
@ -86,17 +109,19 @@ class Cacher(object):
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
"""
self.inputs_changed = [any([a is direct or a is which for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
for ind_id in [id(direct), id(which)]:
_, cache_ids = self.cached_input_ids.get(ind_id, [None, []])
for cache_id in cache_ids:
self.inputs_changed[cache_id] = True
def reset(self):
"""
Totally reset the cache
"""
[[a.remove_observer(self, self.on_cache_changed) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
[[a.remove_observer(self, self.reset) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
self.cached_inputs = []
self.cached_outputs = []
self.inputs_changed = []
[a().remove_observer(self, self.on_cache_changed) if (a() is not None) else None for a in self.cached_input_ids.values()]
self.cached_input_ids = {}
self.cached_outputs = {}
self.inputs_changed = {}
def __deepcopy__(self, memo=None):
return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs)