switch input_sensitivity function to model

This commit is contained in:
Zhenwen Dai 2014-03-04 14:25:11 +00:00
parent 0f37cc721b
commit b9dcb7f640
4 changed files with 14 additions and 2 deletions

View file

@ -147,6 +147,12 @@ class Model(Parameterized):
""" """
raise DeprecationWarning, 'parameters now have default constraints' raise DeprecationWarning, 'parameters now have default constraints'
def input_sensitivity(self):
"""
Returns the sensitivity for each dimension of this kernel.
"""
return self.kern.input_sensitivity()
def objective_function(self, x): def objective_function(self, x):
""" """
The objective function passed to the optimizer. It combines The objective function passed to the optimizer. It combines

View file

@ -89,7 +89,7 @@ class Kern(Parameterized):
""" """
Returns the sensitivity for each dimension of this kernel. Returns the sensitivity for each dimension of this kernel.
""" """
return np.zeros(self.input_dim) return self.kern.input_sensitivity()
def __add__(self, other): def __add__(self, other):
""" Overloading of the '+' operator. for more control, see self.add """ """ Overloading of the '+' operator. for more control, see self.add """

View file

@ -66,6 +66,12 @@ class SSGPLVM(SparseGP):
# update for the KL divergence # update for the KL divergence
self.variational_prior.update_gradients_KL(self.X) self.variational_prior.update_gradients_KL(self.X)
def input_sensitivity(self):
if self.kern.ARD:
return self.kern.input_sensitivity()
else:
return self.variational_prior.pi
def plot_latent(self, plot_inducing=True, *args, **kwargs): def plot_latent(self, plot_inducing=True, *args, **kwargs):
import sys import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported." assert "matplotlib" in sys.modules, "matplotlib package has not been imported."

View file

@ -20,7 +20,7 @@ def most_significant_input_dimensions(model, which_indices):
input_1, input_2 = 0, 1 input_1, input_2 = 0, 1
else: else:
try: try:
input_1, input_2 = np.argsort(model.kern.input_sensitivity())[::-1][:2] input_1, input_2 = np.argsort(model.input_sensitivity())[::-1][:2]
except: except:
raise ValueError, "cannot automatically determine which dimensions to plot, please pass 'which_indices'" raise ValueError, "cannot automatically determine which dimensions to plot, please pass 'which_indices'"
else: else: