[lmv_dimselect] we need to keep a pointer to the lvm_dimselect object, as the updates are weak references: dim_select = ...

This commit is contained in:
Max Zwiessele 2014-05-20 14:43:18 +01:00
parent 01d6b91f90
commit dafad62363
2 changed files with 8 additions and 11 deletions

View file

@ -518,21 +518,21 @@ def stick_bgplvm(model=None, optimize=True, verbose=True, plot=True):
Q = 6
kernel = GPy.kern.RBF(Q, lengthscale=np.repeat(.5, Q), ARD=True)
m = BayesianGPLVM(data['Y'], Q, init="PCA", num_inducing=20, kernel=kernel)
m.data = data
m.likelihood.variance = 0.001
# optimize
if optimize: m.optimize('bfgs', messages=verbose, max_iters=800, xtol=1e-300, ftol=1e-300)
if optimize: m.optimize('bfgs', messages=verbose, max_iters=5e3, bfgs_factor=10)
if plot:
plt.clf, (latent_axes, sense_axes) = plt.subplots(1, 2)
fig, (latent_axes, sense_axes) = plt.subplots(1, 2)
plt.sca(latent_axes)
m.plot_latent(ax=latent_axes)
y = m.Y[:1, :].copy()
data_show = GPy.plotting.matplot_dep.visualize.stick_show(y, connect=data['connect'])
GPy.plotting.matplot_dep.visualize.lvm_dimselect(m.X.mean[:1, :].copy(), m, data_show, latent_axes=latent_axes, sense_axes=sense_axes)
plt.draw()
plt.show()
dim_select = GPy.plotting.matplot_dep.visualize.lvm_dimselect(m.X.mean[:1, :].copy(), m, data_show, latent_axes=latent_axes, sense_axes=sense_axes)
fig.canvas.draw()
fig.canvas.show()
raw_input('Press enter to finish')
return m

View file

@ -88,7 +88,6 @@ class vector_show(matplotlib_show):
class lvm(matplotlib_show):
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
"""Visualize a latent variable model
@ -147,7 +146,6 @@ class lvm(matplotlib_show):
pass
def on_click(self, event):
print 'click!'
if event.inaxes!=self.latent_axes: return
self.move_on = not self.move_on
self.called = True
@ -220,11 +218,10 @@ class lvm_dimselect(lvm):
self.labels = labels
lvm.__init__(self,vals,model,data_visualize,latent_axes,sense_axes,latent_index)
self.show_sensitivities()
print "use left and right mouse butons to select dimensions"
print "use left and right mouse buttons to select dimensions"
def on_click(self, event):
if event.inaxes==self.sense_axes:
new_index = max(0,min(int(np.round(event.xdata-0.5)),self.model.input_dim-1))
if event.button == 1: