mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
fix the bug of caching w.r.t. ignore arguments
This commit is contained in:
parent
1061bf5248
commit
2fb86f9b51
1 changed files with 11 additions and 7 deletions
|
|
@ -33,13 +33,15 @@ class Cacher(object):
|
|||
"""returns the self.id of an object, to be used in caching individual self.ids"""
|
||||
return hex(id(obj))
|
||||
|
||||
def combine_inputs(self, args, kw):
|
||||
def combine_inputs(self, args, kw, ignore_args):
|
||||
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
|
||||
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
|
||||
inputs= args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
|
||||
# REMOVE the ignored arguments from input and PREVENT it from being checked!!!
|
||||
return [a for i,a in enumerate(inputs) if i not in ignore_args]
|
||||
|
||||
def prepare_cache_id(self, combined_args_kw, ignore_args):
|
||||
"get the cacheid (conc. string of argument self.ids in order) ignoring ignore_args"
|
||||
cache_id = "".join(self.id(a) for i, a in enumerate(combined_args_kw) if i not in ignore_args)
|
||||
def prepare_cache_id(self, combined_args_kw):
|
||||
"get the cacheid (conc. string of argument self.ids in order)"
|
||||
cache_id = "".join(self.id(a) for a in combined_args_kw)
|
||||
return cache_id
|
||||
|
||||
def ensure_cache_length(self, cache_id):
|
||||
|
|
@ -95,10 +97,12 @@ class Cacher(object):
|
|||
return self.operation(*args, **kw)
|
||||
|
||||
# 2: prepare_cache_id and get the unique self.id string for this call
|
||||
inputs = self.combine_inputs(args, kw)
|
||||
cache_id = self.prepare_cache_id(inputs, self.ignore_args)
|
||||
inputs = self.combine_inputs(args, kw, self.ignore_args)
|
||||
cache_id = self.prepare_cache_id(inputs)
|
||||
# 2: if anything is not cachable, we will just return the operation, without caching
|
||||
if reduce(lambda a, b: a or (not (isinstance(b, Observable) or b is None)), inputs, False):
|
||||
#print 'WARNING: '+self.operation.__name__ + ' not cacheable!'
|
||||
#print [not (isinstance(b, Observable)) for b in inputs]
|
||||
return self.operation(*args, **kw)
|
||||
# 3&4: check whether this cache_id has been cached, then has it changed?
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue