swiss_roll example added, BGPLVM_oil now working

This commit is contained in:
Max Zwiessele 2013-05-16 13:47:55 +01:00
parent 61a79c5041
commit 93d517f24e
5 changed files with 137 additions and 173 deletions

View file

@ -4,6 +4,7 @@ import numpy as np
import GPy
import scipy.sparse
import scipy.io
import cPickle as pickle
data_path = os.path.join(os.path.dirname(__file__), 'datasets')
default_seed = 10000
@ -96,6 +97,19 @@ def stick():
lbls = 'connect'
return {'Y': Y, 'connect' : connect, 'info': "Stick man data from Ohio."}
def swiss_roll_generated(N=1000, sigma=0.0):
with open(os.path.join(data_path, 'swiss_roll.pickle')) as f:
data = pickle.load(f)
Na = data['Y'].shape[0]
perm = np.random.permutation(np.r_[:Na])[:N]
Y = data['Y'][perm, :]
t = data['t'][perm]
c = data['colors'][perm, :]
so = np.argsort(t)
Y = Y[so, :]
t = t[so]
c = c[so, :]
return {'Y':Y, 't':t, 'colors':c}
def swiss_roll_1000():
mat_data = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data'))
@ -105,8 +119,7 @@ def swiss_roll_1000():
def swiss_roll(N=3000):
mat_data = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data.mat'))
Y = mat_data['X_data'][:, 0:N].transpose()
import ipdb;ipdb.set_trace()
return {'Y': Y, 'info': "The first 3,000 points from the swiss roll data of Tennenbaum, de Silva and Langford (2001)."}
return {'Y': Y, 'X': mat_data['X_data'], 'info': "The first 3,000 points from the swiss roll data of Tennenbaum, de Silva and Langford (2001)."}
def toy_rbf_1d(seed=default_seed):
np.random.seed(seed=seed)