From 5cf792504a653b5ae6cc194b66f0eb7e04a3b273 Mon Sep 17 00:00:00 2001 From: James Hensman Date: Thu, 27 Feb 2014 10:47:27 +0000 Subject: [PATCH] messing with caching --- GPy/kern/_src/stationary.py | 8 ++--- GPy/testing/kernel_tests.py | 19 +++++----- GPy/util/caching.py | 71 +++++++++++++++++++++++++++---------- 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/GPy/kern/_src/stationary.py b/GPy/kern/_src/stationary.py index ebc6b880..8d8ae476 100644 --- a/GPy/kern/_src/stationary.py +++ b/GPy/kern/_src/stationary.py @@ -40,18 +40,18 @@ class Stationary(Kern): def dK_dr(self, r): raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class" - @Cache_this(limit=5, ignore_args=()) + #@Cache_this(limit=5, ignore_args=()) def K(self, X, X2=None): r = self._scaled_dist(X, X2) return self.K_of_r(r) - @Cache_this(limit=5, ignore_args=(0,)) + #@Cache_this(limit=5, ignore_args=(0,)) def _dist(self, X, X2): if X2 is None: X2 = X return X[:, None, :] - X2[None, :, :] - @Cache_this(limit=5, ignore_args=(0,)) + #@Cache_this(limit=5, ignore_args=(0,)) def _unscaled_dist(self, X, X2=None): """ Compute the square distance between each row of X and X2, or between @@ -65,7 +65,7 @@ class Stationary(Kern): X2sq = np.sum(np.square(X2),1) return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:])) - @Cache_this(limit=5, ignore_args=()) + #@Cache_this(limit=5, ignore_args=()) def _scaled_dist(self, X, X2=None): """ Efficiently compute the scaled distance, r. diff --git a/GPy/testing/kernel_tests.py b/GPy/testing/kernel_tests.py index e5985145..d373a546 100644 --- a/GPy/testing/kernel_tests.py +++ b/GPy/testing/kernel_tests.py @@ -8,13 +8,6 @@ import sys verbose = True -try: - import sympy - SYMPY_AVAILABLE=True -except ImportError: - SYMPY_AVAILABLE=False - - class Kern_check_model(GPy.core.Model): """ This is a dummy model class used as a base class for checking that the @@ -70,14 +63,11 @@ class Kern_check_dKdiag_dtheta(Kern_check_model): Kern_check_model.__init__(self,kernel=kernel,dL_dK=dL_dK, X=X, X2=None) self.add_parameter(self.kernel) - def parameters_changed(self): - self.kernel.update_gradients_diag(self.dL_dK, self.X) - def log_likelihood(self): return (np.diag(self.dL_dK)*self.kernel.Kdiag(self.X)).sum() def parameters_changed(self): - return self.kernel.update_gradients_diag(np.diag(self.dL_dK), self.X) + self.kernel.update_gradients_diag(np.diag(self.dL_dK), self.X) class Kern_check_dK_dX(Kern_check_model): """This class allows gradient checks for the gradient of a kernel with respect to X. """ @@ -99,6 +89,8 @@ class Kern_check_dKdiag_dX(Kern_check_dK_dX): def parameters_changed(self): self.X.gradient = self.kernel.gradients_X_diag(self.dL_dK, self.X) + + def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False): """ This function runs on kernels to check the correctness of their @@ -217,11 +209,15 @@ def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False): return pass_checks + class KernelTestsContinuous(unittest.TestCase): def setUp(self): self.X = np.random.randn(100,2) self.X2 = np.random.randn(110,2) + continuous_kerns = ['RBF', 'Linear'] + self.kernclasses = [getattr(GPy.kern, s) for s in continuous_kerns] + def test_Matern32(self): k = GPy.kern.Matern32(2) self.assertTrue(kern_test(k, X=self.X, X2=self.X2, verbose=verbose)) @@ -234,6 +230,7 @@ class KernelTestsContinuous(unittest.TestCase): + if __name__ == "__main__": print "Running unit tests, please be (very) patient..." unittest.main() diff --git a/GPy/util/caching.py b/GPy/util/caching.py index 8e60cf26..2899cb33 100644 --- a/GPy/util/caching.py +++ b/GPy/util/caching.py @@ -1,6 +1,13 @@ from ..core.parameterization.parameter_core import Observable class Cacher(object): + """ + + + + + """ + def __init__(self, operation, limit=5, ignore_args=()): self.limit = int(limit) self.ignore_args = ignore_args @@ -10,50 +17,75 @@ class Cacher(object): self.inputs_changed = [] def __call__(self, *args): + """ + A wrapper function for self.operation, + """ + + #ensure that specified arguments are ignored if len(self.ignore_args) != 0: - ca = [a for i,a in enumerate(args) if i not in self.ignore_args] + oa = [a for i,a in enumerate(args) if i not in self.ignore_args] else: - ca = args + oa = args + # this makes sure we only add an observer once, and that None can be in args - cached_args = [] - for a in ca: - if (not any(a is ai for ai in cached_args)) and a is not None: - cached_args.append(a) - if not all([isinstance(arg, Observable) for arg in cached_args]): - print cached_args - import ipdb;ipdb.set_trace() + observable_args = [] + for a in oa: + if (not any(a is ai for ai in observable_args)) and a is not None: + observable_args.append(a) + + #make sure that all the found argument really are observable: + #otherswise don't cache anything, pass args straight though + if not all([isinstance(arg, Observable) for arg in observable_args]): return self.operation(*args) - - if cached_args in self.cached_inputs: - i = self.cached_inputs.index(cached_args) + + #if the result is cached, return the cached computation + state = [all(a is b for a, b in zip(args, cached_i)) for cached_i in self.cached_inputs] + if any(state): + i = state.index(True) if self.inputs_changed[i]: + #(elements of) the args have changed since we last computed: update self.cached_outputs[i] = self.operation(*args) self.inputs_changed[i] = False return self.cached_outputs[i] else: + #first time we've seen these arguments: compute + + #first make sure the depth limit isn't exceeded if len(self.cached_inputs) == self.limit: args_ = self.cached_inputs.pop(0) - [a.remove_observer(self, self.on_cache_changed) for a in args_] + [a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None] self.inputs_changed.pop(0) self.cached_outputs.pop(0) - self.cached_inputs.append(cached_args) + #compute + self.cached_inputs.append(args) self.cached_outputs.append(self.operation(*args)) self.inputs_changed.append(False) - [a.add_observer(self, self.on_cache_changed) for a in cached_args] - return self.cached_outputs[-1] + [a.add_observer(self, self.on_cache_changed) for a in observable_args] + return self.cached_outputs[-1]#Max says return. def on_cache_changed(self, arg): + """ + A callback funtion, which sets local flags when the elements of some cached inputs change + + this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here. + """ self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)] def reset(self, obj): - [[a.remove_observer(self, self.on_cache_changed) for a in args] for args in self.cached_inputs] - [[a.remove_observer(self, self.reset) for a in args] for args in self.cached_inputs] + """ + Totally reset the cache + """ + [[a.remove_observer(self, self.on_cache_changed) for a in args if isinstance(a, Observable)] for args in self.cached_inputs] + [[a.remove_observer(self, self.reset) for a in args if isinstance(a, Observable)] for args in self.cached_inputs] self.cached_inputs = [] self.cached_outputs = [] self.inputs_changed = [] class Cache_this(object): + """ + A decorator which can be applied to bound methods in order to cache them + """ def __init__(self, limit=5, ignore_args=()): self.limit = limit self.ignore_args = ignore_args @@ -64,4 +96,5 @@ class Cache_this(object): self.c = Cacher(f, self.limit, ignore_args=self.ignore_args) return self.c(*args) f_wrap._cacher = self - return f_wrap \ No newline at end of file + f_wrap.__doc__ = "**cached**\n\n" + (f.__doc__ or "") + return f_wrap