bug fix: slicing was not checking dimensions

This commit is contained in:
Max Zwiessele 2014-04-16 10:12:02 +01:00
parent a57ca26c89
commit 541aa1c8b6
2 changed files with 28 additions and 7 deletions

View file

@ -21,14 +21,37 @@ class Kern(Parameterized):
The base class for a kernel: a positive definite function The base class for a kernel: a positive definite function
which forms of a covariance function (kernel). which forms of a covariance function (kernel).
input_dim:
is the number of dimensions to work on. Make sure to give the
tight dimensionality of inputs.
You moset likely want this to be the integer telling the number of
input dimensions of the kernel.
If this is not an integer (!) we will work on the whole input matrix X,
and not check whether dimensions match or not (!).
active_dims:
is the active_dimensions of inputs X we will work on.
All kernels will get sliced Xes as inputs, if active_dims is not None
if active_dims is None, slicing is switched off and all X will be passed through as given.
:param int input_dim: the number of input dimensions to the function :param int input_dim: the number of input dimensions to the function
:param array-like|slice active_dims: list of indices on which dimensions this kernel works on :param array-like|slice|None active_dims: list of indices on which dimensions this kernel works on, or none if no slicing
Do not instantiate. Do not instantiate.
""" """
super(Kern, self).__init__(name=name, *a, **kw) super(Kern, self).__init__(name=name, *a, **kw)
self.active_dims = active_dims# if active_dims is not None else slice(0, input_dim, 1) try:
self.input_dim = input_dim self.input_dim = int(input_dim)
self.active_dims = active_dims if active_dims is not None else slice(0, input_dim, 1)
except TypeError:
# input_dim is something else then an integer
self.input_dim = input_dim
if active_dims is not None:
print "WARNING: given input_dim={} is not an integer and active_dims={} is given, switching off slicing"
self.active_dims = None
if self.active_dims is not None and self.input_dim is not None: if self.active_dims is not None and self.input_dim is not None:
assert isinstance(self.active_dims, (slice, list, tuple, np.ndarray)), 'active_dims needs to be an array-like or slice object over dimensions, {} given'.format(self.active_dims.__class__) assert isinstance(self.active_dims, (slice, list, tuple, np.ndarray)), 'active_dims needs to be an array-like or slice object over dimensions, {} given'.format(self.active_dims.__class__)
if isinstance(self.active_dims, slice): if isinstance(self.active_dims, slice):
@ -46,9 +69,7 @@ class Kern(Parameterized):
@Cache_this(limit=10) @Cache_this(limit=10)
def _slice_X(self, X): def _slice_X(self, X):
if self.active_dims is not None: return X[:, self.active_dims]
return X[:, self.active_dims]
return X
def K(self, X, X2): def K(self, X, X2):
""" """

View file

@ -34,7 +34,7 @@ class _Slice_wrap(object):
self.k = k self.k = k
self.shape = X.shape self.shape = X.shape
if (self.k.active_dims is not None) and (self.k._sliced_X == 0): if (self.k.active_dims is not None) and (self.k._sliced_X == 0):
#assert X.shape[1] > len(np.r_[self.k.active_dims]), "At least {} dimensional X needed".format(len(np.r_[self.k.active_dims])) assert X.shape[1] >= len(np.r_[self.k.active_dims]), "At least {} dimensional X needed, X.shape={!s}".format(len(np.r_[self.k.active_dims]), X.shape)
self.X = self.k._slice_X(X) self.X = self.k._slice_X(X)
self.X2 = self.k._slice_X(X2) if X2 is not None else X2 self.X2 = self.k._slice_X(X2) if X2 is not None else X2
self.ret = True self.ret = True