From d4d54bacb3ff8d40104eeeff13ecba1ef04304fb Mon Sep 17 00:00:00 2001 From: mzwiessele Date: Mon, 28 Apr 2014 18:22:53 +0100 Subject: [PATCH] bugfix: slicing checks needed to be suspended for combination kernels, checks are done in inner kernels now --- GPy/kern/_src/kern.py | 14 ++++++++++++++ GPy/kern/_src/kernel_slice_operations.py | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index 1ce7cd1e..70bd42b9 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -201,6 +201,13 @@ class Kern(Parameterized): #else: kernels.append(other) return Prod([self, other], name) + def _check_input_dim(self, X): + assert X.shape[1] == self.input_dim, "You did not specify active_dims and X has wrong shape: X_dim={}, whereas input_dim={}".format(X.shape[1], self.input_dim) + + def _check_active_dims(self, X): + assert X.shape[1] >= len(np.r_[self.active_dims]), "At least {} dimensional X needed, X.shape={!s}".format(len(np.r_[self.active_dims]), X.shape) + + class CombinationKernel(Kern): """ Abstract super class for combination kernels. @@ -239,3 +246,10 @@ class CombinationKernel(Kern): def input_sensitivity(self): raise NotImplementedError("Choose the kernel you want to get the sensitivity for. You need to override the default behaviour for getting the input sensitivity to be able to get the input sensitivity. For sum kernel it is the sum of all sensitivities, TODO: product kernel? Other kernels?, also TODO: shall we return all the sensitivities here in the combination kernel? So we can combine them however we want? This could lead to just plot all the sensitivities here...") + + def _check_input_dim(self, X): + return + + def _check_input_dim(self, X): + # As combination kernels cannot always know, what their inner kernels have as input dims, the check will be done inside them, respectively + return diff --git a/GPy/kern/_src/kernel_slice_operations.py b/GPy/kern/_src/kernel_slice_operations.py index ce7f587c..c1c8d7f1 100644 --- a/GPy/kern/_src/kernel_slice_operations.py +++ b/GPy/kern/_src/kernel_slice_operations.py @@ -37,12 +37,12 @@ class _Slice_wrap(object): if X2 is not None: assert X2.ndim == 2, "only matrices are allowed as inputs to kernels for now, given X2.shape={!s}".format(X2.shape) if (self.k.active_dims is not None) and (self.k._sliced_X == 0): - assert X.shape[1] >= len(np.r_[self.k.active_dims]), "At least {} dimensional X needed, X.shape={!s}".format(len(np.r_[self.k.active_dims]), X.shape) + self.k._check_active_dims(X) self.X = self.k._slice_X(X) self.X2 = self.k._slice_X(X2) if X2 is not None else X2 self.ret = True else: - assert X.shape[1] == self.k.input_dim, "You did not specify active_dims and X has wrong shape: X_dim={} -- input_dim={}".format(X.shape[1], self.k.input_dim) + self.k._check_input_dim(X) self.X = X self.X2 = X2 self.ret = False