mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
[dim_reduce examples] updated swiss roll
This commit is contained in:
parent
fa750d4ba4
commit
d6bf3b0d79
3 changed files with 80 additions and 6 deletions
|
|
@ -99,7 +99,7 @@ def sparse_gplvm_oil(optimize=True, verbose=0, plot=True, N=100, Q=6, num_induci
|
|||
m.kern.plot_ARD()
|
||||
return m
|
||||
|
||||
def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=15, Q=4, sigma=.2):
|
||||
def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=25, Q=4, sigma=.2):
|
||||
import GPy
|
||||
from GPy.util.datasets import swiss_roll_generated
|
||||
from GPy.models import BayesianGPLVM
|
||||
|
|
@ -144,16 +144,15 @@ def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=15, Q=4
|
|||
m = BayesianGPLVM(Y, Q, X=X, X_variance=S, num_inducing=num_inducing, Z=Z, kernel=kernel)
|
||||
m.data_colors = c
|
||||
m.data_t = t
|
||||
m['noise_variance'] = Y.var() / 100.
|
||||
|
||||
|
||||
if optimize:
|
||||
m.optimize('scg', messages=verbose, max_iters=2e3)
|
||||
m.optimize('bfgs', messages=verbose, max_iters=2e3)
|
||||
|
||||
if plot:
|
||||
fig = plt.figure('fitted')
|
||||
ax = fig.add_subplot(111)
|
||||
s = m.input_sensitivity().argsort()[::-1][:2]
|
||||
ax.scatter(*m.X.T[s], c=c)
|
||||
ax.scatter(*m.X.mean.T[s], c=c)
|
||||
|
||||
return m
|
||||
|
||||
|
|
|
|||
|
|
@ -635,7 +635,7 @@ def osu_run1(data_set='osu_run1', sample_every=4):
|
|||
return data_details_return({'Y': Y, 'connect' : connect}, data_set)
|
||||
|
||||
def swiss_roll_generated(num_samples=1000, sigma=0.0):
|
||||
with open(os.path.join(data_path, 'swiss_roll.pickle')) as f:
|
||||
with open(os.path.join(os.path.dirname(__file__), 'datasets', 'swiss_roll.pickle')) as f:
|
||||
data = pickle.load(f)
|
||||
Na = data['Y'].shape[0]
|
||||
perm = np.random.permutation(np.r_[:Na])[:num_samples]
|
||||
|
|
|
|||
75
GPy/util/datasets/swiss_roll.pickle
Normal file
75
GPy/util/datasets/swiss_roll.pickle
Normal file
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue