fix stick man example

This commit is contained in:
Zhenwen Dai 2014-04-09 12:22:46 +01:00
parent 5cfc250ad1
commit 9d12c83935
6 changed files with 15 additions and 12 deletions

View file

@ -29,7 +29,7 @@ def plot_2D_images(figure, arr, symmetric=False, pad=None, zoom=None, mode=None,
pad = max(int(min(y_size,x_size)/10),1)
figsize = _calculateFigureSize(x_size, y_size, fig_ncols, fig_nrows, pad)
figure.set_size_inches(figsize,forward=True)
#figure.set_size_inches(figsize,forward=True)
#figure.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
if symmetric:
@ -37,7 +37,7 @@ def plot_2D_images(figure, arr, symmetric=False, pad=None, zoom=None, mode=None,
mval = max(abs(arr.max()),abs(arr.min()))
arr = arr/(2.*mval)+0.5
else:
minval,maxval = arr.max(),arr.min()
minval,maxval = arr.min(),arr.max()
arr = (arr-minval)/(maxval-minval)
if mode=='L':

View file

@ -85,6 +85,7 @@ class vector_show(matplotlib_show):
class lvm(matplotlib_show):
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
"""Visualize a latent variable model
@ -98,7 +99,7 @@ class lvm(matplotlib_show):
vals = param_to_array(model.X.mean)
else:
vals = param_to_array(model.X)
vals = param_to_array(vals)
matplotlib_show.__init__(self, vals, axes=latent_axes)
@ -121,7 +122,7 @@ class lvm(matplotlib_show):
self.move_on = False
self.latent_index = latent_index
self.latent_dim = model.input_dim
# The red cross which shows current latent point.
self.latent_values = vals
self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
@ -130,10 +131,10 @@ class lvm(matplotlib_show):
def modify(self, vals):
"""When latent values are modified update the latent representation and ulso update the output visualization."""
self.vals = vals.copy()
self.vals = vals[None,:].copy()
y = self.model.predict(self.vals)[0]
self.data_visualize.modify(y)
self.latent_handle.set_data(self.vals[self.latent_index[0]], self.vals[self.latent_index[1]])
self.latent_handle.set_data(self.vals[:,self.latent_index[0]], self.vals[:,self.latent_index[1]])
self.axes.figure.canvas.draw()