[dim_reduce examples] updated swiss roll

This commit is contained in:
mzwiessele 2014-05-22 15:25:26 +01:00
parent fa750d4ba4
commit d6bf3b0d79
3 changed files with 80 additions and 6 deletions

View file

@ -99,7 +99,7 @@ def sparse_gplvm_oil(optimize=True, verbose=0, plot=True, N=100, Q=6, num_induci
m.kern.plot_ARD()
return m
def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=15, Q=4, sigma=.2):
def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=25, Q=4, sigma=.2):
import GPy
from GPy.util.datasets import swiss_roll_generated
from GPy.models import BayesianGPLVM
@ -144,16 +144,15 @@ def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=15, Q=4
m = BayesianGPLVM(Y, Q, X=X, X_variance=S, num_inducing=num_inducing, Z=Z, kernel=kernel)
m.data_colors = c
m.data_t = t
m['noise_variance'] = Y.var() / 100.
if optimize:
m.optimize('scg', messages=verbose, max_iters=2e3)
m.optimize('bfgs', messages=verbose, max_iters=2e3)
if plot:
fig = plt.figure('fitted')
ax = fig.add_subplot(111)
s = m.input_sensitivity().argsort()[::-1][:2]
ax.scatter(*m.X.T[s], c=c)
ax.scatter(*m.X.mean.T[s], c=c)
return m

View file

@ -635,7 +635,7 @@ def osu_run1(data_set='osu_run1', sample_every=4):
return data_details_return({'Y': Y, 'connect' : connect}, data_set)
def swiss_roll_generated(num_samples=1000, sigma=0.0):
with open(os.path.join(data_path, 'swiss_roll.pickle')) as f:
with open(os.path.join(os.path.dirname(__file__), 'datasets', 'swiss_roll.pickle')) as f:
data = pickle.load(f)
Na = data['Y'].shape[0]
perm = np.random.permutation(np.r_[:Na])[:num_samples]

File diff suppressed because one or more lines are too long