allow cache supporting boolean and integers

This commit is contained in:
Zhenwen Dai 2015-09-04 17:27:22 +01:00
parent d7655e4407
commit 9145a87ee8
2 changed files with 5 additions and 4 deletions

View file

@ -26,6 +26,7 @@ class KernCallsViaSlicerMeta(ParametersChangedMeta):
put_clean(dct, 'psi0', _slice_psi) put_clean(dct, 'psi0', _slice_psi)
put_clean(dct, 'psi1', _slice_psi) put_clean(dct, 'psi1', _slice_psi)
put_clean(dct, 'psi2', _slice_psi) put_clean(dct, 'psi2', _slice_psi)
put_clean(dct, 'psi2n', _slice_psi)
put_clean(dct, 'update_gradients_expectations', _slice_update_gradients_expectations) put_clean(dct, 'update_gradients_expectations', _slice_update_gradients_expectations)
put_clean(dct, 'gradients_Z_expectations', _slice_gradients_Z_expectations) put_clean(dct, 'gradients_Z_expectations', _slice_gradients_Z_expectations)
put_clean(dct, 'gradients_qX_expectations', _slice_gradients_qX_expectations) put_clean(dct, 'gradients_qX_expectations', _slice_gradients_qX_expectations)

View file

@ -76,7 +76,7 @@ class Cacher(object):
self.order.append(cache_id) self.order.append(cache_id)
self.cached_inputs[cache_id] = inputs self.cached_inputs[cache_id] = inputs
for a in inputs: for a in inputs:
if a is not None: if a is not None and not isinstance(a, int):
ind_id = self.id(a) ind_id = self.id(a)
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []]) v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
v[1].append(cache_id) v[1].append(cache_id)
@ -103,9 +103,9 @@ class Cacher(object):
inputs = self.combine_inputs(args, kw, self.ignore_args) inputs = self.combine_inputs(args, kw, self.ignore_args)
cache_id = self.prepare_cache_id(inputs) cache_id = self.prepare_cache_id(inputs)
# 2: if anything is not cachable, we will just return the operation, without caching # 2: if anything is not cachable, we will just return the operation, without caching
if reduce(lambda a, b: a or (not (isinstance(b, Observable) or b is None)), inputs, False): if reduce(lambda a, b: a or (not (isinstance(b, Observable) or b is None or isinstance(b,int))), inputs, False):
#print 'WARNING: '+self.operation.__name__ + ' not cacheable!' # print 'WARNING: '+self.operation.__name__ + ' not cacheable!'
#print [not (isinstance(b, Observable)) for b in inputs] # print [not (isinstance(b, Observable)) for b in inputs]
return self.operation(*args, **kw) return self.operation(*args, **kw)
# 3&4: check whether this cache_id has been cached, then has it changed? # 3&4: check whether this cache_id has been cached, then has it changed?
try: try: