mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
[caching] renaming of helper methods to make intention clear
This commit is contained in:
parent
c291d5b2ad
commit
73f23c0a0d
3 changed files with 13 additions and 13 deletions
|
|
@ -316,7 +316,7 @@ class Param(OptimizationHandlable, ObsAr):
|
|||
class ParamConcatenation(object):
|
||||
def __init__(self, params):
|
||||
"""
|
||||
Parameter concatenation for convienience of printing regular expression matched arrays
|
||||
Parameter concatenation for convenience of printing regular expression matched arrays
|
||||
you can index this concatenation as if it was the flattened concatenation
|
||||
of all the parameters it contains, same for setting parameters (Broadcasting enabled).
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class Kern(Parameterized):
|
|||
self._sliced_X = 0
|
||||
self.useGPU = self._support_GPU and useGPU
|
||||
|
||||
@Cache_this(limit=10)
|
||||
@Cache_this(limit=20)
|
||||
def _slice_X(self, X):
|
||||
return X[:, self.active_dims]
|
||||
|
||||
|
|
|
|||
|
|
@ -30,11 +30,11 @@ class Cacher(object):
|
|||
self.cached_outputs = {} # point from cache_ids to outputs
|
||||
self.inputs_changed = {} # point from cache_ids to bools
|
||||
|
||||
def combine_args_kw(self, args, kw):
|
||||
def combine_inputs(self, args, kw):
|
||||
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
|
||||
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
|
||||
|
||||
def preprocess(self, combined_args_kw, ignore_args):
|
||||
def prepare_cache_id(self, combined_args_kw, ignore_args):
|
||||
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
|
||||
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
|
||||
|
||||
|
|
@ -81,23 +81,23 @@ class Cacher(object):
|
|||
if k in kw and kw[k] is not None:
|
||||
return self.operation(*args, **kw)
|
||||
|
||||
# 2: preprocess and get the unique id string for this call
|
||||
combined_args_kw = self.combine_args_kw(args, kw)
|
||||
cache_id = self.preprocess(combined_args_kw, self.ignore_args)
|
||||
# 2: prepare_cache_id and get the unique id string for this call
|
||||
inputs = self.combine_inputs(args, kw)
|
||||
cache_id = self.prepare_cache_id(inputs, self.ignore_args)
|
||||
|
||||
# 2: if anything is not cachable, we will just return the operation, without caching
|
||||
if reduce(lambda a,b: a or (not isinstance(b, Observable)), combined_args_kw, False):
|
||||
if reduce(lambda a,b: a or (not isinstance(b, Observable)), inputs, False):
|
||||
return self.operation(*args, **kw)
|
||||
# 3&4: check whether this cache_id has been cached, then has it changed?
|
||||
try:
|
||||
if(self.inputs_changed[cache_id]):
|
||||
# 4: This happens, when one element has changed for this cache id
|
||||
# 4: This happens, when elements have changed for this cache id
|
||||
self.inputs_changed[cache_id] = False
|
||||
self.cached_outputs[cache_id] = self.operation(*args, **kw)
|
||||
except KeyError:
|
||||
# 3: This is when we never saw this chache_id:
|
||||
self.ensure_cache_length(cache_id)
|
||||
self.add_to_cache(cache_id, combined_args_kw, self.operation(*args, **kw))
|
||||
self.add_to_cache(cache_id, inputs, self.operation(*args, **kw))
|
||||
except:
|
||||
self.reset()
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue