mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-11 04:52:37 +02:00
slicing support for kernel input dimension
This commit is contained in:
parent
5f3524e7da
commit
db5fd17609
10 changed files with 178 additions and 65 deletions
|
|
@ -2,13 +2,22 @@
|
|||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import itertools
|
||||
from ...core.parameterization import Parameterized
|
||||
from ...core.parameterization.param import Param
|
||||
|
||||
from ...core.parameterization.parameterized import ParametersChangedMeta, Parameterized
|
||||
from ...util.caching import Cache_this
|
||||
|
||||
class KernCallsViaSlicerMeta(ParametersChangedMeta):
|
||||
def __call__(self, *args, **kw):
|
||||
instance = super(KernCallsViaSlicerMeta, self).__call__(*args, **kw)
|
||||
instance.K = instance._slice_wrapper(instance.K)
|
||||
instance.Kdiag = instance._slice_wrapper(instance.Kdiag, True)
|
||||
instance.update_gradients_full = instance._slice_wrapper(instance.update_gradients_full, False, True)
|
||||
instance.update_gradients_diag = instance._slice_wrapper(instance.update_gradients_diag, True, True)
|
||||
instance.gradients_X = instance._slice_wrapper(instance.gradients_X, False, True)
|
||||
instance.gradients_X_diag = instance._slice_wrapper(instance.gradients_X_diag, True, True)
|
||||
return instance
|
||||
|
||||
class Kern(Parameterized):
|
||||
__metaclass__ = KernCallsViaSlicerMeta
|
||||
def __init__(self, input_dim, name, *a, **kw):
|
||||
"""
|
||||
The base class for a kernel: a positive definite function
|
||||
|
|
@ -20,11 +29,83 @@ class Kern(Parameterized):
|
|||
Do not instantiate.
|
||||
"""
|
||||
super(Kern, self).__init__(name=name, *a, **kw)
|
||||
self.input_dim = input_dim
|
||||
|
||||
if isinstance(input_dim, int):
|
||||
self.active_dims = slice(0, input_dim)
|
||||
self.input_dim = input_dim
|
||||
else:
|
||||
self.active_dims = input_dim
|
||||
self.input_dim = len(self.active_dims)
|
||||
self._sliced_X = False
|
||||
self._sliced_X2 = False
|
||||
|
||||
@Cache_this(limit=10, ignore_args = (0,))
|
||||
def _slice_X(self, X):
|
||||
return X[:, self.active_dims]
|
||||
|
||||
def _slice_wrapper(self, operation, diag=False, derivative=False):
|
||||
"""
|
||||
This method wraps the functions in kernel to make sure all kernels allways see their respective input dimension.
|
||||
The different switches are:
|
||||
diag: if X2 exists
|
||||
derivative: if firest arg is dL_dK
|
||||
"""
|
||||
if derivative:
|
||||
if diag:
|
||||
def x_slice_wrapper(dL_dK, X, *args, **kw):
|
||||
X = self._slice_X(X) if not self._sliced_X else X
|
||||
self._sliced_X = True
|
||||
try:
|
||||
ret = operation(dL_dK, X, *args, **kw)
|
||||
except: raise
|
||||
finally:
|
||||
self._sliced_X = False
|
||||
return ret
|
||||
else:
|
||||
def x_slice_wrapper(dL_dK, X, X2=None, *args, **kw):
|
||||
X, X2 = self._slice_X(X) if not self._sliced_X else X, self._slice_X(X2) if X2 is not None and not self._sliced_X2 else X2
|
||||
self._sliced_X = True
|
||||
self._sliced_X2 = True
|
||||
try:
|
||||
ret = operation(dL_dK, X, X2, *args, **kw)
|
||||
except: raise
|
||||
finally:
|
||||
self._sliced_X = False
|
||||
self._sliced_X2 = False
|
||||
return ret
|
||||
else:
|
||||
if diag:
|
||||
def x_slice_wrapper(X, *args, **kw):
|
||||
X = self._slice_X(X) if not self._sliced_X else X
|
||||
self._sliced_X = True
|
||||
try:
|
||||
ret = operation(X, *args, **kw)
|
||||
except: raise
|
||||
finally:
|
||||
self._sliced_X = False
|
||||
return ret
|
||||
else:
|
||||
def x_slice_wrapper(X, X2=None, *args, **kw):
|
||||
X, X2 = self._slice_X(X) if not self._sliced_X else X, self._slice_X(X2) if X2 is not None and not self._sliced_X2 else X2
|
||||
self._sliced_X = True
|
||||
self._sliced_X2 = True
|
||||
try:
|
||||
ret = operation(X, X2, *args, **kw)
|
||||
except: raise
|
||||
finally:
|
||||
self._sliced_X = False
|
||||
self._sliced_X2 = False
|
||||
return ret
|
||||
x_slice_wrapper._operation = operation
|
||||
x_slice_wrapper.__name__ = ("slicer("+operation.__name__
|
||||
+(","+str(bool(diag)) if diag else'')
|
||||
+(','+str(bool(derivative)) if derivative else '')
|
||||
+')')
|
||||
x_slice_wrapper.__doc__ = "**sliced**\n\n" + (operation.__doc__ or "")
|
||||
return x_slice_wrapper
|
||||
|
||||
def K(self, X, X2):
|
||||
raise NotImplementedError
|
||||
def Kdiag(self, Xa):
|
||||
def Kdiag(self, X):
|
||||
raise NotImplementedError
|
||||
def psi0(self, Z, variational_posterior):
|
||||
raise NotImplementedError
|
||||
|
|
@ -34,13 +115,16 @@ class Kern(Parameterized):
|
|||
raise NotImplementedError
|
||||
def gradients_X(self, dL_dK, X, X2):
|
||||
raise NotImplementedError
|
||||
def gradients_X_diag(self, dL_dK, X):
|
||||
def gradients_X_diag(self, dL_dKdiag, X):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def update_gradients_full(self, dL_dK, X, X2):
|
||||
"""Set the gradients of all parameters when doing full (N) inference."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_gradients_diag(self, dL_dKdiag, X):
|
||||
"""Set the gradients for all parameters for the derivative of the diagonal of the covariance w.r.t the kernel parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
"""
|
||||
Set the gradients of all parameters when doing inference with
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue