messing with caching

This commit is contained in:
James Hensman 2014-02-27 10:47:27 +00:00
parent 2feb849bf7
commit 5cf792504a
3 changed files with 64 additions and 34 deletions

View file

@ -40,18 +40,18 @@ class Stationary(Kern):
def dK_dr(self, r): def dK_dr(self, r):
raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class" raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class"
@Cache_this(limit=5, ignore_args=()) #@Cache_this(limit=5, ignore_args=())
def K(self, X, X2=None): def K(self, X, X2=None):
r = self._scaled_dist(X, X2) r = self._scaled_dist(X, X2)
return self.K_of_r(r) return self.K_of_r(r)
@Cache_this(limit=5, ignore_args=(0,)) #@Cache_this(limit=5, ignore_args=(0,))
def _dist(self, X, X2): def _dist(self, X, X2):
if X2 is None: if X2 is None:
X2 = X X2 = X
return X[:, None, :] - X2[None, :, :] return X[:, None, :] - X2[None, :, :]
@Cache_this(limit=5, ignore_args=(0,)) #@Cache_this(limit=5, ignore_args=(0,))
def _unscaled_dist(self, X, X2=None): def _unscaled_dist(self, X, X2=None):
""" """
Compute the square distance between each row of X and X2, or between Compute the square distance between each row of X and X2, or between
@ -65,7 +65,7 @@ class Stationary(Kern):
X2sq = np.sum(np.square(X2),1) X2sq = np.sum(np.square(X2),1)
return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:])) return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:]))
@Cache_this(limit=5, ignore_args=()) #@Cache_this(limit=5, ignore_args=())
def _scaled_dist(self, X, X2=None): def _scaled_dist(self, X, X2=None):
""" """
Efficiently compute the scaled distance, r. Efficiently compute the scaled distance, r.

View file

