fix stick man example

This commit is contained in:
Zhenwen Dai 2014-04-09 12:22:46 +01:00
parent 5cfc250ad1
commit 9d12c83935
6 changed files with 15 additions and 12 deletions

View file

@ -409,12 +409,12 @@ def stick(kernel=None, optimize=True, verbose=True, plot=True):
# optimize # optimize
m = GPy.models.GPLVM(data['Y'], 2, kernel=kernel) m = GPy.models.GPLVM(data['Y'], 2, kernel=kernel)
if optimize: m.optimize(messages=verbose, max_f_eval=10000) if optimize: m.optimize(messages=verbose, max_f_eval=10000)
if plot and GPy.plotting.matplot_dep.visualize.visual_available: if plot:
plt.clf plt.clf
ax = m.plot_latent() ax = m.plot_latent()
y = m.likelihood.Y[0, :] y = m.Y[0, :]
data_show = GPy.plotting.matplot_dep.visualize.stick_show(y[None, :], connect=data['connect']) data_show = GPy.plotting.matplot_dep.visualize.stick_show(y[None, :], connect=data['connect'])
GPy.plotting.matplot_dep.visualize.lvm(m.X[0, :].copy(), m, data_show, ax) vis = GPy.plotting.matplot_dep.visualize.lvm(m.X[0, :].copy(), m, data_show, latent_axes=ax)
raw_input('Press enter to finish') raw_input('Press enter to finish')
return m return m

View file

@ -32,7 +32,7 @@ class ExactGaussianInference(object):
return Y return Y
else: else:
#if Y in self.cache, return self.Cache[Y], else store Y in cache and return L. #if Y in self.cache, return self.Cache[Y], else store Y in cache and return L.
print "WARNING: N>D of Y, we need caching of L, such that L*L^T = Y, returning Y still!" #print "WARNING: N>D of Y, we need caching of L, such that L*L^T = Y, returning Y still!"
return Y return Y
def inference(self, kern, X, likelihood, Y, Y_metadata=None): def inference(self, kern, X, likelihood, Y, Y_metadata=None):

View file

@ -109,7 +109,9 @@ class VarDTC_GPU(object):
x0, x1 = 0.,0. x0, x1 = 0.,0.
y0, y1 = self._estimateMemoryOccupation(N, M, D) y0, y1 = self._estimateMemoryOccupation(N, M, D)
return int((self.gpu_memory-y0-x0)/(x1+y1)) opt_batchsize = min(int((self.gpu_memory-y0-x0)/(x1+y1)), N)
return opt_batchsize
def _get_YYTfactor(self, Y): def _get_YYTfactor(self, Y):
""" """

View file

@ -29,7 +29,7 @@ def plot_2D_images(figure, arr, symmetric=False, pad=None, zoom=None, mode=None,
pad = max(int(min(y_size,x_size)/10),1) pad = max(int(min(y_size,x_size)/10),1)
figsize = _calculateFigureSize(x_size, y_size, fig_ncols, fig_nrows, pad) figsize = _calculateFigureSize(x_size, y_size, fig_ncols, fig_nrows, pad)
figure.set_size_inches(figsize,forward=True) #figure.set_size_inches(figsize,forward=True)
#figure.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95) #figure.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
if symmetric: if symmetric:
@ -37,7 +37,7 @@ def plot_2D_images(figure, arr, symmetric=False, pad=None, zoom=None, mode=None,
mval = max(abs(arr.max()),abs(arr.min())) mval = max(abs(arr.max()),abs(arr.min()))
arr = arr/(2.*mval)+0.5 arr = arr/(2.*mval)+0.5
else: else:
minval,maxval = arr.max(),arr.min() minval,maxval = arr.min(),arr.max()
arr = (arr-minval)/(maxval-minval) arr = (arr-minval)/(maxval-minval)
if mode=='L': if mode=='L':

View file

@ -85,6 +85,7 @@ class vector_show(matplotlib_show):
class lvm(matplotlib_show): class lvm(matplotlib_show):
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]): def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
"""Visualize a latent variable model """Visualize a latent variable model
@ -130,10 +131,10 @@ class lvm(matplotlib_show):
def modify(self, vals): def modify(self, vals):
"""When latent values are modified update the latent representation and ulso update the output visualization.""" """When latent values are modified update the latent representation and ulso update the output visualization."""
self.vals = vals.copy() self.vals = vals[None,:].copy()
y = self.model.predict(self.vals)[0] y = self.model.predict(self.vals)[0]
self.data_visualize.modify(y) self.data_visualize.modify(y)
self.latent_handle.set_data(self.vals[self.latent_index[0]], self.vals[self.latent_index[1]]) self.latent_handle.set_data(self.vals[:,self.latent_index[0]], self.vals[:,self.latent_index[1]])
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()