mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-27 14:25:16 +02:00
[dim_reduce examples] updated swiss roll
This commit is contained in:
parent
fa750d4ba4
commit
d6bf3b0d79
3 changed files with 80 additions and 6 deletions
|
|
@ -99,7 +99,7 @@ def sparse_gplvm_oil(optimize=True, verbose=0, plot=True, N=100, Q=6, num_induci
|
||||||
m.kern.plot_ARD()
|
m.kern.plot_ARD()
|
||||||
return m
|
return m
|
||||||
|
|
||||||
def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=15, Q=4, sigma=.2):
|
def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=25, Q=4, sigma=.2):
|
||||||
import GPy
|
import GPy
|
||||||
from GPy.util.datasets import swiss_roll_generated
|
from GPy.util.datasets import swiss_roll_generated
|
||||||
from GPy.models import BayesianGPLVM
|
from GPy.models import BayesianGPLVM
|
||||||
|
|
@ -144,16 +144,15 @@ def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=15, Q=4
|
||||||
m = BayesianGPLVM(Y, Q, X=X, X_variance=S, num_inducing=num_inducing, Z=Z, kernel=kernel)
|
m = BayesianGPLVM(Y, Q, X=X, X_variance=S, num_inducing=num_inducing, Z=Z, kernel=kernel)
|
||||||
m.data_colors = c
|
m.data_colors = c
|
||||||
m.data_t = t
|
m.data_t = t
|
||||||
m['noise_variance'] = Y.var() / 100.
|
|
||||||
|
|
||||||
if optimize:
|
if optimize:
|
||||||
m.optimize('scg', messages=verbose, max_iters=2e3)
|
m.optimize('bfgs', messages=verbose, max_iters=2e3)
|
||||||
|
|
||||||
if plot:
|
if plot:
|
||||||
fig = plt.figure('fitted')
|
fig = plt.figure('fitted')
|
||||||
ax = fig.add_subplot(111)
|
ax = fig.add_subplot(111)
|
||||||
s = m.input_sensitivity().argsort()[::-1][:2]
|
s = m.input_sensitivity().argsort()[::-1][:2]
|
||||||
ax.scatter(*m.X.T[s], c=c)
|
ax.scatter(*m.X.mean.T[s], c=c)
|
||||||
|
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -635,7 +635,7 @@ def osu_run1(data_set='osu_run1', sample_every=4):
|
||||||
return data_details_return({'Y': Y, 'connect' : connect}, data_set)
|
return data_details_return({'Y': Y, 'connect' : connect}, data_set)
|
||||||
|
|
||||||
def swiss_roll_generated(num_samples=1000, sigma=0.0):
|
def swiss_roll_generated(num_samples=1000, sigma=0.0):
|
||||||
with open(os.path.join(data_path, 'swiss_roll.pickle')) as f:
|
with open(os.path.join(os.path.dirname(__file__), 'datasets', 'swiss_roll.pickle')) as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
Na = data['Y'].shape[0]
|
Na = data['Y'].shape[0]
|
||||||
perm = np.random.permutation(np.r_[:Na])[:num_samples]
|
perm = np.random.permutation(np.r_[:Na])[:num_samples]
|
||||||
|
|
|
||||||
75
GPy/util/datasets/swiss_roll.pickle
Normal file
75
GPy/util/datasets/swiss_roll.pickle
Normal file
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue