diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index 4fcbf31f..e2647901 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -162,6 +162,10 @@ class Kern(Parameterized): def __mul__(self, other): """ Here we overload the '*' operator. See self.prod for more information""" return self.prod(other) + + def __imul__(self, other): + """ Here we overload the '*' operator. See self.prod for more information""" + return self.prod(other) def __pow__(self, other): """ @@ -183,7 +187,7 @@ class Kern(Parameterized): :type tensor: bool """ - assert isinstance(other, Kern), "only kernels can be added to kernels..." + assert isinstance(other, Kern), "only kernels can be multiplied to kernels..." from prod import Prod #kernels = [] #if isinstance(self, Prod): kernels.extend(self.parameters) diff --git a/GPy/kern/_src/prod.py b/GPy/kern/_src/prod.py index ba7f9450..dd9a5fe4 100644 --- a/GPy/kern/_src/prod.py +++ b/GPy/kern/_src/prod.py @@ -18,6 +18,12 @@ class Prod(CombinationKernel): """ def __init__(self, kernels, name='mul'): + for i, kern in enumerate(kernels[:]): + if isinstance(kern, Prod): + del kernels[i] + for part in kern.parts[::-1]: + kern.unlink_parameter(part) + kernels.insert(i, part) super(Prod, self).__init__(kernels, name) @Cache_this(limit=2, force_kwargs=['which_parts'])