mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-23 15:48:09 +02:00
Further edits on visualization code for faces example.
This commit is contained in:
parent
3fd0672092
commit
fce4dd7fde
9 changed files with 151 additions and 80 deletions
|
|
@ -23,13 +23,13 @@ class model(parameterised):
|
||||||
self._set_params(self._get_params())
|
self._set_params(self._get_params())
|
||||||
self.preferred_optimizer = 'tnc'
|
self.preferred_optimizer = 'tnc'
|
||||||
def _get_params(self):
|
def _get_params(self):
|
||||||
raise NotImplementedError, "this needs to be implemented to utilise the model class"
|
raise NotImplementedError, "this needs to be implemented to use the model class"
|
||||||
def _set_params(self,x):
|
def _set_params(self,x):
|
||||||
raise NotImplementedError, "this needs to be implemented to utilise the model class"
|
raise NotImplementedError, "this needs to be implemented to use the model class"
|
||||||
def log_likelihood(self):
|
def log_likelihood(self):
|
||||||
raise NotImplementedError, "this needs to be implemented to utilise the model class"
|
raise NotImplementedError, "this needs to be implemented to use the model class"
|
||||||
def _log_likelihood_gradients(self):
|
def _log_likelihood_gradients(self):
|
||||||
raise NotImplementedError, "this needs to be implemented to utilise the model class"
|
raise NotImplementedError, "this needs to be implemented to use the model class"
|
||||||
|
|
||||||
def set_prior(self,which,what):
|
def set_prior(self,which,what):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pylab as pb
|
import pylab as pb
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
import GPy
|
import GPy
|
||||||
|
|
||||||
default_seed = np.random.seed(123344)
|
default_seed = np.random.seed(123344)
|
||||||
|
|
@ -55,3 +57,51 @@ def GPLVM_oil_100():
|
||||||
print(m)
|
print(m)
|
||||||
m.plot_latent(labels=data['Y'].argmax(axis=1))
|
m.plot_latent(labels=data['Y'].argmax(axis=1))
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
def oil_100():
|
||||||
|
data = GPy.util.datasets.oil_100()
|
||||||
|
m = GPy.models.GPLVM(data['X'], 2)
|
||||||
|
|
||||||
|
# optimize
|
||||||
|
m.ensure_default_constraints()
|
||||||
|
m.optimize(messages=1, max_iters=2)
|
||||||
|
|
||||||
|
# plot
|
||||||
|
print(m)
|
||||||
|
#m.plot_latent(labels=data['Y'].argmax(axis=1))
|
||||||
|
return m
|
||||||
|
|
||||||
|
def brendan_faces():
|
||||||
|
data = GPy.util.datasets.brendan_faces()
|
||||||
|
Y = data['Y'][0:500, :]
|
||||||
|
m = GPy.models.GPLVM(Y, 2, init='rand')
|
||||||
|
|
||||||
|
# optimize
|
||||||
|
m.ensure_default_constraints()
|
||||||
|
m.optimize(messages=1, max_f_eval=40)
|
||||||
|
|
||||||
|
ax = m.plot_latent()
|
||||||
|
y = m.likelihood.Y[0,:]
|
||||||
|
data_show = GPy.util.visualize.image_show(y[None, :], dimensions=(20, 28), transpose=True, invert=False, scale=False)
|
||||||
|
lvm_visualizer = GPy.util.visualize.lvm(m, data_show, ax)
|
||||||
|
raw_input('Press enter to finish')
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
def stick():
|
||||||
|
data = GPy.util.datasets.stick()
|
||||||
|
m = GPy.models.GPLVM(data['Y'], 2, init='rand')
|
||||||
|
|
||||||
|
# optimize
|
||||||
|
m.ensure_default_constraints()
|
||||||
|
m.optimize(messages=1, max_f_eval=10000)
|
||||||
|
|
||||||
|
ax = m.plot_latent()
|
||||||
|
y = m.likelihood.Y[0,:]
|
||||||
|
data_show = GPy.util.visualize.stick_show(y[None, :], connect=data['connect'])
|
||||||
|
lvm_visualizer = GPy.util.visualize.lvm(m, data_show, ax)
|
||||||
|
raw_input('Press enter to finish')
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ def silhouette():
|
||||||
|
|
||||||
def coregionalisation_toy2():
|
def coregionalisation_toy2():
|
||||||
"""
|
"""
|
||||||
A simple demonstration of coregionalisation on two sinusoidal functions
|
A simple demonstration of coregionalisation on two sinusoidal functions.
|
||||||
"""
|
"""
|
||||||
X1 = np.random.rand(50,1)*8
|
X1 = np.random.rand(50,1)*8
|
||||||
X2 = np.random.rand(30,1)*5
|
X2 = np.random.rand(30,1)*5
|
||||||
|
|
@ -106,7 +106,7 @@ def coregionalisation_toy2():
|
||||||
|
|
||||||
def coregionalisation_toy():
|
def coregionalisation_toy():
|
||||||
"""
|
"""
|
||||||
A simple demonstration of coregionalisation on two sinusoidal functions
|
A simple demonstration of coregionalisation on two sinusoidal functions.
|
||||||
"""
|
"""
|
||||||
X1 = np.random.rand(50,1)*8
|
X1 = np.random.rand(50,1)*8
|
||||||
X2 = np.random.rand(30,1)*5
|
X2 = np.random.rand(30,1)*5
|
||||||
|
|
@ -139,7 +139,7 @@ def coregionalisation_toy():
|
||||||
|
|
||||||
def coregionalisation_sparse():
|
def coregionalisation_sparse():
|
||||||
"""
|
"""
|
||||||
A simple demonstration of coregionalisation on two sinusoidal functions
|
A simple demonstration of coregionalisation on two sinusoidal functions using sparse approximations.
|
||||||
"""
|
"""
|
||||||
X1 = np.random.rand(500,1)*8
|
X1 = np.random.rand(500,1)*8
|
||||||
X2 = np.random.rand(300,1)*5
|
X2 = np.random.rand(300,1)*5
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ class GPLVM(GP):
|
||||||
k = [p for p in self.kern.parts if p.name in ['rbf','linear']]
|
k = [p for p in self.kern.parts if p.name in ['rbf','linear']]
|
||||||
if (not len(k)==1) or (not k[0].ARD):
|
if (not len(k)==1) or (not k[0].ARD):
|
||||||
raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'"
|
raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'"
|
||||||
|
input_1, input_2 = self.lengthscale_order()
|
||||||
k = k[0]
|
k = k[0]
|
||||||
if k.name=='rbf':
|
if k.name=='rbf':
|
||||||
input_1, input_2 = np.argsort(k.lengthscale)[:2]
|
input_1, input_2 = np.argsort(k.lengthscale)[:2]
|
||||||
|
|
@ -92,7 +93,7 @@ class GPLVM(GP):
|
||||||
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
||||||
Xtest_full[:, :2] = Xtest
|
Xtest_full[:, :2] = Xtest
|
||||||
mu, var, low, up = self.predict(Xtest_full)
|
mu, var, low, up = self.predict(Xtest_full)
|
||||||
var = var[:, :2]
|
var = var.mean(axis=1) # this was var[:, :2] edit by Neil
|
||||||
pb.imshow(var.reshape(resolution,resolution).T[::-1,:],extent=[xmin[0],xmax[0],xmin[1],xmax[1]],cmap=pb.cm.binary,interpolation='bilinear')
|
pb.imshow(var.reshape(resolution,resolution).T[::-1,:],extent=[xmin[0],xmax[0],xmin[1],xmax[1]],cmap=pb.cm.binary,interpolation='bilinear')
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -122,4 +123,4 @@ class GPLVM(GP):
|
||||||
pb.xlim(xmin[0],xmax[0])
|
pb.xlim(xmin[0],xmax[0])
|
||||||
pb.ylim(xmin[1],xmax[1])
|
pb.ylim(xmin[1],xmax[1])
|
||||||
|
|
||||||
return input_1, input_2
|
return pb.gca() #input_1, input_2 temporary removal, to return axes.
|
||||||
|
|
|
||||||
|
|
@ -10,4 +10,6 @@ import Tango
|
||||||
import misc
|
import misc
|
||||||
import warping_functions
|
import warping_functions
|
||||||
import datasets
|
import datasets
|
||||||
|
import mocap
|
||||||
|
import visualize
|
||||||
import decorators
|
import decorators
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,12 @@ def sample_class(f):
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def della_gatta_TRP63_gene_expression(gene_number=None):
|
def della_gatta_TRP63_gene_expression(gene_number=None):
|
||||||
matData = scipy.io.loadmat(os.path.join(data_path, 'DellaGattadata.mat'))
|
mat_data = scipy.io.loadmat(os.path.join(data_path, 'DellaGattadata.mat'))
|
||||||
X = np.double(matData['timepoints'])
|
X = np.double(mat_data['timepoints'])
|
||||||
if gene_number == None:
|
if gene_number == None:
|
||||||
Y = matData['exprs_tp53_RMA']
|
Y = mat_data['exprs_tp53_RMA']
|
||||||
else:
|
else:
|
||||||
Y = matData['exprs_tp53_RMA'][:, gene_number]
|
Y = mat_data['exprs_tp53_RMA'][:, gene_number]
|
||||||
if len(Y.shape) == 1:
|
if len(Y.shape) == 1:
|
||||||
Y = Y[:, None]
|
Y = Y[:, None]
|
||||||
return {'X': X, 'Y': Y, 'info': "The full gene expression data set from della Gatta et al (http://www.ncbi.nlm.nih.gov/pmc/articles/PMC2413161/) processed by RMA."}
|
return {'X': X, 'Y': Y, 'info': "The full gene expression data set from della Gatta et al (http://www.ncbi.nlm.nih.gov/pmc/articles/PMC2413161/) processed by RMA."}
|
||||||
|
|
@ -60,28 +60,42 @@ def pumadyn(seed=default_seed):
|
||||||
return {'X': X, 'Y': Y, 'Xtest': Xtest, 'Ytest': Ytest, 'info': "The puma robot arm data with 32 inputs. This data is the non linear case with medium noise (pumadyn-32nm). For training 7,168 examples are sampled without replacement."}
|
return {'X': X, 'Y': Y, 'Xtest': Xtest, 'Ytest': Ytest, 'info': "The puma robot arm data with 32 inputs. This data is the non linear case with medium noise (pumadyn-32nm). For training 7,168 examples are sampled without replacement."}
|
||||||
|
|
||||||
|
|
||||||
|
def brendan_faces():
|
||||||
|
mat_data = scipy.io.loadmat(os.path.join(data_path, 'frey_rawface.mat'))
|
||||||
|
Y = mat_data['ff'].T
|
||||||
|
return {'Y': Y, 'info': "Face data made available by Brendan Frey"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def silhouette():
|
def silhouette():
|
||||||
# Ankur Agarwal and Bill Trigg's silhoutte data.
|
# Ankur Agarwal and Bill Trigg's silhoutte data.
|
||||||
matData = scipy.io.loadmat(os.path.join(data_path, 'mocap', 'ankur', 'ankurDataPoseSilhouette.mat'))
|
mat_data = scipy.io.loadmat(os.path.join(data_path, 'mocap', 'ankur', 'ankurDataPoseSilhouette.mat'))
|
||||||
inMean = np.mean(matData['Y'])
|
inMean = np.mean(mat_data['Y'])
|
||||||
inScales = np.sqrt(np.var(matData['Y']))
|
inScales = np.sqrt(np.var(mat_data['Y']))
|
||||||
X = matData['Y'] - inMean
|
X = mat_data['Y'] - inMean
|
||||||
X = X/inScales
|
X = X/inScales
|
||||||
Xtest = matData['Y_test'] - inMean
|
Xtest = mat_data['Y_test'] - inMean
|
||||||
Xtest = Xtest/inScales
|
Xtest = Xtest/inScales
|
||||||
Y = matData['Z']
|
Y = mat_data['Z']
|
||||||
Ytest = matData['Z_test']
|
Ytest = mat_data['Z_test']
|
||||||
return {'X': X, 'Y': Y, 'Xtest': Xtest, 'Ytest': Ytest, 'info': "Artificial silhouette simulation data developed from Agarwal and Triggs (2004)."}
|
return {'X': X, 'Y': Y, 'Xtest': Xtest, 'Ytest': Ytest, 'info': "Artificial silhouette simulation data developed from Agarwal and Triggs (2004)."}
|
||||||
|
|
||||||
|
def stick():
|
||||||
|
Y, connect = GPy.util.mocap.load_text_data('run1', data_path)
|
||||||
|
Y = Y[0:-1:4, :]
|
||||||
|
lbls = 'connect'
|
||||||
|
return {'Y': Y, 'connect' : connect, 'info': "Stick man data from Ohio."}
|
||||||
|
|
||||||
|
|
||||||
def swiss_roll_1000():
|
def swiss_roll_1000():
|
||||||
matData = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data'))
|
mat_data = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data'))
|
||||||
Y = matData['X_data'][:, 0:1000].transpose()
|
Y = mat_data['X_data'][:, 0:1000].transpose()
|
||||||
return {'Y': Y, 'info': "Subsample of the swiss roll data extracting only the first 1000 values."}
|
return {'Y': Y, 'info': "Subsample of the swiss roll data extracting only the first 1000 values."}
|
||||||
|
|
||||||
def swiss_roll():
|
def swiss_roll():
|
||||||
matData = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data.mat'))
|
mat_data = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data.mat'))
|
||||||
Y = matData['X_data'][:, 0:3000].transpose()
|
Y = mat_data['X_data'][:, 0:3000].transpose()
|
||||||
return {'Y': Y, 'info': "The first 3,000 points from the swiss roll data of Tennenbaum, de Silva and Langford (2001)."}
|
return {'Y': Y, 'info': "The first 3,000 points from the swiss roll data of Tennenbaum, de Silva and Langford (2001)."}
|
||||||
|
|
||||||
def toy_rbf_1d(seed=default_seed):
|
def toy_rbf_1d(seed=default_seed):
|
||||||
|
|
@ -202,3 +216,4 @@ def creep_data():
|
||||||
features.extend(range(2, 31))
|
features.extend(range(2, 31))
|
||||||
X = all_data[:,features].copy()
|
X = all_data[:,features].copy()
|
||||||
return {'X': X, 'y' : y}
|
return {'X': X, 'y' : y}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,8 @@ def PCA(Y, Q):
|
||||||
"""
|
"""
|
||||||
if not np.allclose(Y.mean(axis=0), 0.0):
|
if not np.allclose(Y.mean(axis=0), 0.0):
|
||||||
print "Y is not zero mean, centering it locally (GPy.util.linalg.PCA)"
|
print "Y is not zero mean, centering it locally (GPy.util.linalg.PCA)"
|
||||||
Y -= Y.mean(axis=0)
|
|
||||||
|
Y -= Y.mean(axis=0)
|
||||||
|
|
||||||
Z = linalg.svd(Y, full_matrices = False)
|
Z = linalg.svd(Y, full_matrices = False)
|
||||||
[X, W] = [Z[0][:,0:Q], np.dot(np.diag(Z[1]), Z[2]).T[:,0:Q]]
|
[X, W] = [Z[0][:,0:Q], np.dot(np.diag(Z[1]), Z[2]).T[:,0:Q]]
|
||||||
|
|
|
||||||
|
|
@ -8,12 +8,17 @@ def load_text_data(dataset, directory, centre=True):
|
||||||
# Remove markers where there is a NaN
|
# Remove markers where there is a NaN
|
||||||
present_index = [i for i in range(points[0].shape[1]) if not (np.any(np.isnan(points[0][:, i])) or np.any(np.isnan(points[0][:, i])) or np.any(np.isnan(points[0][:, i])))]
|
present_index = [i for i in range(points[0].shape[1]) if not (np.any(np.isnan(points[0][:, i])) or np.any(np.isnan(points[0][:, i])) or np.any(np.isnan(points[0][:, i])))]
|
||||||
|
|
||||||
|
point_names = point_names[present_index]
|
||||||
|
for i in range(3):
|
||||||
|
points[i] = points[i][:, present_index]
|
||||||
|
if centre:
|
||||||
|
points[i] = (points[i].T - points[i].mean(axis=1)).T
|
||||||
|
|
||||||
# Concatanate the X, Y and Z markers together
|
# Concatanate the X, Y and Z markers together
|
||||||
Y = np.concatenate((points[0][:, present_index], points[1][:, present_index], points[2][:, present_index]), axis=1)
|
Y = np.concatenate((points[0], points[1], points[2]), axis=1)
|
||||||
if centre:
|
|
||||||
Y = Y - Y.mean(axis=0)
|
|
||||||
Y = Y/400.
|
Y = Y/400.
|
||||||
return Y
|
connect = read_connections(os.path.join(directory, 'connections.txt'), point_names)
|
||||||
|
return Y, connect
|
||||||
|
|
||||||
def parse_text(file_name):
|
def parse_text(file_name):
|
||||||
"""Parse data from Ohio State University text mocap files (http://accad.osu.edu/research/mocap/mocap_data.htm)."""
|
"""Parse data from Ohio State University text mocap files (http://accad.osu.edu/research/mocap/mocap_data.htm)."""
|
||||||
|
|
@ -23,7 +28,7 @@ def parse_text(file_name):
|
||||||
point_names = np.array(fid.readline().split())[2:-1:3]
|
point_names = np.array(fid.readline().split())[2:-1:3]
|
||||||
fid.close()
|
fid.close()
|
||||||
for i in range(len(point_names)):
|
for i in range(len(point_names)):
|
||||||
point_names[i] = point_names[i][0:-3]
|
point_names[i] = point_names[i][0:-2]
|
||||||
|
|
||||||
# Read the matrix data
|
# Read the matrix data
|
||||||
S = np.loadtxt(file_name, skiprows=1)
|
S = np.loadtxt(file_name, skiprows=1)
|
||||||
|
|
@ -42,37 +47,28 @@ def parse_text(file_name):
|
||||||
|
|
||||||
return points, point_names, times
|
return points, point_names, times
|
||||||
|
|
||||||
#def read_connections():
|
def read_connections(file_name, point_names):
|
||||||
|
"""Read a file detailing which markers should be connected to which for motion capture data."""
|
||||||
|
|
||||||
|
connections = []
|
||||||
|
fid = open(file_name, 'r')
|
||||||
|
line=fid.readline()
|
||||||
|
while(line):
|
||||||
|
connections.append(np.array(line.split(',')))
|
||||||
|
connections[-1][0] = connections[-1][0].strip()
|
||||||
|
connections[-1][1] = connections[-1][1].strip()
|
||||||
|
line = fid.readline()
|
||||||
|
connect = np.zeros((len(point_names), len(point_names)),dtype=bool)
|
||||||
|
for i in range(len(point_names)):
|
||||||
|
for j in range(len(point_names)):
|
||||||
|
for k in range(len(connections)):
|
||||||
|
if connections[k][0] == point_names[i] and connections[k][1] == point_names[j]:
|
||||||
|
|
||||||
|
connect[i,j]=True
|
||||||
|
connect[j,i]=True
|
||||||
|
break
|
||||||
|
|
||||||
|
return connect
|
||||||
|
|
||||||
# fid = fopen(fileName);
|
|
||||||
# i = 1;
|
|
||||||
# rem = fgets(fid);
|
|
||||||
# while(rem ~= -1)
|
|
||||||
# [token, rem] = strtok(rem, ',');
|
|
||||||
# connections{i, 1} = fliplr(deblank(fliplr(deblank(token))));
|
|
||||||
# [token, rem] = strtok(rem, ',');
|
|
||||||
# connections{i, 2} = fliplr(deblank(fliplr(deblank(token))));
|
|
||||||
# i = i + 1;
|
|
||||||
# rem = fgets(fid);
|
|
||||||
# end
|
|
||||||
|
|
||||||
# connect = zeros(length(pointNames));
|
|
||||||
# fclose(fid);
|
|
||||||
# for i = 1:size(connections, 1);
|
|
||||||
# for j = 1:length(pointNames)
|
|
||||||
# if strcmp(pointNames{j}, connections{i, 1}) | ...
|
|
||||||
# strcmp(pointNames{j}, connections{i, 2})
|
|
||||||
# for k = 1:length(pointNames)
|
|
||||||
# if k == j
|
|
||||||
# break
|
|
||||||
# end
|
|
||||||
# if strcmp(pointNames{k}, connections{i, 1}) | ...
|
|
||||||
# strcmp(pointNames{k}, connections{i, 2})
|
|
||||||
# connect(j, k) = 1;
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
# connect = sparse(connect);
|
|
||||||
|
|
|
||||||
|
|
@ -4,18 +4,18 @@ import GPy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class lvm:
|
class lvm:
|
||||||
def __init__(self, model, data_visualize, ax):
|
def __init__(self, model, data_visualize, latent_axis):
|
||||||
self.cid = ax.figure.canvas.mpl_connect('button_press_event', self.on_click)
|
self.cid = latent_axis.figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||||
self.cid = ax.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
self.cid = latent_axis.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||||
self.data_visualize = data_visualize
|
self.data_visualize = data_visualize
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ax = ax
|
self.latent_axis = latent_axis
|
||||||
self.called = False
|
self.called = False
|
||||||
self.move_on = False
|
self.move_on = False
|
||||||
|
|
||||||
def on_click(self, event):
|
def on_click(self, event):
|
||||||
print 'click', event.xdata, event.ydata
|
print 'click', event.xdata, event.ydata
|
||||||
if event.inaxes!=self.ax: return
|
if event.inaxes!=self.latent_axis: return
|
||||||
self.move_on = not self.move_on
|
self.move_on = not self.move_on
|
||||||
print
|
print
|
||||||
if self.called:
|
if self.called:
|
||||||
|
|
@ -26,10 +26,10 @@ class lvm:
|
||||||
else:
|
else:
|
||||||
self.xs = [event.xdata]
|
self.xs = [event.xdata]
|
||||||
self.ys = [event.ydata]
|
self.ys = [event.ydata]
|
||||||
self.line, = ax.plot(event.xdata, event.ydata)
|
self.line, = self.latent_axis.plot(event.xdata, event.ydata)
|
||||||
self.called = True
|
self.called = True
|
||||||
def on_move(self, event):
|
def on_move(self, event):
|
||||||
if event.inaxes!=self.ax: return
|
if event.inaxes!=self.latent_axis: return
|
||||||
if self.called and self.move_on:
|
if self.called and self.move_on:
|
||||||
# Call modify code on move
|
# Call modify code on move
|
||||||
#print 'move', event.xdata, event.ydata
|
#print 'move', event.xdata, event.ydata
|
||||||
|
|
@ -68,28 +68,34 @@ class vector_show(data_show):
|
||||||
|
|
||||||
class image_show(data_show):
|
class image_show(data_show):
|
||||||
"""Show a data vector as an image."""
|
"""Show a data vector as an image."""
|
||||||
def __init__(self, vals, axis=None, dimensions=(16,16), transpose=False, invert=False):
|
def __init__(self, vals, axis=None, dimensions=(16,16), transpose=False, invert=False, scale=False):
|
||||||
data_show.__init__(self, vals, axis)
|
data_show.__init__(self, vals, axis)
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.fig_display = plt.figure()
|
|
||||||
self.set_image(vals)
|
|
||||||
self.handle = plt.imshow(self.vals)
|
|
||||||
self.transpose = transpose
|
self.transpose = transpose
|
||||||
self.invert = invert
|
self.invert = invert
|
||||||
|
self.scale = scale
|
||||||
|
self.set_image(vals/255.)
|
||||||
|
self.handle = self.axis.imshow(self.vals, interpolation='nearest')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
def modify(self, vals):
|
def modify(self, vals):
|
||||||
self.set_image(vals)
|
self.set_image(vals/255.)
|
||||||
|
#self.handle.remove()
|
||||||
|
#self.handle = self.axis.imshow(self.vals)
|
||||||
self.handle.set_array(self.vals)
|
self.handle.set_array(self.vals)
|
||||||
self.axis.figure.canvas.draw()
|
#self.axis.figure.canvas.draw()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
def set_image(self, vals):
|
def set_image(self, vals):
|
||||||
self.vals = np.reshape(vals, self.dimensions)
|
self.vals = np.reshape(vals, self.dimensions, order='F')
|
||||||
if self.transpose:
|
if self.transpose:
|
||||||
self.vals = self.vals.T
|
self.vals = self.vals.T
|
||||||
if self.invert:
|
if not self.scale:
|
||||||
self.vals = -self.vals
|
self.vals = self.vals
|
||||||
|
#if self.invert:
|
||||||
|
# self.vals = -self.vals
|
||||||
|
|
||||||
class stick_show(data_show):
|
class stick_show(data_show):
|
||||||
"""Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect."""
|
"""Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect."""
|
||||||
|
|
||||||
def __init__(self, vals, axis=None, connect=None):
|
def __init__(self, vals, axis=None, connect=None):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue