From cd6c28bc6a07b349cc5fd5d2f2c8e5e1d4c7ab35 Mon Sep 17 00:00:00 2001 From: Neil Lawrence Date: Thu, 6 Jun 2013 04:55:01 +0100 Subject: [PATCH] Added visualization for motion capture data using python visual module. --- GPy/util/visualize.py | 107 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 87 insertions(+), 20 deletions(-) diff --git a/GPy/util/visualize.py b/GPy/util/visualize.py index b3429850..684bb0ce 100644 --- a/GPy/util/visualize.py +++ b/GPy/util/visualize.py @@ -5,6 +5,7 @@ import numpy as np import matplotlib as mpl import time import Image +import visual class data_show: """ @@ -13,26 +14,35 @@ class data_show: stick figure, or images are shown using imshow. This class enables latent to data visualizations for the GP-LVM. """ - - def __init__(self, vals, axes=None): + def __init__(self, vals): self.vals = vals.copy() # If no axes are defined, create some. + + def modify(self, vals): + raise NotImplementedError, "this needs to be implemented to use the data_show class" + + +class matplotlib_show(data_show): + """ + the matplotlib_show class is a base class for all visualization methods that use matplotlib. It is initialized with an axis. If the axis is set to None it creates a figure window. + """ + def __init__(self, vals, axes=None): + data_show.__init__(self, vals) + # If no axes are defined, create some. + if axes==None: fig = plt.figure() self.axes = fig.add_subplot(111) else: self.axes = axes - def modify(self, vals): - raise NotImplementedError, "this needs to be implemented to use the data_show class" - -class vector_show(data_show): +class vector_show(matplotlib_show): """ A base visualization class that just shows a data vector as a plot of vector elements alongside their indices. """ def __init__(self, vals, axes=None): - data_show.__init__(self, vals, axes) + matplotlib_show.__init__(self, vals, axes) self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals.T)[0] def modify(self, vals): @@ -42,7 +52,7 @@ class vector_show(data_show): self.axes.figure.canvas.draw() -class lvm(data_show): +class lvm(matplotlib_show): def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]): """Visualize a latent variable model @@ -54,7 +64,7 @@ class lvm(data_show): if vals == None: vals = model.X[0] - data_show.__init__(self, vals, axes=latent_axes) + matplotlib_show.__init__(self, vals, axes=latent_axes) if isinstance(latent_axes,mpl.axes.Axes): self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click) @@ -204,10 +214,10 @@ class lvm_dimselect(lvm): -class image_show(data_show): +class image_show(matplotlib_show): """Show a data vector as an image.""" def __init__(self, vals, axes=None, dimensions=(16,16), transpose=False, invert=False, scale=False, palette=[], presetMean = 0., presetSTD = -1., selectImage=0): - data_show.__init__(self, vals, axes) + matplotlib_show.__init__(self, vals, axes) self.dimensions = dimensions self.transpose = transpose self.invert = invert @@ -266,14 +276,72 @@ class image_show(data_show): self.vals = Image.fromarray(self.vals.astype('uint8')) self.vals.putpalette(self.palette) # palette is a list, must be loaded before calling this function -class mocap_data_show(data_show): +class mocap_data_show_visual(data_show): + """Base class for visualizing motion capture data using visual module.""" + + def __init__(self, vals, connect=None, radius=0.1): + data_show.__init__(self, vals) + self.radius = radius + self.connect = connect + self.process_values() + self.draw_edges() + self.draw_vertices() + + def draw_vertices(self): + self.spheres = [] + for i in range(self.vals.shape[0]): + self.spheres.append(visual.sphere(pos=(self.vals[i, 0], self.vals[i, 2], self.vals[i, 1]), radius=self.radius)) + + + def draw_edges(self): + self.rods = [] + self.line_handle = [] + if not self.connect==None: + self.I, self.J = np.nonzero(self.connect) + for i, j in zip(self.I, self.J): + pos, axis = self.pos_axis(i, j) + self.rods.append(visual.cylinder(pos=pos, axis=axis, radius=self.radius)) + + def modify_vertices(self): + for i in range(self.vals.shape[0]): + self.spheres[i].pos = (self.vals[i, 0], self.vals[i, 2], self.vals[i, 1]) + + def modify_edges(self): + self.line_handle = [] + if not self.connect==None: + self.I, self.J = np.nonzero(self.connect) + for rod, i, j in zip(self.rods, self.I, self.J): + rod.pos, rod.axis = self.pos_axis(i, j) + + def pos_axis(self, i, j): + pos = [] + axis = [] + pos.append(self.vals[i, 0]) + axis.append(self.vals[j, 0]-self.vals[i,0]) + pos.append(self.vals[i, 2]) + axis.append(self.vals[j, 2]-self.vals[i,2]) + pos.append(self.vals[i, 1]) + axis.append(self.vals[j, 1]-self.vals[i,1]) + return pos, axis + + def modify(self, vals): + self.vals = vals.copy() + self.process_values() + self.modify_edges() + self.modify_vertices() + + def process_values(self): + raise NotImplementedError, "this needs to be implemented to use the data_show class" + + +class mocap_data_show(matplotlib_show): """Base class for visualizing motion capture data.""" def __init__(self, vals, axes=None, connect=None): if axes==None: fig = plt.figure() axes = fig.add_subplot(111, projection='3d') - data_show.__init__(self, vals, axes) + matplotlib_show.__init__(self, vals, axes) self.connect = connect self.process_values() @@ -342,17 +410,17 @@ class mocap_data_show(data_show): self.axes.set_zlim(self.z_lim) -class stick_show(mocap_data_show): +class stick_show(mocap_data_show_visual): """Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect.""" - def __init__(self, vals, axes=None, connect=None): - mocap_data_show.__init__(self, vals, axes, connect) + def __init__(self, vals, connect=None): + mocap_data_show_visual.__init__(self, vals, connect, radius=0.04) def process_values(self): self.vals = self.vals.reshape((3, self.vals.shape[1]/3)).T -class skeleton_show(mocap_data_show): +class skeleton_show(mocap_data_show_visual): """data_show class for visualizing motion capture data encoded as a skeleton with angles.""" - def __init__(self, vals, skel, padding=0, axes=None): + def __init__(self, vals, skel, padding=0): """data_show class for visualizing motion capture data encoded as a skeleton with angles. :param vals: set of modeled angles to use for printing in the axis when it's first created. :type vals: np.array @@ -364,8 +432,7 @@ class skeleton_show(mocap_data_show): self.skel = skel self.padding = padding connect = skel.connection_matrix() - mocap_data_show.__init__(self, vals, axes, connect) - + mocap_data_show_visual.__init__(self, vals, connect, radius=0.4) def process_values(self): """Takes a set of angles and converts them to the x,y,z coordinates in the internal prepresentation of the class, ready for plotting.