mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
messing with caching
This commit is contained in:
parent
2feb849bf7
commit
5cf792504a
3 changed files with 64 additions and 34 deletions
|
|
@ -40,18 +40,18 @@ class Stationary(Kern):
|
|||
def dK_dr(self, r):
|
||||
raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class"
|
||||
|
||||
@Cache_this(limit=5, ignore_args=())
|
||||
#@Cache_this(limit=5, ignore_args=())
|
||||
def K(self, X, X2=None):
|
||||
r = self._scaled_dist(X, X2)
|
||||
return self.K_of_r(r)
|
||||
|
||||
@Cache_this(limit=5, ignore_args=(0,))
|
||||
#@Cache_this(limit=5, ignore_args=(0,))
|
||||
def _dist(self, X, X2):
|
||||
if X2 is None:
|
||||
X2 = X
|
||||
return X[:, None, :] - X2[None, :, :]
|
||||
|
||||
@Cache_this(limit=5, ignore_args=(0,))
|
||||
#@Cache_this(limit=5, ignore_args=(0,))
|
||||
def _unscaled_dist(self, X, X2=None):
|
||||
"""
|
||||
Compute the square distance between each row of X and X2, or between
|
||||
|
|
@ -65,7 +65,7 @@ class Stationary(Kern):
|
|||
X2sq = np.sum(np.square(X2),1)
|
||||
return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:]))
|
||||
|
||||
@Cache_this(limit=5, ignore_args=())
|
||||
#@Cache_this(limit=5, ignore_args=())
|
||||
def _scaled_dist(self, X, X2=None):
|
||||
"""
|
||||
Efficiently compute the scaled distance, r.
|
||||
|
|
|
|||
|
|
@ -8,13 +8,6 @@ import sys
|
|||
|
||||
verbose = True
|
||||
|
||||
try:
|
||||
import sympy
|
||||
SYMPY_AVAILABLE=True
|
||||
except ImportError:
|
||||
SYMPY_AVAILABLE=False
|
||||
|
||||
|
||||
class Kern_check_model(GPy.core.Model):
|
||||
"""
|
||||
This is a dummy model class used as a base class for checking that the
|
||||
|
|
@ -70,14 +63,11 @@ class Kern_check_dKdiag_dtheta(Kern_check_model):
|
|||
Kern_check_model.__init__(self,kernel=kernel,dL_dK=dL_dK, X=X, X2=None)
|
||||
self.add_parameter(self.kernel)
|
||||
|
||||
def parameters_changed(self):
|
||||
self.kernel.update_gradients_diag(self.dL_dK, self.X)
|
||||
|
||||
def log_likelihood(self):
|
||||
return (np.diag(self.dL_dK)*self.kernel.Kdiag(self.X)).sum()
|
||||
|
||||
def parameters_changed(self):
|
||||
return self.kernel.update_gradients_diag(np.diag(self.dL_dK), self.X)
|
||||
self.kernel.update_gradients_diag(np.diag(self.dL_dK), self.X)
|
||||
|
||||
class Kern_check_dK_dX(Kern_check_model):
|
||||
"""This class allows gradient checks for the gradient of a kernel with respect to X. """
|
||||
|
|
@ -99,6 +89,8 @@ class Kern_check_dKdiag_dX(Kern_check_dK_dX):
|
|||
def parameters_changed(self):
|
||||
self.X.gradient = self.kernel.gradients_X_diag(self.dL_dK, self.X)
|
||||
|
||||
|
||||
|
||||
def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False):
|
||||
"""
|
||||
This function runs on kernels to check the correctness of their
|
||||
|
|
@ -217,11 +209,15 @@ def kern_test(kern, X=None, X2=None, output_ind=None, verbose=False):
|
|||
return pass_checks
|
||||
|
||||
|
||||
|
||||
class KernelTestsContinuous(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.X = np.random.randn(100,2)
|
||||
self.X2 = np.random.randn(110,2)
|
||||
|
||||
continuous_kerns = ['RBF', 'Linear']
|
||||
self.kernclasses = [getattr(GPy.kern, s) for s in continuous_kerns]
|
||||
|
||||
def test_Matern32(self):
|
||||
k = GPy.kern.Matern32(2)
|
||||
self.assertTrue(kern_test(k, X=self.X, X2=self.X2, verbose=verbose))
|
||||
|
|
@ -234,6 +230,7 @@ class KernelTestsContinuous(unittest.TestCase):
|
|||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print "Running unit tests, please be (very) patient..."
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,13 @@
|
|||
from ..core.parameterization.parameter_core import Observable
|
||||
|
||||
class Cacher(object):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, operation, limit=5, ignore_args=()):
|
||||
self.limit = int(limit)
|
||||
self.ignore_args = ignore_args
|
||||
|
|
@ -10,50 +17,75 @@ class Cacher(object):
|
|||
self.inputs_changed = []
|
||||
|
||||
def __call__(self, *args):
|
||||
"""
|
||||
A wrapper function for self.operation,
|
||||
"""
|
||||
|
||||
#ensure that specified arguments are ignored
|
||||
if len(self.ignore_args) != 0:
|
||||
ca = [a for i,a in enumerate(args) if i not in self.ignore_args]
|
||||
oa = [a for i,a in enumerate(args) if i not in self.ignore_args]
|
||||
else:
|
||||
ca = args
|
||||
oa = args
|
||||
|
||||
# this makes sure we only add an observer once, and that None can be in args
|
||||
cached_args = []
|
||||
for a in ca:
|
||||
if (not any(a is ai for ai in cached_args)) and a is not None:
|
||||
cached_args.append(a)
|
||||
if not all([isinstance(arg, Observable) for arg in cached_args]):
|
||||
print cached_args
|
||||
import ipdb;ipdb.set_trace()
|
||||
observable_args = []
|
||||
for a in oa:
|
||||
if (not any(a is ai for ai in observable_args)) and a is not None:
|
||||
observable_args.append(a)
|
||||
|
||||
#make sure that all the found argument really are observable:
|
||||
#otherswise don't cache anything, pass args straight though
|
||||
if not all([isinstance(arg, Observable) for arg in observable_args]):
|
||||
return self.operation(*args)
|
||||
|
||||
if cached_args in self.cached_inputs:
|
||||
i = self.cached_inputs.index(cached_args)
|
||||
#if the result is cached, return the cached computation
|
||||
state = [all(a is b for a, b in zip(args, cached_i)) for cached_i in self.cached_inputs]
|
||||
if any(state):
|
||||
i = state.index(True)
|
||||
if self.inputs_changed[i]:
|
||||
#(elements of) the args have changed since we last computed: update
|
||||
self.cached_outputs[i] = self.operation(*args)
|
||||
self.inputs_changed[i] = False
|
||||
return self.cached_outputs[i]
|
||||
else:
|
||||
#first time we've seen these arguments: compute
|
||||
|
||||
#first make sure the depth limit isn't exceeded
|
||||
if len(self.cached_inputs) == self.limit:
|
||||
args_ = self.cached_inputs.pop(0)
|
||||
[a.remove_observer(self, self.on_cache_changed) for a in args_]
|
||||
[a.remove_observer(self, self.on_cache_changed) for a in args_ if a is not None]
|
||||
self.inputs_changed.pop(0)
|
||||
self.cached_outputs.pop(0)
|
||||
|
||||
self.cached_inputs.append(cached_args)
|
||||
#compute
|
||||
self.cached_inputs.append(args)
|
||||
self.cached_outputs.append(self.operation(*args))
|
||||
self.inputs_changed.append(False)
|
||||
[a.add_observer(self, self.on_cache_changed) for a in cached_args]
|
||||
return self.cached_outputs[-1]
|
||||
[a.add_observer(self, self.on_cache_changed) for a in observable_args]
|
||||
return self.cached_outputs[-1]#Max says return.
|
||||
|
||||
def on_cache_changed(self, arg):
|
||||
"""
|
||||
A callback funtion, which sets local flags when the elements of some cached inputs change
|
||||
|
||||
this function gets 'hooked up' to the inputs when we cache them, and upon their elements being changed we update here.
|
||||
"""
|
||||
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
|
||||
|
||||
def reset(self, obj):
|
||||
[[a.remove_observer(self, self.on_cache_changed) for a in args] for args in self.cached_inputs]
|
||||
[[a.remove_observer(self, self.reset) for a in args] for args in self.cached_inputs]
|
||||
"""
|
||||
Totally reset the cache
|
||||
"""
|
||||
[[a.remove_observer(self, self.on_cache_changed) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
|
||||
[[a.remove_observer(self, self.reset) for a in args if isinstance(a, Observable)] for args in self.cached_inputs]
|
||||
self.cached_inputs = []
|
||||
self.cached_outputs = []
|
||||
self.inputs_changed = []
|
||||
|
||||
class Cache_this(object):
|
||||
"""
|
||||
A decorator which can be applied to bound methods in order to cache them
|
||||
"""
|
||||
def __init__(self, limit=5, ignore_args=()):
|
||||
self.limit = limit
|
||||
self.ignore_args = ignore_args
|
||||
|
|
@ -64,4 +96,5 @@ class Cache_this(object):
|
|||
self.c = Cacher(f, self.limit, ignore_args=self.ignore_args)
|
||||
return self.c(*args)
|
||||
f_wrap._cacher = self
|
||||
f_wrap.__doc__ = "**cached**\n\n" + (f.__doc__ or "")
|
||||
return f_wrap
|
||||
Loading…
Add table
Add a link
Reference in a new issue