mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
merged master back into devel (to sync bugfixes)
This commit is contained in:
parent
f6879ef768
commit
87304a0778
2 changed files with 3 additions and 31 deletions
|
|
@ -133,32 +133,3 @@ def stick():
|
||||||
plt.close('all')
|
plt.close('all')
|
||||||
|
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
def BGPLVM_oil():
|
|
||||||
data = GPy.util.datasets.oil()
|
|
||||||
Y, X = data['Y'], data['X']
|
|
||||||
X -= X.mean(axis=0)
|
|
||||||
X /= X.std(axis=0)
|
|
||||||
|
|
||||||
Q = 10
|
|
||||||
M = 30
|
|
||||||
|
|
||||||
kernel = GPy.kern.rbf(Q, ARD = True) + GPy.kern.bias(Q) + GPy.kern.white(Q)
|
|
||||||
m = GPy.models.Bayesian_GPLVM(X, Q, kernel=kernel, M=M)
|
|
||||||
# m.scale_factor = 100.0
|
|
||||||
m.constrain_positive('(white|noise|bias|X_variance|rbf_variance|rbf_length)')
|
|
||||||
from sklearn import cluster
|
|
||||||
km = cluster.KMeans(M, verbose=10)
|
|
||||||
Z = km.fit(m.X).cluster_centers_
|
|
||||||
# Z = GPy.util.misc.kmm_init(m.X, M)
|
|
||||||
m.set('iip', Z)
|
|
||||||
m.set('bias', 1e-4)
|
|
||||||
# optimize
|
|
||||||
# m.ensure_default_constraints()
|
|
||||||
|
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
m.optimize('tnc', messages=1)
|
|
||||||
print m
|
|
||||||
m.plot_latent(labels=data['Y'].argmax(axis=1))
|
|
||||||
return m
|
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ class GPLVM(GP):
|
||||||
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
||||||
Xtest_full[:, :2] = Xtest
|
Xtest_full[:, :2] = Xtest
|
||||||
mu, var, low, up = self.predict(Xtest_full)
|
mu, var, low, up = self.predict(Xtest_full)
|
||||||
var = var[:, :1] # FIXME: this was a :2
|
var = var[:, :1]
|
||||||
pb.imshow(var.reshape(resolution,resolution).T[::-1,:],
|
pb.imshow(var.reshape(resolution,resolution).T[::-1,:],
|
||||||
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear')
|
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear')
|
||||||
|
|
||||||
|
|
@ -119,5 +119,6 @@ class GPLVM(GP):
|
||||||
pb.xlim(xmin[0],xmax[0])
|
pb.xlim(xmin[0],xmax[0])
|
||||||
pb.ylim(xmin[1],xmax[1])
|
pb.ylim(xmin[1],xmax[1])
|
||||||
pb.grid(b=False) # remove the grid if present, it doesn't look good
|
pb.grid(b=False) # remove the grid if present, it doesn't look good
|
||||||
|
ax = pb.gca()
|
||||||
ax.set_aspect('auto') # set a nice aspect ratio
|
ax.set_aspect('auto') # set a nice aspect ratio
|
||||||
return pb.gca() #input_1, input_2 temporary removal, to return axes.
|
return ax
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue