assertion checks for all kernels

This commit is contained in:
Max Zwiessele 2014-03-28 12:02:34 +00:00
parent 40ade9e157
commit 1dabf67c93
3 changed files with 2 additions and 14 deletions

View file

@ -23,7 +23,6 @@ class Add(CombinationKernel):
If a list of parts (of this kernel!) `which_parts` is given, only If a list of parts (of this kernel!) `which_parts` is given, only
the parts of the list are taken to compute the covariance. the parts of the list are taken to compute the covariance.
""" """
assert X.shape[1] > max(np.r_[self.active_dims])
if which_parts is None: if which_parts is None:
which_parts = self.parts which_parts = self.parts
elif not isinstance(which_parts, (list, tuple)): elif not isinstance(which_parts, (list, tuple)):
@ -33,7 +32,6 @@ class Add(CombinationKernel):
@Cache_this(limit=2, force_kwargs=['which_parts']) @Cache_this(limit=2, force_kwargs=['which_parts'])
def Kdiag(self, X, which_parts=None): def Kdiag(self, X, which_parts=None):
assert X.shape[1] > max(np.r_[self.active_dims])
if which_parts is None: if which_parts is None:
which_parts = self.parts which_parts = self.parts
elif not isinstance(which_parts, (list, tuple)): elif not isinstance(which_parts, (list, tuple)):
@ -160,16 +158,6 @@ class Add(CombinationKernel):
target_S += b target_S += b
return target_mu, target_S return target_mu, target_S
def _getstate(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return super(Add, self)._getstate()
def _setstate(self, state):
super(Add, self)._setstate(state)
def add(self, other, name='sum'): def add(self, other, name='sum'):
if isinstance(other, Add): if isinstance(other, Add):
other_params = other._parameters_[:] other_params = other._parameters_[:]

View file

@ -56,6 +56,7 @@ class _Slice_wrap(object):
def _slice_K(f): def _slice_K(f):
@wraps(f) @wraps(f)
def wrap(self, X, X2 = None, *a, **kw): def wrap(self, X, X2 = None, *a, **kw):
assert X.shape[1] > max(np.r_[self.active_dims]), "At least {} dimensional X needed".format(max(np.r_[self.active_dims]))
with _Slice_wrap(self, X, X2) as s: with _Slice_wrap(self, X, X2) as s:
ret = f(self, s.X, s.X2, *a, **kw) ret = f(self, s.X, s.X2, *a, **kw)
return ret return ret
@ -64,6 +65,7 @@ def _slice_K(f):
def _slice_Kdiag(f): def _slice_Kdiag(f):
@wraps(f) @wraps(f)
def wrap(self, X, *a, **kw): def wrap(self, X, *a, **kw):
assert X.shape[1] > max(np.r_[self.active_dims]), "At least {} dimensional X needed".format(max(np.r_[self.active_dims]))
with _Slice_wrap(self, X, None) as s: with _Slice_wrap(self, X, None) as s:
ret = f(self, s.X, *a, **kw) ret = f(self, s.X, *a, **kw)
return ret return ret

View file

@ -23,7 +23,6 @@ class Prod(CombinationKernel):
@Cache_this(limit=2, force_kwargs=['which_parts']) @Cache_this(limit=2, force_kwargs=['which_parts'])
def K(self, X, X2=None, which_parts=None): def K(self, X, X2=None, which_parts=None):
assert X.shape[1] == self.input_dim
if which_parts is None: if which_parts is None:
which_parts = self.parts which_parts = self.parts
elif not isinstance(which_parts, (list, tuple)): elif not isinstance(which_parts, (list, tuple)):
@ -33,7 +32,6 @@ class Prod(CombinationKernel):
@Cache_this(limit=2, force_kwargs=['which_parts']) @Cache_this(limit=2, force_kwargs=['which_parts'])
def Kdiag(self, X, which_parts=None): def Kdiag(self, X, which_parts=None):
assert X.shape[1] == self.input_dim
if which_parts is None: if which_parts is None:
which_parts = self.parts which_parts = self.parts
return reduce(np.multiply, (p.Kdiag(X) for p in which_parts)) return reduce(np.multiply, (p.Kdiag(X) for p in which_parts))