[caching] done right

This commit is contained in:
Max Zwiessele 2014-05-09 09:04:56 +01:00
parent 37e46b5da3
commit 64fb6ddc4c

View file

@ -1,84 +1,107 @@
from ..core.parameterization.parameter_core import Observable from ..core.parameterization.parameter_core import Observable
import itertools import itertools, collections, weakref
class Cacher(object): class Cacher(object):
"""
"""
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()): def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
"""
Parameters:
***********
:param callable operation: function to cache
:param int limit: depth of cacher
:param [int] ignore_args: list of indices, pointing at arguments to ignore in *args of operation(*args). This includes self!
:param [str] force_kwargs: list of kwarg names (strings). If a kwarg with that name is given, the cacher will force recompute and wont cache anything.
"""
self.limit = int(limit) self.limit = int(limit)
self.ignore_args = ignore_args self.ignore_args = ignore_args
self.force_kwargs = force_kwargs self.force_kwargs = force_kwargs
self.operation=operation self.operation=operation
self.cached_inputs = [] self.order = collections.deque()
self.cached_outputs = [] self.cached_inputs = {} # point from cache_ids to a list of [ind_ids], which where used in cache cache_id
self.inputs_changed = []
#=======================================================================
# point from each ind_id to [ref(obj), cache_ids]
# 0: a weak reference to the object itself
# 1: the cache_ids in which this ind_id is used (len will be how many times we have seen this ind_id)
self.cached_input_ids = {}
#=======================================================================
self.cached_outputs = {} # point from cache_ids to outputs
self.inputs_changed = {} # point from cache_ids to bools
def combine_args_kw(self, args, kw):
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
def preprocess(self, combined_args_kw, ignore_args):
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
def ensure_cache_length(self, cache_id):
"Ensures the cache is within its limits and has one place free"
if len(self.order) == self.limit:
# we have reached the limit, so lets release one element
cache_id = self.order.popleft()
combined_args_kw = self.cached_inputs[cache_id]
for ind_id in combined_args_kw:
ref, cache_ids = self.cached_input_ids[ind_id]
if len(cache_ids) == 1 and ref() is not None:
ref().remove_observer(self, self.on_cache_changed)
del self.cached_input_ids[ind_id]
else:
cache_ids.remove(cache_id)
self.cached_input_ids[ind_id] = [ref, cache_ids]
del self.cached_outputs[cache_id]
del self.inputs_changed[cache_id]
del self.cached_inputs[cache_id]
def add_to_cache(self, cache_id, combined_args_kw, output):
self.inputs_changed[cache_id] = False
self.cached_outputs[cache_id] = output
self.order.append(cache_id)
self.cached_inputs[cache_id] = combined_args_kw
for a in combined_args_kw:
ind_id = id(a)
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
v[1].append(cache_id)
if len(v[1]) == 1:
a.add_observer(self, self.on_cache_changed)
self.cached_input_ids[ind_id] = v
def __call__(self, *args, **kw): def __call__(self, *args, **kw):
""" """
A wrapper function for self.operation, A wrapper function for self.operation,
""" """
#ensure that specified arguments are ignored # 1: Check whether we have forced recompute arguments:
items = sorted(kw.items(), key=lambda x: x[0])
oa_all = args + tuple(a for _,a in items)
if len(self.ignore_args) != 0:
oa = [a for i,a in itertools.chain(enumerate(args), items) if i not in self.ignore_args and i not in self.force_kwargs]
else:
oa = oa_all
# this makes sure we only add an observer once, and that None can be in args
observable_args = []
for a in oa:
if (not any(a is ai for ai in observable_args)) and a is not None:
observable_args.append(a)
#make sure that all the found argument really are observable:
#otherswise don't cache anything, pass args straight though
if not all([isinstance(arg, Observable) for arg in observable_args]):
return self.operation(*args, **kw)
if len(self.force_kwargs) != 0: if len(self.force_kwargs) != 0:
# check if there are force args, which force reloading
for k in self.force_kwargs: for k in self.force_kwargs:
if k in kw and kw[k] is not None: if k in kw and kw[k] is not None:
return self.operation(*args, **kw) return self.operation(*args, **kw)
# TODO: WARNING !!! Cache OFFSWITCH !!! WARNING
# return self.operation(*args, **kw)
#if the result is cached, return the cached computation # 2: preprocess and get the unique id string for this call
state = [all(a is b for a, b in itertools.izip_longest(args, cached_i)) for cached_i in self.cached_inputs] combined_args_kw = self.combine_args_kw(args, kw)
cache_id = self.preprocess(combined_args_kw, self.ignore_args)
# 2: if anything is not cachable, we will just return the operation, without caching
if reduce(lambda a,b: a or (not isinstance(b, Observable)), combined_args_kw, False):
return self.operation(*args, **kw)
# 3&4: check whether this cache_id has been cached, then has it changed?
try: try:
if any(state): if(self.inputs_changed[cache_id]):
i = state.index(True) # 4: This happens, when one element has changed for this cache id
if self.inputs_changed[i]: self.inputs_changed[cache_id] = False
#(elements of) the args have changed since we last computed: update self.cached_outputs[cache_id] = self.operation(*args, **kw)
self.cached_outputs[i] = self.operation(*args, **kw) except KeyError:
self.inputs_changed[i] = False # 3: This is when we never saw this chache_id:
return self.cached_outputs[i] self.ensure_cache_length(cache_id)
else: self.add_to_cache(cache_id, combined_args_kw, self.operation(*args, **kw))
#first time we've seen these arguments: compute
#first make sure the depth limit isn't exceeded
if len(self.cached_inputs) == self.limit:
args_ = self.cached_inputs.pop(0)
args_ = [a for i,a in enumerate(args_) if i not in self.ignore_args and i not in self.force_kwargs]
[a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None]
self.inputs_changed.pop(0)
self.cached_outputs.pop(0)
#compute
self.cached_inputs.append(oa_all)
self.cached_outputs.append(self.operation(*args, **kw))
self.inputs_changed.append(False)
[a.add_observer(self, self.on_cache_changed) for a in observable_args]
return self.cached_outputs[-1]#return
except: except:
self.reset() self.reset()
raise raise
# 5: We have seen this cache_id and it is cached:
return self.cached_outputs[cache_id]
def on_cache_changed(self, direct, which=None): def on_cache_changed(self, direct, which=None):
""" """
@ -86,17 +109,19 @@ class Cacher(object):
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here. this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
""" """
self.inputs_changed = [any([a is direct or a is which for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)] for ind_id in [id(direct), id(which)]:
_, cache_ids = self.cached_input_ids.get(ind_id, [None, []])
for cache_id in cache_ids:
self.inputs_changed[cache_id] = True
def reset(self): def reset(self):
""" """
Totally reset the cache Totally reset the cache
""" """
[[a.remove_observer(self, self.on_cache_changed) for a in args if isinstance(a, Observable)] for args in self.cached_inputs] [a().remove_observer(self, self.on_cache_changed) if (a() is not None) else None for a in self.cached_input_ids.values()]
[[a.remove_observer(self, self.reset) for a in args if isinstance(a, Observable)] for args in self.cached_inputs] self.cached_input_ids = {}
self.cached_inputs = [] self.cached_outputs = {}
self.cached_outputs = [] self.inputs_changed = {}
self.inputs_changed = []
def __deepcopy__(self, memo=None): def __deepcopy__(self, memo=None):
return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs) return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs)