From d5a8e5dd3aa0983c71cbd0cba31b5e4053a1e4e5 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Thu, 27 Mar 2014 13:08:54 +0000 Subject: [PATCH] right active dims when adding kernels --- GPy/kern/_src/add.py | 4 +++- GPy/kern/_src/kern.py | 10 +++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/GPy/kern/_src/add.py b/GPy/kern/_src/add.py index 57e611ed..ddc480de 100644 --- a/GPy/kern/_src/add.py +++ b/GPy/kern/_src/add.py @@ -176,5 +176,7 @@ class Add(CombinationKernel): for p in other_params: other.remove_parameter(p) self.add_parameters(*other_params) - else: self.add_parameter(other) + else: + self.add_parameter(other) + self.input_dim, self.active_dims = self.get_input_dim_active_dims(self.parts) return self \ No newline at end of file diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index 9d8d3f7b..5cd71215 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -208,9 +208,7 @@ class CombinationKernel(Kern): :param array-like|slice extra_dims: if needed extra dimensions for the combination kernel to work on """ assert all([isinstance(k, Kern) for k in kernels]) - active_dims = reduce(np.union1d, (np.r_[x.active_dims] for x in kernels), np.array([], dtype=int)) - input_dim = active_dims.max()+1 + len(extra_dims) - active_dims = slice(active_dims.max()+1+len(extra_dims)) + input_dim, active_dims = self.get_input_dim_active_dims(kernels, extra_dims) # initialize the kernel with the full input_dim super(CombinationKernel, self).__init__(input_dim, active_dims, name) self.extra_dims = extra_dims @@ -220,6 +218,12 @@ class CombinationKernel(Kern): def parts(self): return self._parameters_ + def get_input_dim_active_dims(self, kernels, extra_dims = None): + active_dims = reduce(np.union1d, (np.r_[x.active_dims] for x in kernels), np.array([], dtype=int)) + input_dim = active_dims.max()+1 + (len(extra_dims) if extra_dims is not None else 0) + active_dims = slice(input_dim) + return input_dim, active_dims + def input_sensitivity(self): in_sen = np.zeros((self.num_params, self.input_dim)) for i, p in enumerate(self.parts):