mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
more fun with vizualize
This commit is contained in:
parent
738d53d4bc
commit
a7a5480e6e
2 changed files with 28 additions and 9 deletions
|
|
@ -42,12 +42,12 @@ def BGPLVM(seed = default_seed):
|
|||
|
||||
return m
|
||||
|
||||
def GPLVM_oil_100(optimize=True,M=15):
|
||||
def GPLVM_oil_100(optimize=True):
|
||||
data = GPy.util.datasets.oil_100()
|
||||
|
||||
# create simple GP model
|
||||
kernel = GPy.kern.rbf(6, ARD = True) + GPy.kern.bias(6)
|
||||
m = GPy.models.GPLVM(data['X'], 6, kernel=kernel, M=M)
|
||||
m = GPy.models.GPLVM(data['X'], 6, kernel=kernel)
|
||||
m.data_labels = data['Y'].argmax(axis=1)
|
||||
|
||||
# optimize
|
||||
|
|
|
|||
|
|
@ -9,9 +9,13 @@ class lvm:
|
|||
if isinstance(latent_axes,mpl.axes.Axes):
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
||||
else:
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
||||
self.data_visualize = data_visualize
|
||||
self.model = model
|
||||
self.latent_axes = latent_axes
|
||||
|
|
@ -21,6 +25,11 @@ class lvm:
|
|||
self.latent_index = latent_index
|
||||
self.latent_dim = model.Q
|
||||
|
||||
def on_enter(self,event):
|
||||
pass
|
||||
def on_leave(self,event):
|
||||
pass
|
||||
|
||||
def on_click(self, event):
|
||||
#print 'click', event.xdata, event.ydata
|
||||
if event.inaxes!=self.latent_axes: return
|
||||
|
|
@ -68,23 +77,25 @@ class lvm_dimselect(lvm):
|
|||
|
||||
lvm.__init__(self,model,data_visualize,latent_axes,[0,1])
|
||||
self.latent_values_clicked = np.zeros(model.Q)
|
||||
self._first_index_next = False
|
||||
self.clicked_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
|
||||
print "use left and right mouse butons to select dimensions"
|
||||
|
||||
def on_click(self, event):
|
||||
#print "click"
|
||||
if event.inaxes==self.hist_axes:
|
||||
self.hist_axes.cla()
|
||||
self.hist_axes.bar(np.arange(self.model.Q),1./self.model.input_sensitivity(),color='b')
|
||||
new_index = int(np.round(event.xdata))
|
||||
self.latent_index[int(self._first_index_next)] = new_index
|
||||
self._first_index_next = not self._first_index_next
|
||||
new_index = max(0,min(int(np.round(event.xdata-0.5)),self.model.Q-1))
|
||||
self.latent_index[(0 if event.button==1 else 1)] = new_index
|
||||
self.hist_axes.bar(np.array(self.latent_index),1./self.model.input_sensitivity()[self.latent_index],color='r')
|
||||
self.latent_axes.cla()
|
||||
self.model.plot_latent(which_indices = self.latent_index,ax=self.latent_axes)
|
||||
self.fig.canvas.draw()
|
||||
self.clicked_handle = self.latent_axes.plot([self.latent_values_clicked[self.latent_index[0]]],self.latent_values_clicked[self.latent_index[1]],'rx',mew=2)[0]
|
||||
if event.inaxes==self.latent_axes:
|
||||
#self.latent_values_clicked[self.latent_index] = np.array([event.xdata,event.ydata])
|
||||
pass
|
||||
self.clicked_handle.set_visible(False)
|
||||
self.latent_values_clicked[self.latent_index] = np.array([event.xdata,event.ydata])
|
||||
self.clicked_handle = self.latent_axes.plot([self.latent_values_clicked[self.latent_index[0]]],self.latent_values_clicked[self.latent_index[1]],'rx',mew=2)[0]
|
||||
self.fig.canvas.draw()
|
||||
self.move_on=True
|
||||
self.called = True
|
||||
|
||||
|
|
@ -98,6 +109,14 @@ class lvm_dimselect(lvm):
|
|||
y = self.model.predict(latent_values[None,:])[0]
|
||||
self.data_visualize.modify(y)
|
||||
|
||||
def on_leave(self,event):
|
||||
latent_values = self.latent_values_clicked.copy()
|
||||
y = self.model.predict(latent_values[None,:])[0]
|
||||
self.data_visualize.modify(y)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class data_show:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue