diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index 1ce7cd1e..70bd42b9 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -201,6 +201,13 @@ class Kern(Parameterized): #else: kernels.append(other) return Prod([self, other], name) + def _check_input_dim(self, X): + assert X.shape[1] == self.input_dim, "You did not specify active_dims and X has wrong shape: X_dim={}, whereas input_dim={}".format(X.shape[1], self.input_dim) + + def _check_active_dims(self, X): + assert X.shape[1] >= len(np.r_[self.active_dims]), "At least {} dimensional X needed, X.shape={!s}".format(len(np.r_[self.active_dims]), X.shape) + + class CombinationKernel(Kern): """ Abstract super class for combination kernels. @@ -239,3 +246,10 @@ class CombinationKernel(Kern): def input_sensitivity(self): raise NotImplementedError("Choose the kernel you want to get the sensitivity for. You need to override the default behaviour for getting the input sensitivity to be able to get the input sensitivity. For sum kernel it is the sum of all sensitivities, TODO: product kernel? Other kernels?, also TODO: shall we return all the sensitivities here in the combination kernel? So we can combine them however we want? This could lead to just plot all the sensitivities here...") + + def _check_input_dim(self, X): + return + + def _check_input_dim(self, X): + # As combination kernels cannot always know, what their inner kernels have as input dims, the check will be done inside them, respectively + return diff --git a/GPy/kern/_src/kernel_slice_operations.py b/GPy/kern/_src/kernel_slice_operations.py index ce7f587c..c1c8d7f1 100644 --- a/GPy/kern/_src/kernel_slice_operations.py +++ b/GPy/kern/_src/kernel_slice_operations.py @@ -37,12 +37,12 @@ class _Slice_wrap(object): if X2 is not None: assert X2.ndim == 2, "only matrices are allowed as inputs to kernels for now, given X2.shape={!s}".format(X2.shape) if (self.k.active_dims is not None) and (self.k._sliced_X == 0): - assert X.shape[1] >= len(np.r_[self.k.active_dims]), "At least {} dimensional X needed, X.shape={!s}".format(len(np.r_[self.k.active_dims]), X.shape) + self.k._check_active_dims(X) self.X = self.k._slice_X(X) self.X2 = self.k._slice_X(X2) if X2 is not None else X2 self.ret = True else: - assert X.shape[1] == self.k.input_dim, "You did not specify active_dims and X has wrong shape: X_dim={} -- input_dim={}".format(X.shape[1], self.k.input_dim) + self.k._check_input_dim(X) self.X = X self.X2 = X2 self.ret = False