more fun with vizualize

This commit is contained in:
James Hensman 2013-04-12 22:41:37 +01:00
parent 738d53d4bc
commit a7a5480e6e
2 changed files with 28 additions and 9 deletions

View file

@ -42,12 +42,12 @@ def BGPLVM(seed = default_seed):
return m
def GPLVM_oil_100(optimize=True,M=15):
def GPLVM_oil_100(optimize=True):
data = GPy.util.datasets.oil_100()
# create simple GP model
kernel = GPy.kern.rbf(6, ARD = True) + GPy.kern.bias(6)
m = GPy.models.GPLVM(data['X'], 6, kernel=kernel, M=M)
m = GPy.models.GPLVM(data['X'], 6, kernel=kernel)
m.data_labels = data['Y'].argmax(axis=1)
# optimize

View file

@ -9,9 +9,13 @@ class lvm:
if isinstance(latent_axes,mpl.axes.Axes):
self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click)
self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.cid = latent_axes.figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
self.cid = latent_axes.figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
else:
self.cid = latent_axes[0].figure.canvas.mpl_connect('button_press_event', self.on_click)
self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
self.data_visualize = data_visualize
self.model = model
self.latent_axes = latent_axes
@ -21,6 +25,11 @@ class lvm:
self.latent_index = latent_index
self.latent_dim = model.Q
def on_enter(self,event):
pass
def on_leave(self,event):
pass
def on_click(self, event):
#print 'click', event.xdata, event.ydata
if event.inaxes!=self.latent_axes: return
@ -68,23 +77,25 @@ class lvm_dimselect(lvm):
lvm.__init__(self,model,data_visualize,latent_axes,[0,1])
self.latent_values_clicked = np.zeros(model.Q)
self._first_index_next = False
self.clicked_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
print "use left and right mouse butons to select dimensions"
def on_click(self, event):
#print "click"
if event.inaxes==self.hist_axes:
self.hist_axes.cla()
self.hist_axes.bar(np.arange(self.model.Q),1./self.model.input_sensitivity(),color='b')
new_index = int(np.round(event.xdata))
self.latent_index[int(self._first_index_next)] = new_index
self._first_index_next = not self._first_index_next
new_index = max(0,min(int(np.round(event.xdata-0.5)),self.model.Q-1))
self.latent_index[(0 if event.button==1 else 1)] = new_index
self.hist_axes.bar(np.array(self.latent_index),1./self.model.input_sensitivity()[self.latent_index],color='r')
self.latent_axes.cla()
self.model.plot_latent(which_indices = self.latent_index,ax=self.latent_axes)
self.fig.canvas.draw()
self.clicked_handle = self.latent_axes.plot([self.latent_values_clicked[self.latent_index[0]]],self.latent_values_clicked[self.latent_index[1]],'rx',mew=2)[0]
if event.inaxes==self.latent_axes:
#self.latent_values_clicked[self.latent_index] = np.array([event.xdata,event.ydata])
pass
self.clicked_handle.set_visible(False)
self.latent_values_clicked[self.latent_index] = np.array([event.xdata,event.ydata])
self.clicked_handle = self.latent_axes.plot([self.latent_values_clicked[self.latent_index[0]]],self.latent_values_clicked[self.latent_index[1]],'rx',mew=2)[0]
self.fig.canvas.draw()
self.move_on=True
self.called = True
@ -98,6 +109,14 @@ class lvm_dimselect(lvm):
y = self.model.predict(latent_values[None,:])[0]
self.data_visualize.modify(y)
def on_leave(self,event):
latent_values = self.latent_values_clicked.copy()
y = self.model.predict(latent_values[None,:])[0]
self.data_visualize.modify(y)
class data_show: