diff --git a/GPy/core/model.py b/GPy/core/model.py index 0804f277..9216aea6 100644 --- a/GPy/core/model.py +++ b/GPy/core/model.py @@ -23,13 +23,13 @@ class model(parameterised): self._set_params(self._get_params()) self.preferred_optimizer = 'tnc' def _get_params(self): - raise NotImplementedError, "this needs to be implemented to utilise the model class" + raise NotImplementedError, "this needs to be implemented to use the model class" def _set_params(self,x): - raise NotImplementedError, "this needs to be implemented to utilise the model class" + raise NotImplementedError, "this needs to be implemented to use the model class" def log_likelihood(self): - raise NotImplementedError, "this needs to be implemented to utilise the model class" + raise NotImplementedError, "this needs to be implemented to use the model class" def _log_likelihood_gradients(self): - raise NotImplementedError, "this needs to be implemented to utilise the model class" + raise NotImplementedError, "this needs to be implemented to use the model class" def set_prior(self,which,what): """ diff --git a/GPy/examples/dimensionality_reduction.py b/GPy/examples/dimensionality_reduction.py index d7610acb..e4308465 100644 --- a/GPy/examples/dimensionality_reduction.py +++ b/GPy/examples/dimensionality_reduction.py @@ -3,6 +3,8 @@ import numpy as np import pylab as pb +from matplotlib import pyplot as plt + import GPy default_seed = np.random.seed(123344) @@ -55,3 +57,51 @@ def GPLVM_oil_100(): print(m) m.plot_latent(labels=data['Y'].argmax(axis=1)) return m + +def oil_100(): + data = GPy.util.datasets.oil_100() + m = GPy.models.GPLVM(data['X'], 2) + + # optimize + m.ensure_default_constraints() + m.optimize(messages=1, max_iters=2) + + # plot + print(m) + #m.plot_latent(labels=data['Y'].argmax(axis=1)) + return m + +def brendan_faces(): + data = GPy.util.datasets.brendan_faces() + Y = data['Y'][0:-1:10, :] + m = GPy.models.GPLVM(data['Y'], 2) + + # optimize + m.ensure_default_constraints() + m.optimize(messages=1, max_f_eval=10000) + + ax = m.plot_latent() + y = m.likelihood.Y[0,:] + data_show = GPy.util.visualize.image_show(y[None, :], dimensions=(20, 28), transpose=True, invert=False, scale=False) + lvm_visualizer = GPy.util.visualize.lvm(m, data_show, ax) + raw_input('Press enter to finish') + plt.close('all') + + return m + +def stick(): + data = GPy.util.datasets.stick() + m = GPy.models.GPLVM(data['Y'], 2) + + # optimize + m.ensure_default_constraints() + m.optimize(messages=1, max_f_eval=10000) + + ax = m.plot_latent() + y = m.likelihood.Y[0,:] + data_show = GPy.util.visualize.stick_show(y[None, :], connect=data['connect']) + lvm_visualizer = GPy.util.visualize.lvm(m, data_show, ax) + raw_input('Press enter to finish') + plt.close('all') + + return m diff --git a/GPy/examples/regression.py b/GPy/examples/regression.py index 7de95d20..1a35df2f 100644 --- a/GPy/examples/regression.py +++ b/GPy/examples/regression.py @@ -73,7 +73,7 @@ def silhouette(): def coregionalisation_toy2(): """ - A simple demonstration of coregionalisation on two sinusoidal functions + A simple demonstration of coregionalisation on two sinusoidal functions. """ X1 = np.random.rand(50,1)*8 X2 = np.random.rand(30,1)*5 @@ -106,7 +106,7 @@ def coregionalisation_toy2(): def coregionalisation_toy(): """ - A simple demonstration of coregionalisation on two sinusoidal functions + A simple demonstration of coregionalisation on two sinusoidal functions. """ X1 = np.random.rand(50,1)*8 X2 = np.random.rand(30,1)*5 @@ -139,7 +139,7 @@ def coregionalisation_toy(): def coregionalisation_sparse(): """ - A simple demonstration of coregionalisation on two sinusoidal functions + A simple demonstration of coregionalisation on two sinusoidal functions using sparse approximations. """ X1 = np.random.rand(500,1)*8 X2 = np.random.rand(300,1)*5 diff --git a/GPy/kern/__init__.py b/GPy/kern/__init__.py index 5d8a7d15..f062ee56 100644 --- a/GPy/kern/__init__.py +++ b/GPy/kern/__init__.py @@ -2,5 +2,5 @@ # Licensed under the BSD 3-clause license (see LICENSE.txt) -from constructors import rbf, Matern32, Matern52, exponential, linear, white, bias, finite_dimensional, spline, Brownian, rbf_sympy, sympykern, periodic_exponential, periodic_Matern32, periodic_Matern52, prod, prod_orthogonal, symmetric, coregionalise, rational_quadratic, fixed +from constructors import rbf, Matern32, Matern52, exponential, linear, white, bias, finite_dimensional, spline, Brownian, rbf_sympy, sympykern, periodic_exponential, periodic_Matern32, periodic_Matern52, prod, prod_orthogonal, symmetric, coregionalise, rational_quadratic, fixed, rbfcos from kern import kern diff --git a/GPy/kern/constructors.py b/GPy/kern/constructors.py index 90b13600..6a968da4 100644 --- a/GPy/kern/constructors.py +++ b/GPy/kern/constructors.py @@ -24,6 +24,7 @@ from prod_orthogonal import prod_orthogonal as prod_orthogonalpart from symmetric import symmetric as symmetric_part from coregionalise import coregionalise as coregionalise_part from rational_quadratic import rational_quadratic as rational_quadraticpart +from rbfcos import rbfcos as rbfcospart #TODO these s=constructors are not as clean as we'd like. Tidy the code up #using meta-classes to make the objects construct properly wthout them. @@ -310,3 +311,10 @@ def fixed(D, K, variance=1.): """ part = fixedpart(D, K, variance) return kern(D, [part]) + +def rbfcos(D,variance=1.,frequencies=None,bandwidths=None,ARD=False): + """ + construct a rbfcos kernel + """ + part = rbfcospart(D,variance,frequencies,bandwidths,ARD) + return kern(D,[part]) diff --git a/GPy/kern/rbfcos.py b/GPy/kern/rbfcos.py new file mode 100644 index 00000000..094b806b --- /dev/null +++ b/GPy/kern/rbfcos.py @@ -0,0 +1,117 @@ + +# Copyright (c) 2012, James Hensman and Andrew Gordon Wilson +# Licensed under the BSD 3-clause license (see LICENSE.txt) + + +from kernpart import kernpart +import numpy as np + +class rbfcos(kernpart): + def __init__(self,D,variance=1.,frequencies=None,bandwidths=None,ARD=False): + self.D = D + self.name = 'rbfcos' + if self.D>10: + print "Warning: the rbfcos kernel requires a lot of memory for high dimensional inputs" + self.ARD = ARD + + #set the default frequencies and bandwidths, appropriate Nparam + if ARD: + self.Nparam = 2*self.D + 1 + if frequencies is not None: + frequencies = np.asarray(frequencies) + assert frequencies.size == self.D, "bad number of frequencies" + else: + frequencies = np.ones(self.D) + if bandwidths is not None: + bandwidths = np.asarray(bandwidths) + assert bandwidths.size == self.D, "bad number of bandwidths" + else: + bandwidths = np.ones(self.D) + else: + self.Nparam = 3 + if frequencies is not None: + frequencies = np.asarray(frequencies) + assert frequencies.size == 1, "Exactly one frequency needed for non-ARD kernel" + else: + frequencies = np.ones(1) + + if bandwidths is not None: + bandwidths = np.asarray(bandwidths) + assert bandwidths.size == 1, "Exactly one bandwidth needed for non-ARD kernel" + else: + bandwidths = np.ones(1) + + #initialise cache + self._X, self._X2, self._params = np.empty(shape=(3,1)) + + self._set_params(np.hstack((variance,frequencies.flatten(),bandwidths.flatten()))) + + + def _get_params(self): + return np.hstack((self.variance,self.frequencies, self.bandwidths)) + + def _set_params(self,x): + assert x.size==(self.Nparam) + if self.ARD: + self.variance = x[0] + self.frequencies = x[1:1+self.D] + self.bandwidths = x[1+self.D:] + else: + self.variance, self.frequencies, self.bandwidths = x + + def _get_param_names(self): + if self.Nparam == 3: + return ['variance','frequency','bandwidth'] + else: + return ['variance']+['frequency_%i'%i for i in range(self.D)]+['bandwidth_%i'%i for i in range(self.D)] + + def K(self,X,X2,target): + self._K_computations(X,X2) + target += self.variance*self._dvar + + def Kdiag(self,X,target): + np.add(target,self.variance,target) + + def dK_dtheta(self,dL_dK,X,X2,target): + self._K_computations(X,X2) + target[0] += np.sum(dL_dK*self._dvar) + if self.ARD: + for q in xrange(self.D): + target[q+1] += -2.*np.pi*self.variance*np.sum(dL_dK*self._dvar*np.tan(2.*np.pi*self._dist[:,:,q]*self.frequencies[q])*self._dist[:,:,q]) + target[q+1+self.D] += -2.*np.pi**2*self.variance*np.sum(dL_dK*self._dvar*self._dist2[:,:,q]) + else: + target[1] += -2.*np.pi*self.variance*np.sum(dL_dK*self._dvar*np.sum(np.tan(2.*np.pi*self._dist*self.frequencies)*self._dist,-1)) + target[2] += -2.*np.pi**2*self.variance*np.sum(dL_dK*self._dvar*self._dist2.sum(-1)) + + + def dKdiag_dtheta(self,dL_dKdiag,X,target): + target[0] += np.sum(dL_dKdiag) + + def dK_dX(self,dL_dK,X,X2,target): + #TODO!!! + raise NotImplementedError + + def dKdiag_dX(self,dL_dKdiag,X,target): + pass + + def _K_computations(self,X,X2): + if not (np.all(X==self._X) and np.all(X2==self._X2)): + if X2 is None: X2 = X + self._X = X.copy() + self._X2 = X2.copy() + + #do the distances: this will be high memory for large D + #NB: we don't take the abs of the dist because cos is symmetric + self._dist = X[:,None,:] - X2[None,:,:] + self._dist2 = np.square(self._dist) + + #ensure the next section is computed: + self._params = np.empty(self.Nparam) + + if not np.all(self._params == self._get_params()): + self._params == self._get_params().copy() + + self._rbf_part = np.exp(-2.*np.pi**2*np.sum(self._dist2*self.bandwidths,-1)) + self._cos_part = np.prod(np.cos(2.*np.pi*self._dist*self.frequencies),-1) + self._dvar = self._rbf_part*self._cos_part + diff --git a/GPy/models/GPLVM.py b/GPy/models/GPLVM.py index 32594594..03e9b715 100644 --- a/GPy/models/GPLVM.py +++ b/GPy/models/GPLVM.py @@ -81,6 +81,7 @@ class GPLVM(GP): k = [p for p in self.kern.parts if p.name in ['rbf','linear']] if (not len(k)==1) or (not k[0].ARD): raise ValueError, "cannot Atomatically determine which dimensions to plot, please pass 'which_indices'" + input_1, input_2 = self.lengthscale_order() k = k[0] if k.name=='rbf': input_1, input_2 = np.argsort(k.lengthscale)[:2] @@ -92,7 +93,7 @@ class GPLVM(GP): Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1])) Xtest_full[:, :2] = Xtest mu, var, low, up = self.predict(Xtest_full) - var = var[:, :2] + var = var.mean(axis=1) # this was var[:, :2] edit by Neil pb.imshow(var.reshape(resolution,resolution).T[::-1,:],extent=[xmin[0],xmax[0],xmin[1],xmax[1]],cmap=pb.cm.binary,interpolation='bilinear') @@ -122,4 +123,4 @@ class GPLVM(GP): pb.xlim(xmin[0],xmax[0]) pb.ylim(xmin[1],xmax[1]) - return input_1, input_2 + return pb.gca() #input_1, input_2 temporary removal, to return axes. diff --git a/GPy/util/__init__.py b/GPy/util/__init__.py index c91557d0..56dbd5b9 100644 --- a/GPy/util/__init__.py +++ b/GPy/util/__init__.py @@ -10,4 +10,6 @@ import Tango import misc import warping_functions import datasets +import mocap +import visualize import decorators diff --git a/GPy/util/datasets.py b/GPy/util/datasets.py index ed808f1b..932690ec 100644 --- a/GPy/util/datasets.py +++ b/GPy/util/datasets.py @@ -15,12 +15,12 @@ def sample_class(f): return c def della_gatta_TRP63_gene_expression(gene_number=None): - matData = scipy.io.loadmat(os.path.join(data_path, 'DellaGattadata.mat')) - X = np.double(matData['timepoints']) + mat_data = scipy.io.loadmat(os.path.join(data_path, 'DellaGattadata.mat')) + X = np.double(mat_data['timepoints']) if gene_number == None: - Y = matData['exprs_tp53_RMA'] + Y = mat_data['exprs_tp53_RMA'] else: - Y = matData['exprs_tp53_RMA'][:, gene_number] + Y = mat_data['exprs_tp53_RMA'][:, gene_number] if len(Y.shape) == 1: Y = Y[:, None] return {'X': X, 'Y': Y, 'info': "The full gene expression data set from della Gatta et al (http://www.ncbi.nlm.nih.gov/pmc/articles/PMC2413161/) processed by RMA."} @@ -60,28 +60,42 @@ def pumadyn(seed=default_seed): return {'X': X, 'Y': Y, 'Xtest': Xtest, 'Ytest': Ytest, 'info': "The puma robot arm data with 32 inputs. This data is the non linear case with medium noise (pumadyn-32nm). For training 7,168 examples are sampled without replacement."} +def brendan_faces(): + mat_data = scipy.io.loadmat(os.path.join(data_path, 'frey_rawface.mat')) + Y = mat_data['ff'].T + return {'Y': Y, 'info': "Face data made available by Brendan Frey"} + + + def silhouette(): # Ankur Agarwal and Bill Trigg's silhoutte data. - matData = scipy.io.loadmat(os.path.join(data_path, 'mocap', 'ankur', 'ankurDataPoseSilhouette.mat')) - inMean = np.mean(matData['Y']) - inScales = np.sqrt(np.var(matData['Y'])) - X = matData['Y'] - inMean + mat_data = scipy.io.loadmat(os.path.join(data_path, 'mocap', 'ankur', 'ankurDataPoseSilhouette.mat')) + inMean = np.mean(mat_data['Y']) + inScales = np.sqrt(np.var(mat_data['Y'])) + X = mat_data['Y'] - inMean X = X/inScales - Xtest = matData['Y_test'] - inMean + Xtest = mat_data['Y_test'] - inMean Xtest = Xtest/inScales - Y = matData['Z'] - Ytest = matData['Z_test'] + Y = mat_data['Z'] + Ytest = mat_data['Z_test'] return {'X': X, 'Y': Y, 'Xtest': Xtest, 'Ytest': Ytest, 'info': "Artificial silhouette simulation data developed from Agarwal and Triggs (2004)."} +def stick(): + Y, connect = GPy.util.mocap.load_text_data('run1', data_path) + Y = Y[0:-1:4, :] + lbls = 'connect' + return {'Y': Y, 'connect' : connect, 'info': "Stick man data from Ohio."} + + def swiss_roll_1000(): - matData = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data')) - Y = matData['X_data'][:, 0:1000].transpose() + mat_data = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data')) + Y = mat_data['X_data'][:, 0:1000].transpose() return {'Y': Y, 'info': "Subsample of the swiss roll data extracting only the first 1000 values."} def swiss_roll(): - matData = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data.mat')) - Y = matData['X_data'][:, 0:3000].transpose() + mat_data = scipy.io.loadmat(os.path.join(data_path, 'swiss_roll_data.mat')) + Y = mat_data['X_data'][:, 0:3000].transpose() return {'Y': Y, 'info': "The first 3,000 points from the swiss roll data of Tennenbaum, de Silva and Langford (2001)."} def toy_rbf_1d(seed=default_seed): @@ -202,3 +216,4 @@ def creep_data(): features.extend(range(2, 31)) X = all_data[:,features].copy() return {'X': X, 'y' : y} + diff --git a/GPy/util/linalg.py b/GPy/util/linalg.py index f21502a5..59f598f9 100644 --- a/GPy/util/linalg.py +++ b/GPy/util/linalg.py @@ -145,9 +145,10 @@ def PCA(Y, Q): """ if not np.allclose(Y.mean(axis=0), 0.0): print "Y is not zero mean, centering it locally (GPy.util.linalg.PCA)" - Y -= Y.mean(axis=0) + + #Y -= Y.mean(axis=0) - Z = linalg.svd(Y, full_matrices = False) + Z = linalg.svd(Y-Y.mean(axis=0), full_matrices = False) [X, W] = [Z[0][:,0:Q], np.dot(np.diag(Z[1]), Z[2]).T[:,0:Q]] v = X.std(axis=0) X /= v; diff --git a/GPy/util/mocap.py b/GPy/util/mocap.py new file mode 100644 index 00000000..e66a36b9 --- /dev/null +++ b/GPy/util/mocap.py @@ -0,0 +1,74 @@ +import os +import numpy as np + +def load_text_data(dataset, directory, centre=True): + """Load in a data set of marker points from the Ohio State University C3D motion capture files (http://accad.osu.edu/research/mocap/mocap_data.htm).""" + + points, point_names = parse_text(os.path.join(directory, dataset + '.txt'))[0:2] + # Remove markers where there is a NaN + present_index = [i for i in range(points[0].shape[1]) if not (np.any(np.isnan(points[0][:, i])) or np.any(np.isnan(points[0][:, i])) or np.any(np.isnan(points[0][:, i])))] + + point_names = point_names[present_index] + for i in range(3): + points[i] = points[i][:, present_index] + if centre: + points[i] = (points[i].T - points[i].mean(axis=1)).T + + # Concatanate the X, Y and Z markers together + Y = np.concatenate((points[0], points[1], points[2]), axis=1) + Y = Y/400. + connect = read_connections(os.path.join(directory, 'connections.txt'), point_names) + return Y, connect + +def parse_text(file_name): + """Parse data from Ohio State University text mocap files (http://accad.osu.edu/research/mocap/mocap_data.htm).""" + + # Read the header + fid = open(file_name, 'r') + point_names = np.array(fid.readline().split())[2:-1:3] + fid.close() + for i in range(len(point_names)): + point_names[i] = point_names[i][0:-2] + + # Read the matrix data + S = np.loadtxt(file_name, skiprows=1) + field = np.uint(S[:, 0]) + times = S[:, 1] + S = S[:, 2:] + + # Set the -9999.99 markers to be not present + S[S==-9999.99] = np.NaN + + # Store x, y and z in different arrays + points = [] + points.append(S[:, 0:-1:3]) + points.append(S[:, 1:-1:3]) + points.append(S[:, 2:-1:3]) + + return points, point_names, times + +def read_connections(file_name, point_names): + """Read a file detailing which markers should be connected to which for motion capture data.""" + + connections = [] + fid = open(file_name, 'r') + line=fid.readline() + while(line): + connections.append(np.array(line.split(','))) + connections[-1][0] = connections[-1][0].strip() + connections[-1][1] = connections[-1][1].strip() + line = fid.readline() + connect = np.zeros((len(point_names), len(point_names)),dtype=bool) + for i in range(len(point_names)): + for j in range(len(point_names)): + for k in range(len(connections)): + if connections[k][0] == point_names[i] and connections[k][1] == point_names[j]: + + connect[i,j]=True + connect[j,i]=True + break + + return connect + + + diff --git a/GPy/util/visualize.py b/GPy/util/visualize.py new file mode 100644 index 00000000..dde9cd32 --- /dev/null +++ b/GPy/util/visualize.py @@ -0,0 +1,164 @@ +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +import GPy +import numpy as np + +class lvm: + def __init__(self, model, data_visualize, latent_axis): + self.cid = latent_axis.figure.canvas.mpl_connect('button_press_event', self.on_click) + self.cid = latent_axis.figure.canvas.mpl_connect('motion_notify_event', self.on_move) + self.data_visualize = data_visualize + self.model = model + self.latent_axis = latent_axis + self.called = False + self.move_on = False + + def on_click(self, event): + #print 'click', event.xdata, event.ydata + if event.inaxes!=self.latent_axis: return + self.move_on = not self.move_on + # if self.called: + # self.xs.append(event.xdata) + # self.ys.append(event.ydata) + # self.line.set_data(self.xs, self.ys) + # self.line.figure.canvas.draw() + # else: + # self.xs = [event.xdata] + # self.ys = [event.ydata] + # self.line, = self.latent_axis.plot(event.xdata, event.ydata) + self.called = True + def on_move(self, event): + if event.inaxes!=self.latent_axis: return + if self.called and self.move_on: + # Call modify code on move + #print 'move', event.xdata, event.ydata + latent_values = np.array((event.xdata, event.ydata)) + y = self.model.predict(latent_values)[0] + self.data_visualize.modify(y) + #print 'y', y + +class data_show: + """The data show class is a base class which describes how to visualize a particular data set. For example, motion capture data can be plotted as a stick figure, or images are shown using imshow. This class enables latent to data visualizations for the GP-LVM.""" + + def __init__(self, vals, axis=None): + self.vals = vals + # If no axes are defined, create some. + if axis==None: + fig = plt.figure() + self.axis = fig.add_subplot(111) + else: + self.axis = axis + + def modify(self, vals): + raise NotImplementedError, "this needs to be implemented to use the data_show class" + +class vector_show(data_show): + """A base visualization class that just shows a data vector as a plot of vector elements alongside their indices.""" + def __init__(self, vals, axis=None): + data_show.__init__(self, vals, axis) + self.vals = vals.T + self.handle = plt.plot(np.arange(0, len(vals))[:, None], self.vals)[0] + + def modify(self, vals): + xdata, ydata = self.handle.get_data() + self.vals = vals.T + self.handle.set_data(xdata, self.vals) + self.axis.figure.canvas.draw() + +class image_show(data_show): + """Show a data vector as an image.""" + def __init__(self, vals, axis=None, dimensions=(16,16), transpose=False, invert=False, scale=False): + data_show.__init__(self, vals, axis) + self.dimensions = dimensions + self.transpose = transpose + self.invert = invert + self.scale = scale + self.set_image(vals/255.) + self.handle = self.axis.imshow(self.vals, cmap=plt.cm.gray, interpolation='nearest') + plt.show() + + def modify(self, vals): + self.set_image(vals/255.) + #self.handle.remove() + #self.handle = self.axis.imshow(self.vals) + self.handle.set_array(self.vals) + #self.axis.figure.canvas.draw() + plt.show() + + def set_image(self, vals): + self.vals = np.reshape(vals, self.dimensions, order='F') + if self.transpose: + self.vals = self.vals.T + if not self.scale: + self.vals = self.vals + #if self.invert: + # self.vals = -self.vals + +class stick_show(data_show): + """Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect.""" + + def __init__(self, vals, axis=None, connect=None): + if axis==None: + fig = plt.figure() + axis = fig.add_subplot(111, projection='3d') + data_show.__init__(self, vals, axis) + self.vals = vals.reshape((3, vals.shape[1]/3)).T + self.x_lim = np.array([self.vals[:, 0].min(), self.vals[:, 0].max()]) + self.y_lim = np.array([self.vals[:, 1].min(), self.vals[:, 1].max()]) + self.z_lim = np.array([self.vals[:, 2].min(), self.vals[:, 2].max()]) + self.points_handle = self.axis.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2]) + self.axis.set_xlim(self.x_lim) + self.axis.set_ylim(self.y_lim) + self.axis.set_zlim(self.z_lim) + self.axis.set_aspect(1) + self.axis.autoscale(enable=False) + + self.connect = connect + if not self.connect==None: + x = [] + y = [] + z = [] + self.I, self.J = np.nonzero(self.connect) + for i in range(len(self.I)): + x.append(self.vals[self.I[i], 0]) + x.append(self.vals[self.J[i], 0]) + x.append(np.NaN) + y.append(self.vals[self.I[i], 1]) + y.append(self.vals[self.J[i], 1]) + y.append(np.NaN) + z.append(self.vals[self.I[i], 2]) + z.append(self.vals[self.J[i], 2]) + z.append(np.NaN) + self.line_handle = self.axis.plot(np.array(x), np.array(y), np.array(z), 'b-') + self.axis.figure.canvas.draw() + + def modify(self, vals): + self.points_handle.remove() + self.line_handle[0].remove() + self.vals = vals.reshape((3, vals.shape[1]/3)).T + self.points_handle = self.axis.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2]) + self.axis.set_xlim(self.x_lim) + self.axis.set_ylim(self.y_lim) + self.axis.set_zlim(self.z_lim) + self.line_handle = [] + if not self.connect==None: + x = [] + y = [] + z = [] + self.I, self.J = np.nonzero(self.connect) + for i in range(len(self.I)): + x.append(self.vals[self.I[i], 0]) + x.append(self.vals[self.J[i], 0]) + x.append(np.NaN) + y.append(self.vals[self.I[i], 1]) + y.append(self.vals[self.J[i], 1]) + y.append(np.NaN) + z.append(self.vals[self.I[i], 2]) + z.append(self.vals[self.J[i], 2]) + z.append(np.NaN) + self.line_handle = self.axis.plot(np.array(x), np.array(y), np.array(z), 'b-') + + self.axis.figure.canvas.draw() + + +