@ -8,13 +8,6 @@ import sys
verbose = True verbose = True
try:
import sympy
SYMPY_AVAILABLE=True
except ImportError:
SYMPY_AVAILABLE=False
class Kern_check_model(GPy.core.Model): class Kern_check_model(GPy.core.Model):
""" """
This is a dummy model class used as a base class for checking that the This is a dummy model class used as a base class for checking that the
@ -70,14 +63,11 @@ class Kern_check_dKdiag_dtheta(Kern_check_model):
Kern_check_model.__init__(self,kernel=kernel,dL_dK=dL_dK, X=X, X2=None) Kern_check_model.__init__(self,kernel=kernel,dL_dK=dL_dK, X=X, X2=None)
self.add_parameter(self.kernel) self.add_parameter(self.kernel)
def parameters_changed(self):
self.kernel.update_gradients_diag(self.dL_dK, self.X)
def log_likelihood(self): def log_likelihood(self):
return (np.diag(self.dL_dK)*self.kernel.Kdiag(self.X)).sum() return (np.diag(self.dL_dK)*self.kernel.Kdiag(self.X)).sum()
def parameters_changed(self): def parameters_changed(self):
return self.kernel.update_gradients_diag(np.diag(self.dL_dK), self.X) self.kernel.update_gradients_diag(np.diag(self.dL_dK), self.X)
class Kern_check_dK_dX(Kern_check_model): class Kern_check_dK_dX(Kern_check_model):
"""This class allows gradient checks for the gradient of a kernel with respect to X. """ """This class allows gradient checks for the gradient of a kernel with respect to X. """
@ -99,6 +89,8 @@ class Kern_check_dKdiag_dX(Kern_check_dK_dX):
def parameters_changed(self): def parameters_changed(self):
self.X.gradient = self.kernel.gradients_X_diag(self.dL_dK, self.X) self.X.gradient = self.kernel.gradients_X_diag(self.dL_dK, self.X)
def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False): def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False):
""" """
This function runs on kernels to check the correctness of their This function runs on kernels to check the correctness of their
@ -217,11 +209,15 @@ def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False):
return pass_checks return pass_checks
class KernelTestsContinuous(unittest.TestCase): class KernelTestsContinuous(unittest.TestCase):
def setUp(self): def setUp(self):
self.X = np.random.randn(100,2) self.X = np.random.randn(100,2)
self.X2 = np.random.randn(110,2) self.X2 = np.random.randn(110,2)
continuous_kerns = ['RBF', 'Linear']
self.kernclasses = [getattr(GPy.kern, s) for s in continuous_kerns]
def test_Matern32(self): def test_Matern32(self):
k = GPy.kern.Matern32(2) k = GPy.kern.Matern32(2)
self.assertTrue(kern_test(k, X=self.X, X2=self.X2, verbose=verbose)) self.assertTrue(kern_test(k, X=self.X, X2=self.X2, verbose=verbose))
@ -234,6 +230,7 @@ class KernelTestsContinuous(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
print "Running unit tests, please be (very) patient..." print "Running unit tests, please be (very) patient..."
unittest.main() unittest.main()

View file

@ -1,6 +1,13 @@
from ..core.parameterization.parameter_core import Observable from ..core.parameterization.parameter_core import Observable
class Cacher(object): class Cacher(object):
"""
"""
def __init__(self, operation, limit=5, ignore_args=()): def __init__(self, operation, limit=5, ignore_args=()):
self.limit = int(limit) self.limit = int(limit)
self.ignore_args = ignore_args self.ignore_args = ignore_args
@ -10,50 +17,75 @@ class Cacher(object):
self.inputs_changed = [] self.inputs_changed = []
def __call__(self, *args): def __call__(self, *args):
"""
A wrapper function for self.operation,
"""
#ensure that specified arguments are ignored
if len(self.ignore_args) != 0: if len(self.ignore_args) != 0:
ca = [a for i,a in enumerate(args) if i not in self.ignore_args] oa = [a for i,a in enumerate(args) if i not in self.ignore_args]
else: else:
ca = args oa = args
# this makes sure we only add an observer once, and that None can be in args # this makes sure we only add an observer once, and that None can be in args
cached_args = [] observable_args = []
for a in ca: for a in oa:
if (not any(a is ai for ai in cached_args)) and a is not None: if (not any(a is ai for ai in observable_args)) and a is not None:
cached_args.append(a) observable_args.append(a)
if not all([isinstance(arg, Observable) for arg in cached_args]):
print cached_args #make sure that all the found argument really are observable:
import ipdb;ipdb.set_trace() #otherswise don't cache anything, pass args straight though
if not all([isinstance(arg, Observable) for arg in observable_args]):
return self.operation(*args) return self.operation(*args)
if cached_args in self.cached_inputs: #if the result is cached, return the cached computation
i = self.cached_inputs.index(cached_args) state = [all(a is b for a, b in zip(args, cached_i)) for cached_i in self.cached_inputs]
if any(state):
i = state.index(True)
if self.inputs_changed[i]: if self.inputs_changed[i]:
#(elements of) the args have changed since we last computed: update
self.cached_outputs[i] = self.operation(*args) self.cached_outputs[i] = self.operation(*args)
self.inputs_changed[i] = False self.inputs_changed[i] = False
return self.cached_outputs[i] return self.cached_outputs[i]
else: else:
#first time we've seen these arguments: compute
#first make sure the depth limit isn't exceeded
if len(self.cached_inputs) == self.limit: if len(self.cached_inputs) == self.limit:
args_ = self.cached_inputs.pop(0) args_ = self.cached_inputs.pop(0)
[a.remove_observer(self, self.on_cache_changed) for a in args_] [a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None]
self.inputs_changed.pop(0) self.inputs_changed.pop(0)
self.cached_outputs.pop(0) self.cached_outputs.pop(0)
self.cached_inputs.append(cached_args) #compute
self.cached_inputs.append(args)
self.cached_outputs.append(self.operation(*args)) self.cached_outputs.append(self.operation(*args))
self.inputs_changed.append(False) self.inputs_changed.append(False)
[a.add_observer(self, self.on_cache_changed) for a in cached_args] [a.add_observer(self, self.on_cache_changed) for a in observable_args]
return self.cached_outputs[-1] return self.cached_outputs[-1]#Max says return.
def on_cache_changed(self, arg): def on_cache_changed(self, arg):
"""
A callback funtion, which sets local flags when the elements of some cached inputs change
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
"""
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)] self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
def reset(self, obj): def reset(self, obj):
[[a.remove_observer(self, self.on_cache_changed) for a in args] for args in self.cached_inputs] """
[[a.remove_observer(self, self.reset) for a in args] for args in self.cached_inputs] Totally reset the cache
"""
[[a.remove_observer(self, self.on_cache_changed) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
[[a.remove_observer(self, self.reset) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
self.cached_inputs = [] self.cached_inputs = []
self.cached_outputs = [] self.cached_outputs = []
self.inputs_changed = [] self.inputs_changed = []
class Cache_this(object): class Cache_this(object):
"""
A decorator which can be applied to bound methods in order to cache them
"""
def __init__(self, limit=5, ignore_args=()): def __init__(self, limit=5, ignore_args=()):
self.limit = limit self.limit = limit
self.ignore_args = ignore_args self.ignore_args = ignore_args
@ -64,4 +96,5 @@ class Cache_this(object):
self.c = Cacher(f, self.limit, ignore_args=self.ignore_args) self.c = Cacher(f, self.limit, ignore_args=self.ignore_args)
return self.c(*args) return self.c(*args)
f_wrap._cacher = self f_wrap._cacher = self
f_wrap.__doc__ = "**cached**\n\n" + (f.__doc__ or "")
return f_wrap return f_wrap