mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
[caching] done right
This commit is contained in:
parent
37e46b5da3
commit
64fb6ddc4c
1 changed files with 87 additions and 62 deletions
|
|
@ -1,84 +1,107 @@
|
||||||
from ..core.parameterization.parameter_core import Observable
|
from ..core.parameterization.parameter_core import Observable
|
||||||
import itertools
|
import itertools, collections, weakref
|
||||||
|
|
||||||
class Cacher(object):
|
class Cacher(object):
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
|
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
***********
|
||||||
|
:param callable operation: function to cache
|
||||||
|
:param int limit: depth of cacher
|
||||||
|
:param [int] ignore_args: list of indices, pointing at arguments to ignore in *args of operation(*args). This includes self!
|
||||||
|
:param [str] force_kwargs: list of kwarg names (strings). If a kwarg with that name is given, the cacher will force recompute and wont cache anything.
|
||||||
|
"""
|
||||||
self.limit = int(limit)
|
self.limit = int(limit)
|
||||||
self.ignore_args = ignore_args
|
self.ignore_args = ignore_args
|
||||||
self.force_kwargs = force_kwargs
|
self.force_kwargs = force_kwargs
|
||||||
self.operation=operation
|
self.operation=operation
|
||||||
self.cached_inputs = []
|
self.order = collections.deque()
|
||||||
self.cached_outputs = []
|
self.cached_inputs = {} # point from cache_ids to a list of [ind_ids], which where used in cache cache_id
|
||||||
self.inputs_changed = []
|
|
||||||
|
#=======================================================================
|
||||||
|
# point from each ind_id to [ref(obj), cache_ids]
|
||||||
|
# 0: a weak reference to the object itself
|
||||||
|
# 1: the cache_ids in which this ind_id is used (len will be how many times we have seen this ind_id)
|
||||||
|
self.cached_input_ids = {}
|
||||||
|
#=======================================================================
|
||||||
|
|
||||||
|
self.cached_outputs = {} # point from cache_ids to outputs
|
||||||
|
self.inputs_changed = {} # point from cache_ids to bools
|
||||||
|
|
||||||
|
def combine_args_kw(self, args, kw):
|
||||||
|
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
|
||||||
|
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
|
||||||
|
|
||||||
|
def preprocess(self, combined_args_kw, ignore_args):
|
||||||
|
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
|
||||||
|
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
|
||||||
|
|
||||||
|
def ensure_cache_length(self, cache_id):
|
||||||
|
"Ensures the cache is within its limits and has one place free"
|
||||||
|
if len(self.order) == self.limit:
|
||||||
|
# we have reached the limit, so lets release one element
|
||||||
|
cache_id = self.order.popleft()
|
||||||
|
combined_args_kw = self.cached_inputs[cache_id]
|
||||||
|
for ind_id in combined_args_kw:
|
||||||
|
ref, cache_ids = self.cached_input_ids[ind_id]
|
||||||
|
if len(cache_ids) == 1 and ref() is not None:
|
||||||
|
ref().remove_observer(self, self.on_cache_changed)
|
||||||
|
del self.cached_input_ids[ind_id]
|
||||||
|
else:
|
||||||
|
cache_ids.remove(cache_id)
|
||||||
|
self.cached_input_ids[ind_id] = [ref, cache_ids]
|
||||||
|
del self.cached_outputs[cache_id]
|
||||||
|
del self.inputs_changed[cache_id]
|
||||||
|
del self.cached_inputs[cache_id]
|
||||||
|
|
||||||
|
def add_to_cache(self, cache_id, combined_args_kw, output):
|
||||||
|
self.inputs_changed[cache_id] = False
|
||||||
|
self.cached_outputs[cache_id] = output
|
||||||
|
self.order.append(cache_id)
|
||||||
|
self.cached_inputs[cache_id] = combined_args_kw
|
||||||
|
for a in combined_args_kw:
|
||||||
|
ind_id = id(a)
|
||||||
|
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
|
||||||
|
v[1].append(cache_id)
|
||||||
|
if len(v[1]) == 1:
|
||||||
|
a.add_observer(self, self.on_cache_changed)
|
||||||
|
self.cached_input_ids[ind_id] = v
|
||||||
|
|
||||||
def __call__(self, *args, **kw):
|
def __call__(self, *args, **kw):
|
||||||
"""
|
"""
|
||||||
A wrapper function for self.operation,
|
A wrapper function for self.operation,
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#ensure that specified arguments are ignored
|
# 1: Check whether we have forced recompute arguments:
|
||||||
items = sorted(kw.items(), key=lambda x: x[0])
|
|
||||||
oa_all = args + tuple(a for _,a in items)
|
|
||||||
if len(self.ignore_args) != 0:
|
|
||||||
oa = [a for i,a in itertools.chain(enumerate(args), items) if i not in self.ignore_args and i not in self.force_kwargs]
|
|
||||||
else:
|
|
||||||
oa = oa_all
|
|
||||||
|
|
||||||
# this makes sure we only add an observer once, and that None can be in args
|
|
||||||
observable_args = []
|
|
||||||
for a in oa:
|
|
||||||
if (not any(a is ai for ai in observable_args)) and a is not None:
|
|
||||||
observable_args.append(a)
|
|
||||||
|
|
||||||
#make sure that all the found argument really are observable:
|
|
||||||
#otherswise don't cache anything, pass args straight though
|
|
||||||
if not all([isinstance(arg, Observable) for arg in observable_args]):
|
|
||||||
return self.operation(*args, **kw)
|
|
||||||
|
|
||||||
if len(self.force_kwargs) != 0:
|
if len(self.force_kwargs) != 0:
|
||||||
# check if there are force args, which force reloading
|
|
||||||
for k in self.force_kwargs:
|
for k in self.force_kwargs:
|
||||||
if k in kw and kw[k] is not None:
|
if k in kw and kw[k] is not None:
|
||||||
return self.operation(*args, **kw)
|
return self.operation(*args, **kw)
|
||||||
# TODO: WARNING !!! Cache OFFSWITCH !!! WARNING
|
|
||||||
# return self.operation(*args, **kw)
|
|
||||||
|
|
||||||
#if the result is cached, return the cached computation
|
# 2: preprocess and get the unique id string for this call
|
||||||
state = [all(a is b for a, b in itertools.izip_longest(args, cached_i)) for cached_i in self.cached_inputs]
|
combined_args_kw = self.combine_args_kw(args, kw)
|
||||||
|
cache_id = self.preprocess(combined_args_kw, self.ignore_args)
|
||||||
|
|
||||||
|
# 2: if anything is not cachable, we will just return the operation, without caching
|
||||||
|
if reduce(lambda a,b: a or (not isinstance(b, Observable)), combined_args_kw, False):
|
||||||
|
return self.operation(*args, **kw)
|
||||||
|
# 3&4: check whether this cache_id has been cached, then has it changed?
|
||||||
try:
|
try:
|
||||||
if any(state):
|
if(self.inputs_changed[cache_id]):
|
||||||
i = state.index(True)
|
# 4: This happens, when one element has changed for this cache id
|
||||||
if self.inputs_changed[i]:
|
self.inputs_changed[cache_id] = False
|
||||||
#(elements of) the args have changed since we last computed: update
|
self.cached_outputs[cache_id] = self.operation(*args, **kw)
|
||||||
self.cached_outputs[i] = self.operation(*args, **kw)
|
except KeyError:
|
||||||
self.inputs_changed[i] = False
|
# 3: This is when we never saw this chache_id:
|
||||||
return self.cached_outputs[i]
|
self.ensure_cache_length(cache_id)
|
||||||
else:
|
self.add_to_cache(cache_id, combined_args_kw, self.operation(*args, **kw))
|
||||||
#first time we've seen these arguments: compute
|
|
||||||
|
|
||||||
#first make sure the depth limit isn't exceeded
|
|
||||||
if len(self.cached_inputs) == self.limit:
|
|
||||||
args_ = self.cached_inputs.pop(0)
|
|
||||||
args_ = [a for i,a in enumerate(args_) if i not in self.ignore_args and i not in self.force_kwargs]
|
|
||||||
[a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None]
|
|
||||||
self.inputs_changed.pop(0)
|
|
||||||
self.cached_outputs.pop(0)
|
|
||||||
#compute
|
|
||||||
self.cached_inputs.append(oa_all)
|
|
||||||
self.cached_outputs.append(self.operation(*args, **kw))
|
|
||||||
self.inputs_changed.append(False)
|
|
||||||
[a.add_observer(self, self.on_cache_changed) for a in observable_args]
|
|
||||||
return self.cached_outputs[-1]#return
|
|
||||||
except:
|
except:
|
||||||
self.reset()
|
self.reset()
|
||||||
raise
|
raise
|
||||||
|
# 5: We have seen this cache_id and it is cached:
|
||||||
|
return self.cached_outputs[cache_id]
|
||||||
|
|
||||||
def on_cache_changed(self, direct, which=None):
|
def on_cache_changed(self, direct, which=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -86,17 +109,19 @@ class Cacher(object):
|
||||||
|
|
||||||
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
|
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
|
||||||
"""
|
"""
|
||||||
self.inputs_changed = [any([a is direct or a is which for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
|
for ind_id in [id(direct), id(which)]:
|
||||||
|
_, cache_ids = self.cached_input_ids.get(ind_id, [None, []])
|
||||||
|
for cache_id in cache_ids:
|
||||||
|
self.inputs_changed[cache_id] = True
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Totally reset the cache
|
Totally reset the cache
|
||||||
"""
|
"""
|
||||||
[[a.remove_observer(self, self.on_cache_changed) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
|
[a().remove_observer(self, self.on_cache_changed) if (a() is not None) else None for a in self.cached_input_ids.values()]
|
||||||
[[a.remove_observer(self, self.reset) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
|
self.cached_input_ids = {}
|
||||||
self.cached_inputs = []
|
self.cached_outputs = {}
|
||||||
self.cached_outputs = []
|
self.inputs_changed = {}
|
||||||
self.inputs_changed = []
|
|
||||||
|
|
||||||
def __deepcopy__(self, memo=None):
|
def __deepcopy__(self, memo=None):
|
||||||
return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs)
|
return Cacher(self.operation, self.limit, self.ignore_args, self.force_kwargs)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue