Neil's flailing attempts to update the flailing stick man.

This commit is contained in:
Neil Lawrence 2013-06-04 17:20:46 +01:00
parent 95d5bcc1b9
commit 500bd8f4b8
3 changed files with 33 additions and 30 deletions

View file

@ -184,7 +184,7 @@ class Bayesian_GPLVM(sparse_GP, GPLVM):
return self._clipped(np.hstack((self.dbound_dmuS.flatten(), self.dbound_dZtheta))) return self._clipped(np.hstack((self.dbound_dmuS.flatten(), self.dbound_dZtheta)))
def plot_latent(self, *args, **kwargs): def plot_latent(self, *args, **kwargs):
plot_latent.plot_latent_indices(self, *args, **kwargs) return plot_latent.plot_latent_indices(self, *args, **kwargs)
def do_test_latents(self, Y): def do_test_latents(self, Y):
""" """

View file

@ -63,4 +63,4 @@ class GPLVM(GP):
pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5) pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5)
def plot_latent(self, *args, **kwargs): def plot_latent(self, *args, **kwargs):
util.plot_latent.plot_latent(self, *args, **kwargs) return util.plot_latent.plot_latent(self, *args, **kwargs)

View file

@ -33,13 +33,12 @@ class vector_show(data_show):
""" """
def __init__(self, vals, axes=None): def __init__(self, vals, axes=None):
data_show.__init__(self, vals, axes) data_show.__init__(self, vals, axes)
self.vals = vals.T.copy() self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals.T)[0]
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals)[0]
def modify(self, vals): def modify(self, vals):
self.vals = vals.copy()
xdata, ydata = self.handle.get_data() xdata, ydata = self.handle.get_data()
self.vals = vals.T.copy() self.handle.set_data(xdata, self.vals.T)
self.handle.set_data(xdata, self.vals)
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()
@ -53,7 +52,7 @@ class lvm(data_show):
:param latent_axes: the axes where the latent visualization should be plotted. :param latent_axes: the axes where the latent visualization should be plotted.
""" """
if vals == None: if vals == None:
vals = model.X[0].copy() vals = model.X[0]
data_show.__init__(self, vals, axes=latent_axes) data_show.__init__(self, vals, axes=latent_axes)
@ -85,9 +84,10 @@ class lvm(data_show):
def modify(self, vals): def modify(self, vals):
"""When latent values are modified update the latent representation and ulso update the output visualization.""" """When latent values are modified update the latent representation and ulso update the output visualization."""
y = self.model.predict(vals)[0] self.vals = vals.copy()
y = self.model.predict(self.vals)[0]
self.data_visualize.modify(y) self.data_visualize.modify(y)
self.latent_handle.set_data(vals[self.latent_index[0]], vals[self.latent_index[1]]) self.latent_handle.set_data(self.vals[self.latent_index[0]], self.vals[self.latent_index[1]])
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()
@ -217,7 +217,7 @@ class image_show(data_show):
self.presetSTD = presetSTD self.presetSTD = presetSTD
self.selectImage = selectImage # This is used when the y vector contains multiple images concatenated. self.selectImage = selectImage # This is used when the y vector contains multiple images concatenated.
self.set_image(vals) self.set_image(self.vals)
if not self.palette == []: # Can just show the image (self.set_image() took care of setting the palette) if not self.palette == []: # Can just show the image (self.set_image() took care of setting the palette)
self.handle = self.axes.imshow(self.vals, interpolation='nearest') self.handle = self.axes.imshow(self.vals, interpolation='nearest')
else: # Use a boring gray map. else: # Use a boring gray map.
@ -225,7 +225,7 @@ class image_show(data_show):
plt.show() plt.show()
def modify(self, vals): def modify(self, vals):
self.set_image(vals) self.set_image(vals.copy())
self.handle.set_array(self.vals) self.handle.set_array(self.vals)
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()
@ -248,7 +248,7 @@ class image_show(data_show):
else: else:
self.vals = np.reshape(vals[0,dim*self.selectImage+np.array(range(dim))], self.dimensions, order='F') self.vals = np.reshape(vals[0,dim*self.selectImage+np.array(range(dim))], self.dimensions, order='F')
if self.transpose: if self.transpose:
self.vals = self.vals.T.copy() self.vals = self.vals.T
# if not self.scale: # if not self.scale:
# self.vals = self.vals # self.vals = self.vals
if self.invert: if self.invert:
@ -276,7 +276,7 @@ class mocap_data_show(data_show):
data_show.__init__(self, vals, axes) data_show.__init__(self, vals, axes)
self.connect = connect self.connect = connect
self.process_values(vals) self.process_values()
self.initialize_axes() self.initialize_axes()
self.draw_vertices() self.draw_vertices()
self.finalize_axes() self.finalize_axes()
@ -306,14 +306,15 @@ class mocap_data_show(data_show):
self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-') self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-')
def modify(self, vals): def modify(self, vals):
self.process_values(vals) self.vals = vals.copy()
self.process_values()
self.initialize_axes_modify() self.initialize_axes_modify()
self.draw_vertices() self.draw_vertices()
self.finalize_axes_modify() self.finalize_axes_modify()
self.draw_edges() self.draw_edges()
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()
def process_values(self, vals): def process_values(self):
raise NotImplementedError, "this needs to be implemented to use the data_show class" raise NotImplementedError, "this needs to be implemented to use the data_show class"
def initialize_axes(self): def initialize_axes(self):
@ -330,7 +331,9 @@ class mocap_data_show(data_show):
self.axes.set_xlim(self.x_lim) self.axes.set_xlim(self.x_lim)
self.axes.set_ylim(self.y_lim) self.axes.set_ylim(self.y_lim)
self.axes.set_zlim(self.z_lim) self.axes.set_zlim(self.z_lim)
self.axes.set_aspect(1) self.axes.auto_scale_xyz([-1., 1.], [-1., 1.], [-1.5, 1.5])
#self.axes.set_aspect('equal')
self.axes.autoscale(enable=False) self.axes.autoscale(enable=False)
def finalize_axes_modify(self): def finalize_axes_modify(self):
@ -344,8 +347,8 @@ class stick_show(mocap_data_show):
def __init__(self, vals, axes=None, connect=None): def __init__(self, vals, axes=None, connect=None):
mocap_data_show.__init__(self, vals, axes, connect) mocap_data_show.__init__(self, vals, axes, connect)
def process_values(self, vals): def process_values(self):
self.vals = vals.reshape((3, vals.shape[1]/3)).T.copy() self.vals = self.vals.reshape((3, self.vals.shape[1]/3)).T
class skeleton_show(mocap_data_show): class skeleton_show(mocap_data_show):
"""data_show class for visualizing motion capture data encoded as a skeleton with angles.""" """data_show class for visualizing motion capture data encoded as a skeleton with angles."""
@ -363,32 +366,32 @@ class skeleton_show(mocap_data_show):
connect = skel.connection_matrix() connect = skel.connection_matrix()
mocap_data_show.__init__(self, vals, axes, connect) mocap_data_show.__init__(self, vals, axes, connect)
def process_values(self, vals): def process_values(self):
"""Takes a set of angles and converts them to the x,y,z coordinates in the internal prepresentation of the class, ready for plotting. """Takes a set of angles and converts them to the x,y,z coordinates in the internal prepresentation of the class, ready for plotting.
:param vals: the values that are being modelled.""" :param vals: the values that are being modelled."""
if self.padding>0: if self.padding>0:
channels = np.zeros((vals.shape[0], vals.shape[1]+self.padding)) channels = np.zeros((self.vals.shape[0], self.vals.shape[1]+self.padding))
channels[:, 0:vals.shape[0]] = vals channels[:, 0:self.vals.shape[0]] = self.vals
else: else:
channels = vals channels = self.vals
vals_mat = self.skel.to_xyz(channels.flatten()) vals_mat = self.skel.to_xyz(channels.flatten())
self.vals = vals_mat self.vals = np.zeros_like(vals_mat)
# Flip the Y and Z axes # Flip the Y and Z axes
self.vals[:, 0] = vals_mat[:, 0] self.vals[:, 0] = vals_mat[:, 0].copy()
self.vals[:, 1] = vals_mat[:, 2] self.vals[:, 1] = vals_mat[:, 2].copy()
self.vals[:, 2] = vals_mat[:, 1] self.vals[:, 2] = vals_mat[:, 1].copy()
def wrap_around(vals, lim, connect): def wrap_around(self, lim, connect):
quot = lim[1] - lim[0] quot = lim[1] - lim[0]
vals = rem(vals, quot)+lim[0] self.vals = rem(self.vals, quot)+lim[0]
nVals = floor(vals/quot) nVals = floor(self.vals/quot)
for i in range(connect.shape[0]): for i in range(connect.shape[0]):
for j in find(connect[i, :]): for j in find(connect[i, :]):
if nVals[i] != nVals[j]: if nVals[i] != nVals[j]:
connect[i, j] = False connect[i, j] = False
return vals, connect return connect
def data_play(Y, visualizer, frame_rate=30): def data_play(Y, visualizer, frame_rate=30):