Added visualization for motion capture data using python visual module.

This commit is contained in:
Neil Lawrence 2013-06-06 04:55:01 +01:00
parent eb5f2ff5f0
commit cd6c28bc6a

View file

@ -5,6 +5,7 @@ import numpy as np
import matplotlib as mpl import matplotlib as mpl
import time import time
import Image import Image
import visual
class data_show: class data_show:
""" """
@ -13,26 +14,35 @@ class data_show:
stick figure, or images are shown using imshow. This class enables latent stick figure, or images are shown using imshow. This class enables latent
to data visualizations for the GP-LVM. to data visualizations for the GP-LVM.
""" """
def __init__(self, vals):
def __init__(self, vals, axes=None):
self.vals = vals.copy() self.vals = vals.copy()
# If no axes are defined, create some. # If no axes are defined, create some.
def modify(self, vals):
raise NotImplementedError, "this needs to be implemented to use the data_show class"
class matplotlib_show(data_show):
"""
the matplotlib_show class is a base class for all visualization methods that use matplotlib. It is initialized with an axis. If the axis is set to None it creates a figure window.
"""
def __init__(self, vals, axes=None):
data_show.__init__(self, vals)
# If no axes are defined, create some.
if axes==None: if axes==None:
fig = plt.figure() fig = plt.figure()
self.axes = fig.add_subplot(111) self.axes = fig.add_subplot(111)
else: else:
self.axes = axes self.axes = axes
def modify(self, vals): class vector_show(matplotlib_show):
raise NotImplementedError, "this needs to be implemented to use the data_show class"
class vector_show(data_show):
""" """
A base visualization class that just shows a data vector as a plot of A base visualization class that just shows a data vector as a plot of
vector elements alongside their indices. vector elements alongside their indices.
""" """
def __init__(self, vals, axes=None): def __init__(self, vals, axes=None):
data_show.__init__(self, vals, axes) matplotlib_show.__init__(self, vals, axes)
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals.T)[0] self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals.T)[0]
def modify(self, vals): def modify(self, vals):
@ -42,7 +52,7 @@ class vector_show(data_show):
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()
class lvm(data_show): class lvm(matplotlib_show):
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]): def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
"""Visualize a latent variable model """Visualize a latent variable model
@ -54,7 +64,7 @@ class lvm(data_show):
if vals == None: if vals == None:
vals = model.X[0] vals = model.X[0]
data_show.__init__(self, vals, axes=latent_axes) matplotlib_show.__init__(self, vals, axes=latent_axes)
if isinstance(latent_axes,mpl.axes.Axes): if isinstance(latent_axes,mpl.axes.Axes):
self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click) self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click)
@ -204,10 +214,10 @@ class lvm_dimselect(lvm):
class image_show(data_show): class image_show(matplotlib_show):
"""Show a data vector as an image.""" """Show a data vector as an image."""
def __init__(self, vals, axes=None, dimensions=(16,16), transpose=False, invert=False, scale=False, palette=[], presetMean = 0., presetSTD = -1., selectImage=0): def __init__(self, vals, axes=None, dimensions=(16,16), transpose=False, invert=False, scale=False, palette=[], presetMean = 0., presetSTD = -1., selectImage=0):
data_show.__init__(self, vals, axes) matplotlib_show.__init__(self, vals, axes)
self.dimensions = dimensions self.dimensions = dimensions
self.transpose = transpose self.transpose = transpose
self.invert = invert self.invert = invert
@ -266,14 +276,72 @@ class image_show(data_show):
self.vals = Image.fromarray(self.vals.astype('uint8')) self.vals = Image.fromarray(self.vals.astype('uint8'))
self.vals.putpalette(self.palette) # palette is a list, must be loaded before calling this function self.vals.putpalette(self.palette) # palette is a list, must be loaded before calling this function
class mocap_data_show(data_show): class mocap_data_show_visual(data_show):
"""Base class for visualizing motion capture data using visual module."""
def __init__(self, vals, connect=None, radius=0.1):
data_show.__init__(self, vals)
self.radius = radius
self.connect = connect
self.process_values()
self.draw_edges()
self.draw_vertices()
def draw_vertices(self):
self.spheres = []
for i in range(self.vals.shape[0]):
self.spheres.append(visual.sphere(pos=(self.vals[i, 0], self.vals[i, 2], self.vals[i, 1]), radius=self.radius))
def draw_edges(self):
self.rods = []
self.line_handle = []
if not self.connect==None:
self.I, self.J = np.nonzero(self.connect)
for i, j in zip(self.I, self.J):
pos, axis = self.pos_axis(i, j)
self.rods.append(visual.cylinder(pos=pos, axis=axis, radius=self.radius))
def modify_vertices(self):
for i in range(self.vals.shape[0]):
self.spheres[i].pos = (self.vals[i, 0], self.vals[i, 2], self.vals[i, 1])
def modify_edges(self):
self.line_handle = []
if not self.connect==None:
self.I, self.J = np.nonzero(self.connect)
for rod, i, j in zip(self.rods, self.I, self.J):
rod.pos, rod.axis = self.pos_axis(i, j)
def pos_axis(self, i, j):
pos = []
axis = []
pos.append(self.vals[i, 0])
axis.append(self.vals[j, 0]-self.vals[i,0])
pos.append(self.vals[i, 2])
axis.append(self.vals[j, 2]-self.vals[i,2])
pos.append(self.vals[i, 1])
axis.append(self.vals[j, 1]-self.vals[i,1])
return pos, axis
def modify(self, vals):
self.vals = vals.copy()
self.process_values()
self.modify_edges()
self.modify_vertices()
def process_values(self):
raise NotImplementedError, "this needs to be implemented to use the data_show class"
class mocap_data_show(matplotlib_show):
"""Base class for visualizing motion capture data.""" """Base class for visualizing motion capture data."""
def __init__(self, vals, axes=None, connect=None): def __init__(self, vals, axes=None, connect=None):
if axes==None: if axes==None:
fig = plt.figure() fig = plt.figure()
axes = fig.add_subplot(111, projection='3d') axes = fig.add_subplot(111, projection='3d')
data_show.__init__(self, vals, axes) matplotlib_show.__init__(self, vals, axes)
self.connect = connect self.connect = connect
self.process_values() self.process_values()
@ -342,17 +410,17 @@ class mocap_data_show(data_show):
self.axes.set_zlim(self.z_lim) self.axes.set_zlim(self.z_lim)
class stick_show(mocap_data_show): class stick_show(mocap_data_show_visual):
"""Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect.""" """Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect."""
def __init__(self, vals, axes=None, connect=None): def __init__(self, vals, connect=None):
mocap_data_show.__init__(self, vals, axes, connect) mocap_data_show_visual.__init__(self, vals, connect, radius=0.04)
def process_values(self): def process_values(self):
self.vals = self.vals.reshape((3, self.vals.shape[1]/3)).T self.vals = self.vals.reshape((3, self.vals.shape[1]/3)).T
class skeleton_show(mocap_data_show): class skeleton_show(mocap_data_show_visual):
"""data_show class for visualizing motion capture data encoded as a skeleton with angles.""" """data_show class for visualizing motion capture data encoded as a skeleton with angles."""
def __init__(self, vals, skel, padding=0, axes=None): def __init__(self, vals, skel, padding=0):
"""data_show class for visualizing motion capture data encoded as a skeleton with angles. """data_show class for visualizing motion capture data encoded as a skeleton with angles.
:param vals: set of modeled angles to use for printing in the axis when it's first created. :param vals: set of modeled angles to use for printing in the axis when it's first created.
:type vals: np.array :type vals: np.array
@ -364,8 +432,7 @@ class skeleton_show(mocap_data_show):
self.skel = skel self.skel = skel
self.padding = padding self.padding = padding
connect = skel.connection_matrix() connect = skel.connection_matrix()
mocap_data_show.__init__(self, vals, axes, connect) mocap_data_show_visual.__init__(self, vals, connect, radius=0.4)
def process_values(self): def process_values(self):
"""Takes a set of angles and converts them to the x,y,z coordinates in the internal prepresentation of the class, ready for plotting. """Takes a set of angles and converts them to the x,y,z coordinates in the internal prepresentation of the class, ready for plotting.