mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 10:32:39 +02:00
switch input_sensitivity function to model
This commit is contained in:
parent
0f37cc721b
commit
b9dcb7f640
4 changed files with 14 additions and 2 deletions
|
|
@ -147,6 +147,12 @@ class Model(Parameterized):
|
||||||
"""
|
"""
|
||||||
raise DeprecationWarning, 'parameters now have default constraints'
|
raise DeprecationWarning, 'parameters now have default constraints'
|
||||||
|
|
||||||
|
def input_sensitivity(self):
|
||||||
|
"""
|
||||||
|
Returns the sensitivity for each dimension of this kernel.
|
||||||
|
"""
|
||||||
|
return self.kern.input_sensitivity()
|
||||||
|
|
||||||
def objective_function(self, x):
|
def objective_function(self, x):
|
||||||
"""
|
"""
|
||||||
The objective function passed to the optimizer. It combines
|
The objective function passed to the optimizer. It combines
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ class Kern(Parameterized):
|
||||||
"""
|
"""
|
||||||
Returns the sensitivity for each dimension of this kernel.
|
Returns the sensitivity for each dimension of this kernel.
|
||||||
"""
|
"""
|
||||||
return np.zeros(self.input_dim)
|
return self.kern.input_sensitivity()
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
""" Overloading of the '+' operator. for more control, see self.add """
|
""" Overloading of the '+' operator. for more control, see self.add """
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,12 @@ class SSGPLVM(SparseGP):
|
||||||
# update for the KL divergence
|
# update for the KL divergence
|
||||||
self.variational_prior.update_gradients_KL(self.X)
|
self.variational_prior.update_gradients_KL(self.X)
|
||||||
|
|
||||||
|
def input_sensitivity(self):
|
||||||
|
if self.kern.ARD:
|
||||||
|
return self.kern.input_sensitivity()
|
||||||
|
else:
|
||||||
|
return self.variational_prior.pi
|
||||||
|
|
||||||
def plot_latent(self, plot_inducing=True, *args, **kwargs):
|
def plot_latent(self, plot_inducing=True, *args, **kwargs):
|
||||||
import sys
|
import sys
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ def most_significant_input_dimensions(model, which_indices):
|
||||||
input_1, input_2 = 0, 1
|
input_1, input_2 = 0, 1
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
input_1, input_2 = np.argsort(model.kern.input_sensitivity())[::-1][:2]
|
input_1, input_2 = np.argsort(model.input_sensitivity())[::-1][:2]
|
||||||
except:
|
except:
|
||||||
raise ValueError, "cannot automatically determine which dimensions to plot, please pass 'which_indices'"
|
raise ValueError, "cannot automatically determine which dimensions to plot, please pass 'which_indices'"
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue