mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
better GPLVM oil flow demo
This commit is contained in:
parent
c4d190e0fd
commit
cbcb3cb337
2 changed files with 7 additions and 3 deletions
|
|
@ -44,12 +44,14 @@ def GPLVM_oil_100():
|
||||||
data = GPy.util.datasets.oil_100()
|
data = GPy.util.datasets.oil_100()
|
||||||
|
|
||||||
# create simple GP model
|
# create simple GP model
|
||||||
m = GPy.models.GPLVM(data['X'], 2)
|
kernel = GPy.kern.rbf(6, ARD = True) + GPy.kern.bias(6)
|
||||||
|
m = GPy.models.GPLVM(data['X'], 6, kernel = kernel)
|
||||||
|
|
||||||
# optimize
|
# optimize
|
||||||
m.ensure_default_constraints()
|
m.ensure_default_constraints()
|
||||||
m.optimize()
|
m.optimize(messages=1)
|
||||||
|
|
||||||
# plot
|
# plot
|
||||||
print(m)
|
print(m)
|
||||||
|
m.plot_latent(labels=data['Y'].argmax(axis=1))
|
||||||
return m
|
return m
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,8 @@ class GPLVM(GP):
|
||||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
util.plot.Tango.reset()
|
||||||
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
labels = np.ones(self.N)
|
labels = np.ones(self.N)
|
||||||
if which_indices is None:
|
if which_indices is None:
|
||||||
|
|
@ -88,7 +90,7 @@ class GPLVM(GP):
|
||||||
#first, plot the output variance as a function of the latent space
|
#first, plot the output variance as a function of the latent space
|
||||||
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(self.X[:,[input_1, input_2]],resolution=resolution)
|
Xtest, xx,yy,xmin,xmax = util.plot.x_frame2D(self.X[:,[input_1, input_2]],resolution=resolution)
|
||||||
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
||||||
Xtest_full[:, :2] = Xtest
|
Xtest_full[:, :2] = Xtest
|
||||||
mu, var, low, up = self.predict(Xtest_full)
|
mu, var, low, up = self.predict(Xtest_full)
|
||||||
var = var[:, :2]
|
var = var[:, :2]
|
||||||
pb.imshow(var.reshape(resolution,resolution).T[::-1,:],extent=[xmin[0],xmax[0],xmin[1],xmax[1]],cmap=pb.cm.binary,interpolation='bilinear')
|
pb.imshow(var.reshape(resolution,resolution).T[::-1,:],extent=[xmin[0],xmax[0],xmin[1],xmax[1]],cmap=pb.cm.binary,interpolation='bilinear')
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue