diff --git a/GPy/core/model.py b/GPy/core/model.py index d27cbc69..bf8915c6 100644 --- a/GPy/core/model.py +++ b/GPy/core/model.py @@ -147,6 +147,12 @@ class Model(Parameterized): """ raise DeprecationWarning, 'parameters now have default constraints' + def input_sensitivity(self): + """ + Returns the sensitivity for each dimension of this kernel. + """ + return self.kern.input_sensitivity() + def objective_function(self, x): """ The objective function passed to the optimizer. It combines diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index 07f3fdf7..f632783b 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -89,7 +89,7 @@ class Kern(Parameterized): """ Returns the sensitivity for each dimension of this kernel. """ - return np.zeros(self.input_dim) + return self.kern.input_sensitivity() def __add__(self, other): """ Overloading of the '+' operator. for more control, see self.add """ diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index 37309c94..5994814b 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -66,6 +66,12 @@ class SSGPLVM(SparseGP): # update for the KL divergence self.variational_prior.update_gradients_KL(self.X) + def input_sensitivity(self): + if self.kern.ARD: + return self.kern.input_sensitivity() + else: + return self.variational_prior.pi + def plot_latent(self, plot_inducing=True, *args, **kwargs): import sys assert "matplotlib" in sys.modules, "matplotlib package has not been imported." diff --git a/GPy/plotting/matplot_dep/dim_reduction_plots.py b/GPy/plotting/matplot_dep/dim_reduction_plots.py index 10b352d3..bf9297b9 100644 --- a/GPy/plotting/matplot_dep/dim_reduction_plots.py +++ b/GPy/plotting/matplot_dep/dim_reduction_plots.py @@ -20,7 +20,7 @@ def most_significant_input_dimensions(model, which_indices): input_1, input_2 = 0, 1 else: try: - input_1, input_2 = np.argsort(model.kern.input_sensitivity())[::-1][:2] + input_1, input_2 = np.argsort(model.input_sensitivity())[::-1][:2] except: raise ValueError, "cannot automatically determine which dimensions to plot, please pass 'which_indices'" else: