automatic slicing

This commit is contained in:
Max Zwiessele 2014-03-11 16:24:09 +00:00
parent e078bb47e1
commit 01f5d789c5
3 changed files with 72 additions and 144 deletions

View file

@ -3,25 +3,18 @@
import sys
import numpy as np
from ...core.parameterization.parameterized import ParametersChangedMeta, Parameterized
from ...core.parameterization.parameterized import Parameterized
from kernel_slice_operations import KernCallsViaSlicerMeta
from ...util.caching import Cache_this
class KernCallsViaSlicerMeta(ParametersChangedMeta):
def __call__(self, *args, **kw):
instance = super(KernCallsViaSlicerMeta, self).__call__(*args, **kw)
instance.K = instance._slice_wrapper(instance.K)
instance.Kdiag = instance._slice_wrapper(instance.Kdiag, True)
instance.update_gradients_full = instance._slice_wrapper(instance.update_gradients_full, False, True)
instance.update_gradients_diag = instance._slice_wrapper(instance.update_gradients_diag, True, True)
instance.gradients_X = instance._slice_wrapper(instance.gradients_X, False, True)
instance.gradients_X_diag = instance._slice_wrapper(instance.gradients_X_diag, True, True)
instance.psi0 = instance._slice_wrapper(instance.psi0, False, False)
instance.psi1 = instance._slice_wrapper(instance.psi1, False, False)
instance.psi2 = instance._slice_wrapper(instance.psi2, False, False)
return instance
class Kern(Parameterized):
#===========================================================================
# This adds input slice support. The rather ugly code for slicing can be
# found in kernel_slice_operations
__metaclass__ = KernCallsViaSlicerMeta
#===========================================================================
def __init__(self, input_dim, name, *a, **kw):
"""
The base class for a kernel: a positive definite function
@ -40,76 +33,11 @@ class Kern(Parameterized):
self.active_dims = input_dim
self.input_dim = len(self.active_dims)
self._sliced_X = False
self._sliced_X2 = False
@Cache_this(limit=10)#, ignore_args = (0,))
def _slice_X(self, X):
return X[:, self.active_dims]
def _slice_wrapper(self, operation, diag=False, derivative=False):
"""
This method wraps the functions in kernel to make sure all kernels allways see their respective input dimension.
The different switches are:
diag: if X2 exists
derivative: if firest arg is dL_dK
"""
if derivative:
if diag:
def x_slice_wrapper(dL_dK, X, *args, **kw):
X = self._slice_X(X) if not self._sliced_X else X
self._sliced_X = True
try:
ret = operation(dL_dK, X, *args, **kw)
except:
raise
finally:
self._sliced_X = False
return ret
else:
def x_slice_wrapper(dL_dK, X, X2=None, *args, **kw):
X, X2 = self._slice_X(X) if not self._sliced_X else X, self._slice_X(X2) if X2 is not None and not self._sliced_X2 else X2
self._sliced_X = True
self._sliced_X2 = True
try:
ret = operation(dL_dK, X, X2, *args, **kw)
except:
raise
finally:
self._sliced_X = False
self._sliced_X2 = False
return ret
else:
if diag:
def x_slice_wrapper(X, *args, **kw):
X = self._slice_X(X) if not self._sliced_X else X
self._sliced_X = True
try:
ret = operation(X, *args, **kw)
except:
raise
finally:
self._sliced_X = False
return ret
else:
def x_slice_wrapper(X, X2=None, *args, **kw):
X, X2 = self._slice_X(X) if not self._sliced_X else X, self._slice_X(X2) if X2 is not None and not self._sliced_X2 else X2
self._sliced_X = True
self._sliced_X2 = True
try:
ret = operation(X, X2, *args, **kw)
except: raise
finally:
self._sliced_X = False
self._sliced_X2 = False
return ret
x_slice_wrapper._operation = operation
x_slice_wrapper.__name__ = ("slicer("+operation.__name__
+(","+str(bool(diag)) if diag else'')
+(','+str(bool(derivative)) if derivative else '')
+')')
x_slice_wrapper.__doc__ = "**sliced**\n" + (operation.__doc__ or "")
return x_slice_wrapper
def K(self, X, X2):
"""
Compute the kernel function.
@ -241,6 +169,21 @@ class Kern(Parameterized):
else: kernels.append(other)
return Prod(self, other, name)
def _getstate(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return super(Kern, self)._getstate() + [
self.active_dims,
self.input_dim,
self._sliced_X]
def _setstate(self, state):
self._sliced_X = state.pop()
self.input_dim = state.pop()
self.active_dims = state.pop()
super(Kern, self)._setstate(state)
class CombinationKernel(Kern):
def __init__(self, kernels, name):
@ -258,3 +201,9 @@ class CombinationKernel(Kern):
def update_gradients_diag(self, dL_dK, X):
[p.update_gradients_diag(dL_dK, X) for p in self.parts]
def input_sensitivity(self):
in_sen = np.zeros((self.num_params, self.input_dim))
for i, p in enumerate(self.parts):
in_sen[i, p.active_dims] = p.input_sensitivity()
return in_sen