better GPLVM oil flow demo

This commit is contained in:
Nicolo Fusi 2013-03-12 14:01:32 +00:00
parent c4d190e0fd
commit cbcb3cb337
2 changed files with 7 additions and 3 deletions

View file

@ -44,12 +44,14 @@ def GPLVM_oil_100():
data = GPy.util.datasets.oil_100() data = GPy.util.datasets.oil_100()
# create simple GP model # create simple GP model
m = GPy.models.GPLVM(data['X'], 2) kernel = GPy.kern.rbf(6, ARD = True) + GPy.kern.bias(6)
m = GPy.models.GPLVM(data['X'], 6, kernel = kernel)
# optimize # optimize
m.ensure_default_constraints() m.ensure_default_constraints()
m.optimize() m.optimize(messages=1)
# plot # plot
print(m) print(m)
m.plot_latent(labels=data['Y'].argmax(axis=1))
return m return m

View file

@ -66,6 +66,8 @@ class GPLVM(GP):
:param resolution: the resolution of the grid on which to evaluate the predictive variance :param resolution: the resolution of the grid on which to evaluate the predictive variance
""" """
util.plot.Tango.reset()
if labels is None: if labels is None:
labels = np.ones(self.N) labels = np.ones(self.N)
if which_indices is None: if which_indices is None:
@ -88,7 +90,7 @@ class GPLVM(GP):
#first, plot the output variance as a function of the latent space #first, plot the output variance as a function of the latent space
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(self.X[:,[input_1, input_2]],resolution=resolution) Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(self.X[:,[input_1, input_2]],resolution=resolution)
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1])) Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
Xtest_full[:, :2] = Xtest Xtest_full[:, :2] = Xtest
mu, var, low, up = self.predict(Xtest_full) mu, var, low, up = self.predict(Xtest_full)
var = var[:, :2] var = var[:, :2]
pb.imshow(var.reshape(resolution,resolution).T[::-1,:],extent=[xmin[0],xmax[0],xmin[1],xmax[1]],cmap=pb.cm.binary,interpolation='bilinear') pb.imshow(var.reshape(resolution,resolution).T[::-1,:],extent=[xmin[0],xmax[0],xmin[1],xmax[1]],cmap=pb.cm.binary,interpolation='bilinear')