mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-04-28 06:16:24 +02:00
more work on visualize
This commit is contained in:
parent
a42ec441e4
commit
738d53d4bc
2 changed files with 120 additions and 53 deletions
|
|
@ -60,7 +60,7 @@ class GPLVM(GP):
|
||||||
mu, var, upper, lower = self.predict(Xnew)
|
mu, var, upper, lower = self.predict(Xnew)
|
||||||
pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5)
|
pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5)
|
||||||
|
|
||||||
def plot_latent(self,labels=None, which_indices=None, resolution=50):
|
def plot_latent(self,labels=None, which_indices=None, resolution=50,ax=pb.gca()):
|
||||||
"""
|
"""
|
||||||
:param labels: a np.array of size self.N containing labels for the points (can be number, strings, etc)
|
:param labels: a np.array of size self.N containing labels for the points (can be number, strings, etc)
|
||||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||||
|
|
@ -89,8 +89,8 @@ class GPLVM(GP):
|
||||||
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
||||||
Xtest_full[:, :2] = Xtest
|
Xtest_full[:, :2] = Xtest
|
||||||
mu, var, low, up = self.predict(Xtest_full)
|
mu, var, low, up = self.predict(Xtest_full)
|
||||||
var = var[:, :1]
|
var = var[:, :1]
|
||||||
pb.imshow(var.reshape(resolution,resolution).T[::-1,:],
|
ax.imshow(var.reshape(resolution,resolution).T[::-1,:],
|
||||||
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear')
|
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear')
|
||||||
|
|
||||||
for i,ul in enumerate(np.unique(labels)):
|
for i,ul in enumerate(np.unique(labels)):
|
||||||
|
|
@ -108,17 +108,16 @@ class GPLVM(GP):
|
||||||
else:
|
else:
|
||||||
x = self.X[index,input_1]
|
x = self.X[index,input_1]
|
||||||
y = self.X[index,input_2]
|
y = self.X[index,input_2]
|
||||||
pb.plot(x,y,marker='o',color=util.plot.Tango.nextMedium(),mew=0,label=this_label,linewidth=0)
|
ax.plot(x,y,marker='o',color=util.plot.Tango.nextMedium(),mew=0,label=this_label,linewidth=0)
|
||||||
|
|
||||||
pb.xlabel('latent dimension %i'%input_1)
|
ax.set_xlabel('latent dimension %i'%input_1)
|
||||||
pb.ylabel('latent dimension %i'%input_2)
|
ax.set_ylabel('latent dimension %i'%input_2)
|
||||||
|
|
||||||
if not np.all(labels==1.):
|
if not np.all(labels==1.):
|
||||||
pb.legend(loc=0,numpoints=1)
|
ax.legend(loc=0,numpoints=1)
|
||||||
|
|
||||||
pb.xlim(xmin[0],xmax[0])
|
ax.set_xlim(xmin[0],xmax[0])
|
||||||
pb.ylim(xmin[1],xmax[1])
|
ax.set_ylim(xmin[1],xmax[1])
|
||||||
pb.grid(b=False) # remove the grid if present, it doesn't look good
|
ax.grid(b=False) # remove the grid if present, it doesn't look good
|
||||||
ax = pb.gca()
|
|
||||||
ax.set_aspect('auto') # set a nice aspect ratio
|
ax.set_aspect('auto') # set a nice aspect ratio
|
||||||
return ax
|
return ax
|
||||||
|
|
|
||||||
|
|
@ -2,22 +2,28 @@ import matplotlib.pyplot as plt
|
||||||
from mpl_toolkits.mplot3d import Axes3D
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
import GPy
|
import GPy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import matplotlib as mpl
|
||||||
|
|
||||||
class lvm:
|
class lvm:
|
||||||
def __init__(self, model, data_visualize, latent_axis, latent_index=[0,1], latent_dim=2):
|
def __init__(self, model, data_visualize, latent_axes, latent_index=[0,1]):
|
||||||
self.cid = latent_axis.figure.canvas.mpl_connect('button_press_event', self.on_click)
|
if isinstance(latent_axes,mpl.axes.Axes):
|
||||||
self.cid = latent_axis.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||||
|
self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||||
|
else:
|
||||||
|
self.cid = latent_axes[0].figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||||
|
self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||||
self.data_visualize = data_visualize
|
self.data_visualize = data_visualize
|
||||||
self.model = model
|
self.model = model
|
||||||
self.latent_axis = latent_axis
|
self.latent_axes = latent_axes
|
||||||
|
|
||||||
self.called = False
|
self.called = False
|
||||||
self.move_on = False
|
self.move_on = False
|
||||||
self.latent_index = latent_index
|
self.latent_index = latent_index
|
||||||
self.latent_dim = latent_dim
|
self.latent_dim = model.Q
|
||||||
|
|
||||||
def on_click(self, event):
|
def on_click(self, event):
|
||||||
#print 'click', event.xdata, event.ydata
|
#print 'click', event.xdata, event.ydata
|
||||||
if event.inaxes!=self.latent_axis: return
|
if event.inaxes!=self.latent_axes: return
|
||||||
self.move_on = not self.move_on
|
self.move_on = not self.move_on
|
||||||
# if self.called:
|
# if self.called:
|
||||||
# self.xs.append(event.xdata)
|
# self.xs.append(event.xdata)
|
||||||
|
|
@ -27,10 +33,10 @@ class lvm:
|
||||||
# else:
|
# else:
|
||||||
# self.xs = [event.xdata]
|
# self.xs = [event.xdata]
|
||||||
# self.ys = [event.ydata]
|
# self.ys = [event.ydata]
|
||||||
# self.line, = self.latent_axis.plot(event.xdata, event.ydata)
|
# self.line, = self.latent_axes.plot(event.xdata, event.ydata)
|
||||||
self.called = True
|
self.called = True
|
||||||
def on_move(self, event):
|
def on_move(self, event):
|
||||||
if event.inaxes!=self.latent_axis: return
|
if event.inaxes!=self.latent_axes: return
|
||||||
if self.called and self.move_on:
|
if self.called and self.move_on:
|
||||||
# Call modify code on move
|
# Call modify code on move
|
||||||
#print 'move', event.xdata, event.ydata
|
#print 'move', event.xdata, event.ydata
|
||||||
|
|
@ -40,52 +46,114 @@ class lvm:
|
||||||
self.data_visualize.modify(y)
|
self.data_visualize.modify(y)
|
||||||
#print 'y', y
|
#print 'y', y
|
||||||
|
|
||||||
class data_show:
|
class lvm_subplots(lvm):
|
||||||
"""The data show class is a base class which describes how to visualize a particular data set. For example, motion capture data can be plotted as a stick figure, or images are shown using imshow. This class enables latent to data visualizations for the GP-LVM."""
|
"""
|
||||||
|
latent_axes is a np array of dimension np.ceil(Q/2) + 1,
|
||||||
|
one for each pair of the axes, and the last one for the sensitiity histogram
|
||||||
|
"""
|
||||||
|
def __init__(self, model, data_visualize, latent_axes=None, latent_index=[0,1]):
|
||||||
|
self.nplots = int(np.ceil(model.Q/2.))+1
|
||||||
|
lvm.__init__(self,model,data_visualize,latent_axes,latent_index)
|
||||||
|
self.latent_values = np.zeros(2*np.ceil(self.model.Q/2.)) # possibly an extra dimension on this
|
||||||
|
assert latent_axes.size == self.nplots
|
||||||
|
|
||||||
def __init__(self, vals, axis=None):
|
|
||||||
|
class lvm_dimselect(lvm):
|
||||||
|
"""
|
||||||
|
A visualizer for latent variable models
|
||||||
|
with selection by clicking on the histogram
|
||||||
|
"""
|
||||||
|
def __init__(self, model, data_visualize):
|
||||||
|
self.fig,(latent_axes,self.hist_axes) = plt.subplots(1,2)
|
||||||
|
|
||||||
|
lvm.__init__(self,model,data_visualize,latent_axes,[0,1])
|
||||||
|
self.latent_values_clicked = np.zeros(model.Q)
|
||||||
|
self._first_index_next = False
|
||||||
|
|
||||||
|
def on_click(self, event):
|
||||||
|
#print "click"
|
||||||
|
if event.inaxes==self.hist_axes:
|
||||||
|
self.hist_axes.cla()
|
||||||
|
self.hist_axes.bar(np.arange(self.model.Q),1./self.model.input_sensitivity(),color='b')
|
||||||
|
new_index = int(np.round(event.xdata))
|
||||||
|
self.latent_index[int(self._first_index_next)] = new_index
|
||||||
|
self._first_index_next = not self._first_index_next
|
||||||
|
self.hist_axes.bar(np.array(self.latent_index),1./self.model.input_sensitivity()[self.latent_index],color='r')
|
||||||
|
self.latent_axes.cla()
|
||||||
|
self.model.plot_latent(which_indices = self.latent_index,ax=self.latent_axes)
|
||||||
|
self.fig.canvas.draw()
|
||||||
|
if event.inaxes==self.latent_axes:
|
||||||
|
#self.latent_values_clicked[self.latent_index] = np.array([event.xdata,event.ydata])
|
||||||
|
pass
|
||||||
|
self.move_on=True
|
||||||
|
self.called = True
|
||||||
|
|
||||||
|
|
||||||
|
def on_move(self, event):
|
||||||
|
#print "move"
|
||||||
|
if event.inaxes!=self.latent_axes: return
|
||||||
|
if self.called and self.move_on:
|
||||||
|
latent_values = self.latent_values_clicked.copy()
|
||||||
|
latent_values[self.latent_index] = np.array([event.xdata, event.ydata])
|
||||||
|
y = self.model.predict(latent_values[None,:])[0]
|
||||||
|
self.data_visualize.modify(y)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class data_show:
|
||||||
|
"""
|
||||||
|
The data show class is a base class which describes how to visualize a
|
||||||
|
particular data set. For example, motion capture data can be plotted as a
|
||||||
|
stick figure, or images are shown using imshow. This class enables latent
|
||||||
|
to data visualizations for the GP-LVM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vals, axes=None):
|
||||||
self.vals = vals
|
self.vals = vals
|
||||||
# If no axes are defined, create some.
|
# If no axes are defined, create some.
|
||||||
if axis==None:
|
if axes==None:
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
self.axis = fig.add_subplot(111)
|
self.axes = fig.add_subplot(111)
|
||||||
else:
|
else:
|
||||||
self.axis = axis
|
self.axes = axes
|
||||||
|
|
||||||
def modify(self, vals):
|
def modify(self, vals):
|
||||||
raise NotImplementedError, "this needs to be implemented to use the data_show class"
|
raise NotImplementedError, "this needs to be implemented to use the data_show class"
|
||||||
|
|
||||||
class vector_show(data_show):
|
class vector_show(data_show):
|
||||||
"""A base visualization class that just shows a data vector as a plot of vector elements alongside their indices."""
|
"""
|
||||||
def __init__(self, vals, axis=None):
|
A base visualization class that just shows a data vector as a plot of
|
||||||
data_show.__init__(self, vals, axis)
|
vector elements alongside their indices.
|
||||||
|
"""
|
||||||
|
def __init__(self, vals, axes=None):
|
||||||
|
data_show.__init__(self, vals, axes)
|
||||||
self.vals = vals.T
|
self.vals = vals.T
|
||||||
self.handle = self.axis.plot(np.arange(0, len(vals))[:, None], self.vals)[0]
|
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals)[0]
|
||||||
|
|
||||||
def modify(self, vals):
|
def modify(self, vals):
|
||||||
xdata, ydata = self.handle.get_data()
|
xdata, ydata = self.handle.get_data()
|
||||||
self.vals = vals.T
|
self.vals = vals.T
|
||||||
self.handle.set_data(xdata, self.vals)
|
self.handle.set_data(xdata, self.vals)
|
||||||
self.axis.figure.canvas.draw()
|
self.axes.figure.canvas.draw()
|
||||||
|
|
||||||
class image_show(data_show):
|
class image_show(data_show):
|
||||||
"""Show a data vector as an image."""
|
"""Show a data vector as an image."""
|
||||||
def __init__(self, vals, axis=None, dimensions=(16,16), transpose=False, invert=False, scale=False):
|
def __init__(self, vals, axes=None, dimensions=(16,16), transpose=False, invert=False, scale=False):
|
||||||
data_show.__init__(self, vals, axis)
|
data_show.__init__(self, vals, axes)
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.transpose = transpose
|
self.transpose = transpose
|
||||||
self.invert = invert
|
self.invert = invert
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.set_image(vals/255.)
|
self.set_image(vals/255.)
|
||||||
self.handle = self.axis.imshow(self.vals, cmap=plt.cm.gray, interpolation='nearest')
|
self.handle = self.axes.imshow(self.vals, cmap=plt.cm.gray, interpolation='nearest')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def modify(self, vals):
|
def modify(self, vals):
|
||||||
self.set_image(vals/255.)
|
self.set_image(vals/255.)
|
||||||
#self.handle.remove()
|
#self.handle.remove()
|
||||||
#self.handle = self.axis.imshow(self.vals)
|
#self.handle = self.axes.imshow(self.vals)
|
||||||
self.handle.set_array(self.vals)
|
self.handle.set_array(self.vals)
|
||||||
#self.axis.figure.canvas.draw()
|
#self.axes.figure.canvas.draw()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def set_image(self, vals):
|
def set_image(self, vals):
|
||||||
|
|
@ -100,21 +168,21 @@ class image_show(data_show):
|
||||||
class stick_show(data_show):
|
class stick_show(data_show):
|
||||||
"""Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect."""
|
"""Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect."""
|
||||||
|
|
||||||
def __init__(self, vals, axis=None, connect=None):
|
def __init__(self, vals, axes=None, connect=None):
|
||||||
if axis==None:
|
if axes==None:
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
axis = fig.add_subplot(111, projection='3d')
|
axes = fig.add_subplot(111, projection='3d')
|
||||||
data_show.__init__(self, vals, axis)
|
data_show.__init__(self, vals, axes)
|
||||||
self.vals = vals.reshape((3, vals.shape[1]/3)).T
|
self.vals = vals.reshape((3, vals.shape[1]/3)).T
|
||||||
self.x_lim = np.array([self.vals[:, 0].min(), self.vals[:, 0].max()])
|
self.x_lim = np.array([self.vals[:, 0].min(), self.vals[:, 0].max()])
|
||||||
self.y_lim = np.array([self.vals[:, 1].min(), self.vals[:, 1].max()])
|
self.y_lim = np.array([self.vals[:, 1].min(), self.vals[:, 1].max()])
|
||||||
self.z_lim = np.array([self.vals[:, 2].min(), self.vals[:, 2].max()])
|
self.z_lim = np.array([self.vals[:, 2].min(), self.vals[:, 2].max()])
|
||||||
self.points_handle = self.axis.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
self.points_handle = self.axes.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
||||||
self.axis.set_xlim(self.x_lim)
|
self.axes.set_xlim(self.x_lim)
|
||||||
self.axis.set_ylim(self.y_lim)
|
self.axes.set_ylim(self.y_lim)
|
||||||
self.axis.set_zlim(self.z_lim)
|
self.axes.set_zlim(self.z_lim)
|
||||||
self.axis.set_aspect(1)
|
self.axes.set_aspect(1)
|
||||||
self.axis.autoscale(enable=False)
|
self.axes.autoscale(enable=False)
|
||||||
|
|
||||||
self.connect = connect
|
self.connect = connect
|
||||||
if not self.connect==None:
|
if not self.connect==None:
|
||||||
|
|
@ -132,17 +200,17 @@ class stick_show(data_show):
|
||||||
z.append(self.vals[self.I[i], 2])
|
z.append(self.vals[self.I[i], 2])
|
||||||
z.append(self.vals[self.J[i], 2])
|
z.append(self.vals[self.J[i], 2])
|
||||||
z.append(np.NaN)
|
z.append(np.NaN)
|
||||||
self.line_handle = self.axis.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
||||||
self.axis.figure.canvas.draw()
|
self.axes.figure.canvas.draw()
|
||||||
|
|
||||||
def modify(self, vals):
|
def modify(self, vals):
|
||||||
self.points_handle.remove()
|
self.points_handle.remove()
|
||||||
self.line_handle[0].remove()
|
self.line_handle[0].remove()
|
||||||
self.vals = vals.reshape((3, vals.shape[1]/3)).T
|
self.vals = vals.reshape((3, vals.shape[1]/3)).T
|
||||||
self.points_handle = self.axis.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
self.points_handle = self.axes.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
||||||
self.axis.set_xlim(self.x_lim)
|
self.axes.set_xlim(self.x_lim)
|
||||||
self.axis.set_ylim(self.y_lim)
|
self.axes.set_ylim(self.y_lim)
|
||||||
self.axis.set_zlim(self.z_lim)
|
self.axes.set_zlim(self.z_lim)
|
||||||
self.line_handle = []
|
self.line_handle = []
|
||||||
if not self.connect==None:
|
if not self.connect==None:
|
||||||
x = []
|
x = []
|
||||||
|
|
@ -159,9 +227,9 @@ class stick_show(data_show):
|
||||||
z.append(self.vals[self.I[i], 2])
|
z.append(self.vals[self.I[i], 2])
|
||||||
z.append(self.vals[self.J[i], 2])
|
z.append(self.vals[self.J[i], 2])
|
||||||
z.append(np.NaN)
|
z.append(np.NaN)
|
||||||
self.line_handle = self.axis.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
||||||
|
|
||||||
self.axis.figure.canvas.draw()
|
self.axes.figure.canvas.draw()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue