mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
[caching] renaming of helper methods to make intention clear
This commit is contained in:
parent
c291d5b2ad
commit
73f23c0a0d
3 changed files with 13 additions and 13 deletions
|
|
@ -316,7 +316,7 @@ class Param(OptimizationHandlable, ObsAr):
|
||||||
class ParamConcatenation(object):
|
class ParamConcatenation(object):
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
"""
|
"""
|
||||||
Parameter concatenation for convienience of printing regular expression matched arrays
|
Parameter concatenation for convenience of printing regular expression matched arrays
|
||||||
you can index this concatenation as if it was the flattened concatenation
|
you can index this concatenation as if it was the flattened concatenation
|
||||||
of all the parameters it contains, same for setting parameters (Broadcasting enabled).
|
of all the parameters it contains, same for setting parameters (Broadcasting enabled).
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class Kern(Parameterized):
|
||||||
self._sliced_X = 0
|
self._sliced_X = 0
|
||||||
self.useGPU = self._support_GPU and useGPU
|
self.useGPU = self._support_GPU and useGPU
|
||||||
|
|
||||||
@Cache_this(limit=10)
|
@Cache_this(limit=20)
|
||||||
def _slice_X(self, X):
|
def _slice_X(self, X):
|
||||||
return X[:, self.active_dims]
|
return X[:, self.active_dims]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,11 +30,11 @@ class Cacher(object):
|
||||||
self.cached_outputs = {} # point from cache_ids to outputs
|
self.cached_outputs = {} # point from cache_ids to outputs
|
||||||
self.inputs_changed = {} # point from cache_ids to bools
|
self.inputs_changed = {} # point from cache_ids to bools
|
||||||
|
|
||||||
def combine_args_kw(self, args, kw):
|
def combine_inputs(self, args, kw):
|
||||||
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
|
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
|
||||||
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
|
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
|
||||||
|
|
||||||
def preprocess(self, combined_args_kw, ignore_args):
|
def prepare_cache_id(self, combined_args_kw, ignore_args):
|
||||||
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
|
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
|
||||||
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
|
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
|
||||||
|
|
||||||
|
|
@ -81,23 +81,23 @@ class Cacher(object):
|
||||||
if k in kw and kw[k] is not None:
|
if k in kw and kw[k] is not None:
|
||||||
return self.operation(*args, **kw)
|
return self.operation(*args, **kw)
|
||||||
|
|
||||||
# 2: preprocess and get the unique id string for this call
|
# 2: prepare_cache_id and get the unique id string for this call
|
||||||
combined_args_kw = self.combine_args_kw(args, kw)
|
inputs = self.combine_inputs(args, kw)
|
||||||
cache_id = self.preprocess(combined_args_kw, self.ignore_args)
|
cache_id = self.prepare_cache_id(inputs, self.ignore_args)
|
||||||
|
|
||||||
# 2: if anything is not cachable, we will just return the operation, without caching
|
# 2: if anything is not cachable, we will just return the operation, without caching
|
||||||
if reduce(lambda a,b: a or (not isinstance(b, Observable)), combined_args_kw, False):
|
if reduce(lambda a,b: a or (not isinstance(b, Observable)), inputs, False):
|
||||||
return self.operation(*args, **kw)
|
return self.operation(*args, **kw)
|
||||||
# 3&4: check whether this cache_id has been cached, then has it changed?
|
# 3&4: check whether this cache_id has been cached, then has it changed?
|
||||||
try:
|
try:
|
||||||
if(self.inputs_changed[cache_id]):
|
if(self.inputs_changed[cache_id]):
|
||||||
# 4: This happens, when one element has changed for this cache id
|
# 4: This happens, when elements have changed for this cache id
|
||||||
self.inputs_changed[cache_id] = False
|
self.inputs_changed[cache_id] = False
|
||||||
self.cached_outputs[cache_id] = self.operation(*args, **kw)
|
self.cached_outputs[cache_id] = self.operation(*args, **kw)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# 3: This is when we never saw this chache_id:
|
# 3: This is when we never saw this chache_id:
|
||||||
self.ensure_cache_length(cache_id)
|
self.ensure_cache_length(cache_id)
|
||||||
self.add_to_cache(cache_id, combined_args_kw, self.operation(*args, **kw))
|
self.add_to_cache(cache_id, inputs, self.operation(*args, **kw))
|
||||||
except:
|
except:
|
||||||
self.reset()
|
self.reset()
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue