messing with caching

This commit is contained in:
James Hensman 2014-02-27 10:47:27 +00:00
parent 2feb849bf7
commit 5cf792504a
3 changed files with 64 additions and 34 deletions

View file

@ -40,18 +40,18 @@ class Stationary(Kern):
def dK_dr(self, r):
raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class"
@Cache_this(limit=5, ignore_args=())
#@Cache_this(limit=5, ignore_args=())
def K(self, X, X2=None):
r = self._scaled_dist(X, X2)
return self.K_of_r(r)
@Cache_this(limit=5, ignore_args=(0,))
#@Cache_this(limit=5, ignore_args=(0,))
def _dist(self, X, X2):
if X2 is None:
X2 = X
return X[:, None, :] - X2[None, :, :]
@Cache_this(limit=5, ignore_args=(0,))
#@Cache_this(limit=5, ignore_args=(0,))
def _unscaled_dist(self, X, X2=None):
"""
Compute the square distance between each row of X and X2, or between
@ -65,7 +65,7 @@ class Stationary(Kern):
X2sq = np.sum(np.square(X2),1)
return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:]))
@Cache_this(limit=5, ignore_args=())
#@Cache_this(limit=5, ignore_args=())
def _scaled_dist(self, X, X2=None):
"""
Efficiently compute the scaled distance, r.

View file

@ -8,13 +8,6 @@ import sys
verbose = True
try:
import sympy
SYMPY_AVAILABLE=True
except ImportError:
SYMPY_AVAILABLE=False
class Kern_check_model(GPy.core.Model):
"""
This is a dummy model class used as a base class for checking that the
@ -70,14 +63,11 @@ class Kern_check_dKdiag_dtheta(Kern_check_model):
Kern_check_model.__init__(self,kernel=kernel,dL_dK=dL_dK, X=X, X2=None)
self.add_parameter(self.kernel)
def parameters_changed(self):
self.kernel.update_gradients_diag(self.dL_dK, self.X)
def log_likelihood(self):
return (np.diag(self.dL_dK)*self.kernel.Kdiag(self.X)).sum()
def parameters_changed(self):
return self.kernel.update_gradients_diag(np.diag(self.dL_dK), self.X)
self.kernel.update_gradients_diag(np.diag(self.dL_dK), self.X)
class Kern_check_dK_dX(Kern_check_model):
"""This class allows gradient checks for the gradient of a kernel with respect to X. """
@ -99,6 +89,8 @@ class Kern_check_dKdiag_dX(Kern_check_dK_dX):
def parameters_changed(self):
self.X.gradient = self.kernel.gradients_X_diag(self.dL_dK, self.X)
def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False):
"""
This function runs on kernels to check the correctness of their
@ -217,11 +209,15 @@ def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False):
return pass_checks
class KernelTestsContinuous(unittest.TestCase):
def setUp(self):
self.X = np.random.randn(100,2)
self.X2 = np.random.randn(110,2)
continuous_kerns = ['RBF', 'Linear']
self.kernclasses = [getattr(GPy.kern, s) for s in continuous_kerns]
def test_Matern32(self):
k = GPy.kern.Matern32(2)
self.assertTrue(kern_test(k, X=self.X, X2=self.X2, verbose=verbose))
@ -234,6 +230,7 @@ class KernelTestsContinuous(unittest.TestCase):
if __name__ == "__main__":
print "Running unit tests, please be (very) patient..."
unittest.main()

View file

@ -1,6 +1,13 @@
from ..core.parameterization.parameter_core import Observable
class Cacher(object):
"""
"""
def __init__(self, operation, limit=5, ignore_args=()):
self.limit = int(limit)
self.ignore_args = ignore_args
@ -10,50 +17,75 @@ class Cacher(object):
self.inputs_changed = []
def __call__(self, *args):
"""
A wrapper function for self.operation,
"""
#ensure that specified arguments are ignored
if len(self.ignore_args) != 0:
ca = [a for i,a in enumerate(args) if i not in self.ignore_args]
oa = [a for i,a in enumerate(args) if i not in self.ignore_args]
else:
ca = args
oa = args
# this makes sure we only add an observer once, and that None can be in args
cached_args = []
for a in ca:
if (not any(a is ai for ai in cached_args)) and a is not None:
cached_args.append(a)
if not all([isinstance(arg, Observable) for arg in cached_args]):
print cached_args
import ipdb;ipdb.set_trace()
observable_args = []
for a in oa:
if (not any(a is ai for ai in observable_args)) and a is not None:
observable_args.append(a)
#make sure that all the found argument really are observable:
#otherswise don't cache anything, pass args straight though
if not all([isinstance(arg, Observable) for arg in observable_args]):
return self.operation(*args)
if cached_args in self.cached_inputs:
i = self.cached_inputs.index(cached_args)
#if the result is cached, return the cached computation
state = [all(a is b for a, b in zip(args, cached_i)) for cached_i in self.cached_inputs]
if any(state):
i = state.index(True)
if self.inputs_changed[i]:
#(elements of) the args have changed since we last computed: update
self.cached_outputs[i] = self.operation(*args)
self.inputs_changed[i] = False
return self.cached_outputs[i]
else:
#first time we've seen these arguments: compute
#first make sure the depth limit isn't exceeded
if len(self.cached_inputs) == self.limit:
args_ = self.cached_inputs.pop(0)
[a.remove_observer(self, self.on_cache_changed) for a in args_]
[a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None]
self.inputs_changed.pop(0)
self.cached_outputs.pop(0)
self.cached_inputs.append(cached_args)
#compute
self.cached_inputs.append(args)
self.cached_outputs.append(self.operation(*args))
self.inputs_changed.append(False)
[a.add_observer(self, self.on_cache_changed) for a in cached_args]
return self.cached_outputs[-1]
[a.add_observer(self, self.on_cache_changed) for a in observable_args]
return self.cached_outputs[-1]#Max says return.
def on_cache_changed(self, arg):
"""
A callback funtion, which sets local flags when the elements of some cached inputs change
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
"""
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
def reset(self, obj):
[[a.remove_observer(self, self.on_cache_changed) for a in args] for args in self.cached_inputs]
[[a.remove_observer(self, self.reset) for a in args] for args in self.cached_inputs]
"""
Totally reset the cache
"""
[[a.remove_observer(self, self.on_cache_changed) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
[[a.remove_observer(self, self.reset) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
self.cached_inputs = []
self.cached_outputs = []
self.inputs_changed = []
class Cache_this(object):
"""
A decorator which can be applied to bound methods in order to cache them
"""
def __init__(self, limit=5, ignore_args=()):
self.limit = limit
self.ignore_args = ignore_args
@ -64,4 +96,5 @@ class Cache_this(object):
self.c = Cacher(f, self.limit, ignore_args=self.ignore_args)
return self.c(*args)
f_wrap._cacher = self
f_wrap.__doc__ = "**cached**\n\n" + (f.__doc__ or "")
return f_wrap