bugfix: slicing checks needed to be suspended for combination kernels, checks are done in inner kernels now

This commit is contained in:
mzwiessele 2014-04-28 18:22:53 +01:00
parent 2a36b5afee
commit d4d54bacb3
2 changed files with 16 additions and 2 deletions

View file

@ -201,6 +201,13 @@ class Kern(Parameterized):
#else: kernels.append(other) #else: kernels.append(other)
return Prod([self, other], name) return Prod([self, other], name)
def _check_input_dim(self, X):
assert X.shape[1] == self.input_dim, "You did not specify active_dims and X has wrong shape: X_dim={}, whereas input_dim={}".format(X.shape[1], self.input_dim)
def _check_active_dims(self, X):
assert X.shape[1] >= len(np.r_[self.active_dims]), "At least {} dimensional X needed, X.shape={!s}".format(len(np.r_[self.active_dims]), X.shape)
class CombinationKernel(Kern): class CombinationKernel(Kern):
""" """
Abstract super class for combination kernels. Abstract super class for combination kernels.
@ -239,3 +246,10 @@ class CombinationKernel(Kern):
def input_sensitivity(self): def input_sensitivity(self):
raise NotImplementedError("Choose the kernel you want to get the sensitivity for. You need to override the default behaviour for getting the input sensitivity to be able to get the input sensitivity. For sum kernel it is the sum of all sensitivities, TODO: product kernel? Other kernels?, also TODO: shall we return all the sensitivities here in the combination kernel? So we can combine them however we want? This could lead to just plot all the sensitivities here...") raise NotImplementedError("Choose the kernel you want to get the sensitivity for. You need to override the default behaviour for getting the input sensitivity to be able to get the input sensitivity. For sum kernel it is the sum of all sensitivities, TODO: product kernel? Other kernels?, also TODO: shall we return all the sensitivities here in the combination kernel? So we can combine them however we want? This could lead to just plot all the sensitivities here...")
def _check_input_dim(self, X):
return
def _check_input_dim(self, X):
# As combination kernels cannot always know, what their inner kernels have as input dims, the check will be done inside them, respectively
return

View file

@ -37,12 +37,12 @@ class _Slice_wrap(object):
if X2 is not None: if X2 is not None:
assert X2.ndim == 2, "only matrices are allowed as inputs to kernels for now, given X2.shape={!s}".format(X2.shape) assert X2.ndim == 2, "only matrices are allowed as inputs to kernels for now, given X2.shape={!s}".format(X2.shape)
if (self.k.active_dims is not None) and (self.k._sliced_X == 0): if (self.k.active_dims is not None) and (self.k._sliced_X == 0):
assert X.shape[1] >= len(np.r_[self.k.active_dims]), "At least {} dimensional X needed, X.shape={!s}".format(len(np.r_[self.k.active_dims]), X.shape) self.k._check_active_dims(X)
self.X = self.k._slice_X(X) self.X = self.k._slice_X(X)
self.X2 = self.k._slice_X(X2) if X2 is not None else X2 self.X2 = self.k._slice_X(X2) if X2 is not None else X2
self.ret = True self.ret = True
else: else:
assert X.shape[1] == self.k.input_dim, "You did not specify active_dims and X has wrong shape: X_dim={} -- input_dim={}".format(X.shape[1], self.k.input_dim) self.k._check_input_dim(X)
self.X = X self.X = X
self.X2 = X2 self.X2 = X2
self.ret = False self.ret = False