Added CIFAR-10 data to data sets.

This commit is contained in:
Neil Lawrence 2014-05-27 16:27:27 +01:00
parent e3b6d9c9c5
commit 9ea236112e
2 changed files with 39 additions and 1 deletions

View file

@ -409,7 +409,7 @@ def lee_yeast_ChIP(data_set='lee_yeast_ChIP'):
transcription_factors = [col for col in X.columns if col[:7] != 'Unnamed']
annotations = X[['Unnamed: 1', 'Unnamed: 2', 'Unnamed: 3']]
X = X[transcription_factors]
return data_details_return({'annotations' : annotations, 'X' : X, 'transcription_factors', transcription_factors}, data_set)
return data_details_return({'annotations' : annotations, 'X' : X, 'transcription_factors': transcription_factors}, data_set)
def fruitfly_tomancak(data_set='fruitfly_tomancak', gene_number=None):
@ -1145,6 +1145,30 @@ def creep_data(data_set='creep_rupture'):
X = all_data[:, features].copy()
return data_details_return({'X': X, 'y': y}, data_set)
def cifar10(data_set='cifar-10'):
"""The Candian Institute for Advanced Research 10 image data set. Code for loading in this data is taken from this Boris Babenko's blog post, original code available here: http://bbabenko.tumblr.com/post/86756017649/learning-low-level-vision-feautres-in-10-lines-of-code"""
dirpath = os.path.join(data_path, data_set)
filename = os.path.join(dirpath, 'cifar-10-python.tar.gz')
if not data_available(data_set):
download_data(data_set)
import tarfile
# This code is from Boris Babenko's blog post.
# http://bbabenko.tumblr.com/post/86756017649/learning-low-level-vision-feautres-in-10-lines-of-code
tfile = tarfile.open(filename, 'r:gz')
tfile.extractall(dirpath)
with open(os.path.join(dirpath, 'cifar-10-batches-py','data_batch_1'),'rb') as f:
data = pickle.load(f)
images = data['data'].reshape((-1,3,32,32)).astype('float32')/255
images = np.rollaxis(images, 1, 4)
patches = np.zeros((0,5,5,3))
for x in range(0,32-5,5):
for y in range(0,32-5,5):
patches = np.concatenate((patches, images[:,x:x+5,y:y+5,:]), axis=0)
patches = patches.reshape((patches.shape[0],-1))
return data_details_return({'Y': patches}, data_set)
def cmu_mocap_49_balance(data_set='cmu_mocap'):
"""Load CMU subject 49's one legged balancing motion that was used by Alvarez, Luengo and Lawrence at AISTATS 2009."""
train_motions = ['18', '19']