mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-01 15:52:39 +02:00
swiss_roll example
This commit is contained in:
parent
a0df861e8c
commit
d6c790ae9c
3 changed files with 19 additions and 18 deletions
|
|
@ -102,9 +102,10 @@ def swiss_roll_1000():
|
|||
Y = mat_data['X_data'][:, 0:1000].transpose()
|
||||
return {'Y': Y, 'info': "Subsample of the swiss roll data extracting only the first 1000 values."}
|
||||
|
||||
def swiss_roll():
|
||||
def swiss_roll(N=3000):
|
||||
mat_data = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data.mat'))
|
||||
Y = mat_data['X_data'][:, 0:3000].transpose()
|
||||
Y = mat_data['X_data'][:, 0:N].transpose()
|
||||
import ipdb;ipdb.set_trace()
|
||||
return {'Y': Y, 'info': "The first 3,000 points from the swiss roll data of Tennenbaum, de Silva and Langford (2001)."}
|
||||
|
||||
def toy_rbf_1d(seed=default_seed):
|
||||
|
|
@ -270,13 +271,13 @@ def cmu_mocap(subject, train_motions, test_motions=[], sample_every=4):
|
|||
|
||||
end_ind = 0
|
||||
for i in range(len(temp_Y)):
|
||||
start_ind = end_ind
|
||||
start_ind = end_ind
|
||||
end_ind += temp_Y[i].shape[0]
|
||||
Y[start_ind:end_ind, :] = temp_Y[i]
|
||||
lbls[start_ind:end_ind, :] = temp_lbls[i]
|
||||
if len(test_motions)>0:
|
||||
if len(test_motions) > 0:
|
||||
temp_Ytest = []
|
||||
temp_lblstest = []
|
||||
temp_lblstest = []
|
||||
|
||||
testexlbls = np.eye(len(test_motions))
|
||||
tot_test_length = 0
|
||||
|
|
@ -292,7 +293,7 @@ def cmu_mocap(subject, train_motions, test_motions=[], sample_every=4):
|
|||
|
||||
end_ind = 0
|
||||
for i in range(len(temp_Ytest)):
|
||||
start_ind = end_ind
|
||||
start_ind = end_ind
|
||||
end_ind += temp_Ytest[i].shape[0]
|
||||
Ytest[start_ind:end_ind, :] = temp_Ytest[i]
|
||||
lblstest[start_ind:end_ind, :] = temp_lblstest[i]
|
||||
|
|
@ -304,7 +305,7 @@ def cmu_mocap(subject, train_motions, test_motions=[], sample_every=4):
|
|||
for motion in train_motions:
|
||||
info += motion + ', '
|
||||
info = info[:-2]
|
||||
if len(test_motions)>0:
|
||||
if len(test_motions) > 0:
|
||||
info += '. Test motions: '
|
||||
for motion in test_motions:
|
||||
info += motion + ', '
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue