This commit is contained in:
James Hensman 2014-04-28 16:09:11 +01:00
commit 7d223d8df0
5 changed files with 14 additions and 10 deletions

View file

@ -176,7 +176,7 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40,
if plot: if plot:
y = m.Y y = m.Y
fig, (latent_axes, sense_axes) = plt.subplots(1, 2) fig, (latent_axes, sense_axes) = plt.subplots(1, 2)
m.plot_latent(ax=latent_axes) m.plot_latent(ax=latent_axes, labels=m.data_labels)
data_show = GPy.plotting.matplot_dep.visualize.vector_show(y) data_show = GPy.plotting.matplot_dep.visualize.vector_show(y)
lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean), # @UnusedVariable lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean), # @UnusedVariable
m, data_show, latent_axes=latent_axes, sense_axes=sense_axes) m, data_show, latent_axes=latent_axes, sense_axes=sense_axes)

View file

@ -167,4 +167,10 @@ class Add(CombinationKernel):
else: else:
self.add_parameter(other) self.add_parameter(other)
self.input_dim, self.active_dims = self.get_input_dim_active_dims(self.parts) self.input_dim, self.active_dims = self.get_input_dim_active_dims(self.parts)
return self return self
def input_sensitivity(self):
in_sen = np.zeros(self.input_dim)
for i, p in enumerate(self.parts):
in_sen[p.active_dims] += p.input_sensitivity()
return in_sen

View file

@ -238,7 +238,4 @@ class CombinationKernel(Kern):
return input_dim, active_dims return input_dim, active_dims
def input_sensitivity(self): def input_sensitivity(self):
in_sen = np.zeros((self.num_params, self.input_dim)) raise NotImplementedError("Choose the kernel you want to get the sensitivity for. You need to override the default behaviour for getting the input sensitivity to be able to get the input sensitivity. For sum kernel it is the sum of all sensitivities, TODO: product kernel? Other kernels?, also TODO: shall we return all the sensitivities here in the combination kernel? So we can combine them however we want? This could lead to just plot all the sensitivities here...")
for i, p in enumerate(self.parts):
in_sen[i, p.active_dims] = p.input_sensitivity()
return in_sen

View file

@ -42,6 +42,7 @@ class _Slice_wrap(object):
self.X2 = self.k._slice_X(X2) if X2 is not None else X2 self.X2 = self.k._slice_X(X2) if X2 is not None else X2
self.ret = True self.ret = True
else: else:
assert X.shape[1] == self.k.input_dim, "You did not specify active_dims and X has wrong shape: X_dim={} -- input_dim={}".format(X.shape[1], self.input_dim)
self.X = X self.X = X
self.X2 = X2 self.X2 = X2
self.ret = False self.ret = False

View file

@ -131,10 +131,10 @@ class lvm(matplotlib_show):
def modify(self, vals): def modify(self, vals):
"""When latent values are modified update the latent representation and ulso update the output visualization.""" """When latent values are modified update the latent representation and ulso update the output visualization."""
self.vals = vals[None,:].copy() self.vals = vals.copy()
y = self.model.predict(self.vals)[0] y = self.model.predict(self.vals)[0]
self.data_visualize.modify(y) self.data_visualize.modify(y)
self.latent_handle.set_data(self.vals[:,self.latent_index[0]], self.vals[:,self.latent_index[1]]) self.latent_handle.set_data(self.vals[0,self.latent_index[0]], self.vals[0,self.latent_index[1]])
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()
@ -153,8 +153,8 @@ class lvm(matplotlib_show):
if event.inaxes!=self.latent_axes: return if event.inaxes!=self.latent_axes: return
if self.called and self.move_on: if self.called and self.move_on:
# Call modify code on move # Call modify code on move
self.latent_values[self.latent_index[0]]=event.xdata self.latent_values[:, self.latent_index[0]]=event.xdata
self.latent_values[self.latent_index[1]]=event.ydata self.latent_values[:, self.latent_index[1]]=event.ydata
self.modify(self.latent_values) self.modify(self.latent_values)
def show_sensitivities(self): def show_sensitivities(self):