mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 09:42:39 +02:00
fix stick man example
This commit is contained in:
parent
5cfc250ad1
commit
9d12c83935
6 changed files with 15 additions and 12 deletions
|
|
@ -121,7 +121,7 @@ class GP(Model):
|
|||
If full_cov and self.input_dim > 1, the return shape of var is Nnew x Nnew x self.input_dim. If self.input_dim == 1, the return shape is Nnew x Nnew.
|
||||
This is to allow for different normalizations of the output dimensions.
|
||||
|
||||
"""
|
||||
"""
|
||||
#predict the latent function values
|
||||
mu, var = self._raw_predict(Xnew, full_cov=full_cov)
|
||||
|
||||
|
|
|
|||
|
|
@ -409,12 +409,12 @@ def stick(kernel=None, optimize=True, verbose=True, plot=True):
|
|||
# optimize
|
||||
m = GPy.models.GPLVM(data['Y'], 2, kernel=kernel)
|
||||
if optimize: m.optimize(messages=verbose, max_f_eval=10000)
|
||||
if plot and GPy.plotting.matplot_dep.visualize.visual_available:
|
||||
if plot:
|
||||
plt.clf
|
||||
ax = m.plot_latent()
|
||||
y = m.likelihood.Y[0, :]
|
||||
y = m.Y[0, :]
|
||||
data_show = GPy.plotting.matplot_dep.visualize.stick_show(y[None, :], connect=data['connect'])
|
||||
GPy.plotting.matplot_dep.visualize.lvm(m.X[0, :].copy(), m, data_show, ax)
|
||||
vis = GPy.plotting.matplot_dep.visualize.lvm(m.X[0, :].copy(), m, data_show, latent_axes=ax)
|
||||
raw_input('Press enter to finish')
|
||||
|
||||
return m
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class ExactGaussianInference(object):
|
|||
return Y
|
||||
else:
|
||||
#if Y in self.cache, return self.Cache[Y], else store Y in cache and return L.
|
||||
print "WARNING: N>D of Y, we need caching of L, such that L*L^T = Y, returning Y still!"
|
||||
#print "WARNING: N>D of Y, we need caching of L, such that L*L^T = Y, returning Y still!"
|
||||
return Y
|
||||
|
||||
def inference(self, kern, X, likelihood, Y, Y_metadata=None):
|
||||
|
|
|
|||
|
|
@ -109,7 +109,9 @@ class VarDTC_GPU(object):
|
|||
x0, x1 = 0.,0.
|
||||
y0, y1 = self._estimateMemoryOccupation(N, M, D)
|
||||
|
||||
return int((self.gpu_memory-y0-x0)/(x1+y1))
|
||||
opt_batchsize = min(int((self.gpu_memory-y0-x0)/(x1+y1)), N)
|
||||
|
||||
return opt_batchsize
|
||||
|
||||
def _get_YYTfactor(self, Y):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def plot_2D_images(figure, arr, symmetric=False, pad=None, zoom=None, mode=None,
|
|||
pad = max(int(min(y_size,x_size)/10),1)
|
||||
|
||||
figsize = _calculateFigureSize(x_size, y_size, fig_ncols, fig_nrows, pad)
|
||||
figure.set_size_inches(figsize,forward=True)
|
||||
#figure.set_size_inches(figsize,forward=True)
|
||||
#figure.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||
|
||||
if symmetric:
|
||||
|
|
@ -37,7 +37,7 @@ def plot_2D_images(figure, arr, symmetric=False, pad=None, zoom=None, mode=None,
|
|||
mval = max(abs(arr.max()),abs(arr.min()))
|
||||
arr = arr/(2.*mval)+0.5
|
||||
else:
|
||||
minval,maxval = arr.max(),arr.min()
|
||||
minval,maxval = arr.min(),arr.max()
|
||||
arr = (arr-minval)/(maxval-minval)
|
||||
|
||||
if mode=='L':
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ class vector_show(matplotlib_show):
|
|||
|
||||
|
||||
class lvm(matplotlib_show):
|
||||
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
|
||||
"""Visualize a latent variable model
|
||||
|
||||
|
|
@ -98,7 +99,7 @@ class lvm(matplotlib_show):
|
|||
vals = param_to_array(model.X.mean)
|
||||
else:
|
||||
vals = param_to_array(model.X)
|
||||
|
||||
|
||||
vals = param_to_array(vals)
|
||||
matplotlib_show.__init__(self, vals, axes=latent_axes)
|
||||
|
||||
|
|
@ -121,7 +122,7 @@ class lvm(matplotlib_show):
|
|||
self.move_on = False
|
||||
self.latent_index = latent_index
|
||||
self.latent_dim = model.input_dim
|
||||
|
||||
|
||||
# The red cross which shows current latent point.
|
||||
self.latent_values = vals
|
||||
self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
|
||||
|
|
@ -130,10 +131,10 @@ class lvm(matplotlib_show):
|
|||
|
||||
def modify(self, vals):
|
||||
"""When latent values are modified update the latent representation and ulso update the output visualization."""
|
||||
self.vals = vals.copy()
|
||||
self.vals = vals[None,:].copy()
|
||||
y = self.model.predict(self.vals)[0]
|
||||
self.data_visualize.modify(y)
|
||||
self.latent_handle.set_data(self.vals[self.latent_index[0]], self.vals[self.latent_index[1]])
|
||||
self.latent_handle.set_data(self.vals[:,self.latent_index[0]], self.vals[:,self.latent_index[1]])
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue