import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import GPy import numpy as np import matplotlib as mpl import time import Image class data_show: """ The data show class is a base class which describes how to visualize a particular data set. For example, motion capture data can be plotted as a stick figure, or images are shown using imshow. This class enables latent to data visualizations for the GP-LVM. """ def __init__(self, vals, axes=None): self.vals = vals.copy() # If no axes are defined, create some. if axes==None: fig = plt.figure() self.axes = fig.add_subplot(111) else: self.axes = axes def modify(self, vals): raise NotImplementedError, "this needs to be implemented to use the data_show class" class vector_show(data_show): """ A base visualization class that just shows a data vector as a plot of vector elements alongside their indices. """ def __init__(self, vals, axes=None): data_show.__init__(self, vals, axes) self.vals = vals.T.copy() self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals)[0] def modify(self, vals): xdata, ydata = self.handle.get_data() self.vals = vals.T.copy() self.handle.set_data(xdata, self.vals) self.axes.figure.canvas.draw() class lvm(data_show): def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]): """Visualize a latent variable model :param model: the latent variable model to visualize. :param data_visualize: the object used to visualize the data which has been modelled. :type data_visualize: visualize.data_show type. :param latent_axes: the axes where the latent visualization should be plotted. """ if vals == None: vals = model.X[0].copy() data_show.__init__(self, vals, axes=latent_axes) if isinstance(latent_axes,mpl.axes.Axes): self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click) self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move) self.cid = latent_axes.figure.canvas.mpl_connect('axes_leave_event', self.on_leave) self.cid = latent_axes.figure.canvas.mpl_connect('axes_enter_event', self.on_enter) else: self.cid = latent_axes[0].figure.canvas.mpl_connect('button_press_event', self.on_click) self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move) self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_leave_event', self.on_leave) self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter) self.data_visualize = data_visualize self.model = model self.latent_axes = latent_axes self.sense_axes = sense_axes self.called = False self.move_on = False self.latent_index = latent_index self.latent_dim = model.Q # The red cross which shows current latent point. self.latent_values = vals self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0] self.modify(vals) self.show_sensitivities() def modify(self, vals): """When latent values are modified update the latent representation and ulso update the output visualization.""" y = self.model.predict(vals)[0] print y self.data_visualize.modify(y) self.latent_handle.set_data(vals[self.latent_index[0]], vals[self.latent_index[1]]) self.axes.figure.canvas.draw() def on_enter(self,event): pass def on_leave(self,event): pass def on_click(self, event): if event.inaxes!=self.latent_axes: return self.move_on = not self.move_on self.called = True def on_move(self, event): if event.inaxes!=self.latent_axes: return if self.called and self.move_on: # Call modify code on move self.latent_values[self.latent_index[0]]=event.xdata self.latent_values[self.latent_index[1]]=event.ydata self.modify(self.latent_values) def show_sensitivities(self): # A click in the bar chart axis for selection a dimension. if self.sense_axes != None: self.sense_axes.cla() self.sense_axes.bar(np.arange(self.model.Q),1./self.model.input_sensitivity(),color='b') if self.latent_index[1] == self.latent_index[0]: self.sense_axes.bar(np.array(self.latent_index[0]),1./self.model.input_sensitivity()[self.latent_index[0]],color='y') self.sense_axes.bar(np.array(self.latent_index[1]),1./self.model.input_sensitivity()[self.latent_index[1]],color='y') else: self.sense_axes.bar(np.array(self.latent_index[0]),1./self.model.input_sensitivity()[self.latent_index[0]],color='g') self.sense_axes.bar(np.array(self.latent_index[1]),1./self.model.input_sensitivity()[self.latent_index[1]],color='r') self.sense_axes.figure.canvas.draw() class lvm_subplots(lvm): """ latent_axes is a np array of dimension np.ceil(Q/2), one for each pair of the latent dimensions. """ def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None): self.nplots = int(np.ceil(model.Q/2.))+1 assert len(latent_axes)==self.nplots if vals==None: vals = model.X[0, :] self.latent_values = vals for i, axis in enumerate(latent_axes): if i == self.nplots-1: if self.nplots*2!=model.Q: latent_index = [i*2, i*2] lvm.__init__(self, self.latent_vals, model, data_visualize, axis, sense_axes, latent_index=latent_index) else: latent_index = [i*2, i*2+1] lvm.__init__(self, self.latent_vals, model, data_visualize, axis, latent_index=latent_index) class lvm_dimselect(lvm): """ A visualizer for latent variable models which allows selection of the latent dimensions to use by clicking on a bar chart of their length scales. For an example of the visualizer's use try: GPy.examples.dimensionality_reduction.BGPVLM_oil() """ def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0, 1]): if latent_axes==None and sense_axes==None: self.fig,(latent_axes,self.sense_axes) = plt.subplots(1,2) elif sense_axes==None: fig=plt.figure() self.sense_axes = fig.add_subplot(111) else: self.sense_axes = sense_axes lvm.__init__(self,vals,model,data_visualize,latent_axes,sense_axes,latent_index) print "use left and right mouse butons to select dimensions" def on_click(self, event): if event.inaxes==self.sense_axes: new_index = max(0,min(int(np.round(event.xdata-0.5)),self.model.Q-1)) if event.button == 1: # Make it red if and y-axis (red=port=left) if it is a left button click self.latent_index[1] = new_index else: # Make it green and x-axis (green=starboard=right) if it is a right button click self.latent_index[0] = new_index self.show_sensitivities() self.latent_axes.cla() self.model.plot_latent(which_indices=self.latent_index, ax=self.latent_axes) self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0] self.modify(self.latent_values) elif event.inaxes==self.latent_axes: self.move_on = not self.move_on self.called = True def on_leave(self,event): latent_values = self.latent_values.copy() y = self.model.predict(latent_values[None,:])[0] self.data_visualize.modify(y) class image_show(data_show): """Show a data vector as an image.""" def __init__(self, vals, axes=None, dimensions=(16,16), transpose=False, invert=False, scale=False, palette=[], presetMean = 0., presetSTD = -1., selectImage=0): data_show.__init__(self, vals, axes) self.dimensions = dimensions self.transpose = transpose self.invert = invert self.scale = scale self.palette = palette self.presetMean = presetMean self.presetSTD = presetSTD self.selectImage = selectImage # This is used when the y vector contains multiple images concatenated. self.set_image(vals) if not self.palette == []: # Can just show the image (self.set_image() took care of setting the palette) self.handle = self.axes.imshow(self.vals, interpolation='nearest') else: # Use a boring gray map. self.handle = self.axes.imshow(self.vals, cmap=plt.cm.gray, interpolation='nearest') plt.show() def modify(self, vals): self.set_image(vals) self.handle.set_array(self.vals) self.axes.figure.canvas.draw() def set_image(self, vals): dim = self.dimensions[0] * self.dimensions[1] nImg = np.sqrt(vals[0,].size/dim) if nImg > 1 and nImg.is_integer(): # Show a mosaic of images nImg = np.int(nImg) self.vals = np.zeros((self.dimensions[0]*nImg, self.dimensions[1]*nImg)) for iR in range(nImg): for iC in range(nImg): currImgId = iR*nImg + iC currImg = np.reshape(vals[0,dim*currImgId+np.array(range(dim))], self.dimensions, order='F') firstRow = iR*self.dimensions[0] lastRow = (iR+1)*self.dimensions[0] firstCol = iC*self.dimensions[1] lastCol = (iC+1)*self.dimensions[1] self.vals[firstRow:lastRow, firstCol:lastCol] = currImg else: self.vals = np.reshape(vals[0,dim*self.selectImage+np.array(range(dim))], self.dimensions, order='F') if self.transpose: self.vals = self.vals.T.copy() # if not self.scale: # self.vals = self.vals if self.invert: self.vals = -self.vals # un-normalizing, for visualisation purposes: if self.presetSTD >= 0: # The Mean is assumed to be in the range (0,255) self.vals = self.vals*self.presetSTD + self.presetMean # Clipping the values: self.vals[self.vals < 0] = 0 self.vals[self.vals > 255] = 255 else: self.vals = 255*(self.vals - self.vals.min())/(self.vals.max() - self.vals.min()) if not self.palette == []: # applying using an image palette (e.g. if the image has been quantized) self.vals = Image.fromarray(self.vals.astype('uint8')) self.vals.putpalette(self.palette) # palette is a list, must be loaded before calling this function class mocap_data_show(data_show): """Base class for visualizing motion capture data.""" def __init__(self, vals, axes=None, connect=None): if axes==None: fig = plt.figure() axes = fig.add_subplot(111, projection='3d') data_show.__init__(self, vals, axes) self.connect = connect self.process_values(vals) self.initialize_axes() self.draw_vertices() self.finalize_axes() self.draw_edges() self.axes.figure.canvas.draw() def draw_vertices(self): self.points_handle = self.axes.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2]) def draw_edges(self): self.line_handle = [] if not self.connect==None: x = [] y = [] z = [] self.I, self.J = np.nonzero(self.connect) for i, j in zip(self.I, self.J): x.append(self.vals[i, 0]) x.append(self.vals[j, 0]) x.append(np.NaN) y.append(self.vals[i, 1]) y.append(self.vals[j, 1]) y.append(np.NaN) z.append(self.vals[i, 2]) z.append(self.vals[j, 2]) z.append(np.NaN) self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-') def modify(self, vals): self.process_values(vals) self.initialize_axes_modify() self.draw_vertices() self.finalize_axes_modify() self.draw_edges() self.axes.figure.canvas.draw() def process_values(self, vals): raise NotImplementedError, "this needs to be implemented to use the data_show class" def initialize_axes(self): """Set up the axes with the right limits and scaling.""" self.x_lim = np.array([self.vals[:, 0].min(), self.vals[:, 0].max()]) self.y_lim = np.array([self.vals[:, 1].min(), self.vals[:, 1].max()]) self.z_lim = np.array([self.vals[:, 2].min(), self.vals[:, 2].max()]) def initialize_axes_modify(self): self.points_handle.remove() self.line_handle[0].remove() def finalize_axes(self): self.axes.set_xlim(self.x_lim) self.axes.set_ylim(self.y_lim) self.axes.set_zlim(self.z_lim) self.axes.set_aspect(1) self.axes.autoscale(enable=False) def finalize_axes_modify(self): self.axes.set_xlim(self.x_lim) self.axes.set_ylim(self.y_lim) self.axes.set_zlim(self.z_lim) class stick_show(mocap_data_show): """Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect.""" def __init__(self, vals, axes=None, connect=None): mocap_data_show.__init__(self, vals, axes, connect) def process_values(self, vals): self.vals = vals.reshape((3, vals.shape[1]/3)).T.copy() class skeleton_show(mocap_data_show): """data_show class for visualizing motion capture data encoded as a skeleton with angles.""" def __init__(self, vals, skel, padding=0, axes=None): """data_show class for visualizing motion capture data encoded as a skeleton with angles. :param vals: set of modeled angles to use for printing in the axis when it's first created. :type vals: np.array :param skel: skeleton object that has the parameters of the motion capture skeleton associated with it. :type skel: mocap.skeleton object :param padding: :type int """ self.skel = skel self.padding = padding connect = skel.connection_matrix() mocap_data_show.__init__(self, vals, axes, connect) def process_values(self, vals): """Takes a set of angles and converts them to the x,y,z coordinates in the internal prepresentation of the class, ready for plotting. :param vals: the values that are being modelled.""" if self.padding>0: channels = np.zeros((vals.shape[0], vals.shape[1]+self.padding)) channels[:, 0:vals.shape[0]] = vals else: channels = vals vals_mat = self.skel.to_xyz(channels.flatten()) self.vals = vals_mat # Flip the Y and Z axes self.vals[:, 0] = vals_mat[:, 0] self.vals[:, 1] = vals_mat[:, 2] self.vals[:, 2] = vals_mat[:, 1] def wrap_around(vals, lim, connect): quot = lim[1] - lim[0] vals = rem(vals, quot)+lim[0] nVals = floor(vals/quot) for i in range(connect.shape[0]): for j in find(connect[i, :]): if nVals[i] != nVals[j]: connect[i, j] = False return vals, connect def data_play(Y, visualizer, frame_rate=30): """Play a data set using the data_show object given. :Y: the data set to be visualized. :param visualizer: the data show objectwhether to display during optimisation :type visualizer: data_show Example usage: This example loads in the CMU mocap database (http://mocap.cs.cmu.edu) subject number 35 motion number 01. It then plays it using the mocap_show visualize object. data = GPy.util.datasets.cmu_mocap(subject='35', train_motions=['01']) Y = data['Y'] Y[:, 0:3] = 0. # Make figure walk in place visualize = GPy.util.visualize.skeleton_show(Y[0, :], data['skel']) GPy.util.visualize.data_play(Y, visualize) """ for y in Y: visualizer.modify(y) time.sleep(1./float(frame_rate))