mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-27 14:25:16 +02:00
Merge branch 'devel' into mrd
This commit is contained in:
commit
3baeeb1e35
6 changed files with 170 additions and 75 deletions
|
|
@ -42,12 +42,12 @@ def BGPLVM(seed=default_seed):
|
|||
|
||||
return m
|
||||
|
||||
def GPLVM_oil_100(optimize=True, M=15):
|
||||
def GPLVM_oil_100(optimize=True):
|
||||
data = GPy.util.datasets.oil_100()
|
||||
|
||||
# create simple GP model
|
||||
kernel = GPy.kern.rbf(6, ARD=True) + GPy.kern.bias(6)
|
||||
m = GPy.models.GPLVM(data['X'], 6, kernel=kernel, M=M)
|
||||
kernel = GPy.kern.rbf(6, ARD = True) + GPy.kern.bias(6)
|
||||
m = GPy.models.GPLVM(data['X'], 6, kernel=kernel)
|
||||
m.data_labels = data['Y'].argmax(axis=1)
|
||||
|
||||
# optimize
|
||||
|
|
|
|||
|
|
@ -41,20 +41,21 @@ class Brownian(kernpart):
|
|||
def Kdiag(self,X,target):
|
||||
target += self.variance*X.flatten()
|
||||
|
||||
def dK_dtheta(self,X,X2,target):
|
||||
target += np.fmin(X,X2.T)
|
||||
def dK_dtheta(self,dL_dK,X,X2,target):
|
||||
target += np.sum(np.fmin(X,X2.T)*dL_dK)
|
||||
|
||||
def dKdiag_dtheta(self,X,target):
|
||||
target += X.flatten()
|
||||
def dKdiag_dtheta(self,dL_dKdiag,X,target):
|
||||
target += np.dot(X.flatten(), dL_dKdiag)
|
||||
|
||||
def dK_dX(self,X,X2,target):
|
||||
target += self.variance
|
||||
target -= self.variance*theta(X-X2.T)
|
||||
if X.shape==X2.shape:
|
||||
if np.all(X==X2):
|
||||
np.add(target[:,:,0],self.variance*np.diag(X2.flatten()-X.flatten()),target[:,:,0])
|
||||
def dK_dX(self,dL_dK,X,X2,target):
|
||||
raise NotImplementedError, "TODO"
|
||||
#target += self.variance
|
||||
#target -= self.variance*theta(X-X2.T)
|
||||
#if X.shape==X2.shape:
|
||||
#if np.all(X==X2):
|
||||
#np.add(target[:,:,0],self.variance*np.diag(X2.flatten()-X.flatten()),target[:,:,0])
|
||||
|
||||
|
||||
def dKdiag_dX(self,X,target):
|
||||
target += self.variance
|
||||
def dKdiag_dX(self,dL_dKdiag,X,target):
|
||||
target += self.variance*dL_dKdiag[:,None]
|
||||
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ class rbf(kernpart):
|
|||
"""Shape N,M,M,Ntheta"""
|
||||
self._psi_computations(Z,mu,S)
|
||||
d_var = 2.*self._psi2/self.variance
|
||||
d_length = self._psi2[:,:,:,None]*(0.5*self._psi2_Zdist_sq*self._psi2_denom + 2.*self._psi2_mudist_sq + 2.*S[:,None,None,:]/self.lengthscale2)/(self.lengthscale*self._psi2_denom)
|
||||
d_length = 2.*self._psi2[:,:,:,None]*(self._psi2_Zdist_sq*self._psi2_denom + self._psi2_mudist_sq + S[:,None,None,:]/self.lengthscale2)/(self.lengthscale*self._psi2_denom)
|
||||
|
||||
target[0] += np.sum(dL_dpsi2*d_var)
|
||||
dpsi2_dlength = d_length*dL_dpsi2[:,:,:,None]
|
||||
|
|
@ -164,7 +164,7 @@ class rbf(kernpart):
|
|||
|
||||
def dpsi2_dZ(self,dL_dpsi2,Z,mu,S,target):
|
||||
self._psi_computations(Z,mu,S)
|
||||
term1 = 0.5*self._psi2_Zdist/self.lengthscale2 # M, M, Q
|
||||
term1 = self._psi2_Zdist/self.lengthscale2 # M, M, Q
|
||||
term2 = self._psi2_mudist/self._psi2_denom/self.lengthscale2 # N, M, M, Q
|
||||
dZ = self._psi2[:,:,:,None] * (term1[None] + term2)
|
||||
target += (dL_dpsi2[:,:,:,None]*dZ).sum(0).sum(0)
|
||||
|
|
@ -200,8 +200,8 @@ class rbf(kernpart):
|
|||
if not np.all(Z==self._Z):
|
||||
#Z has changed, compute Z specific stuff
|
||||
self._psi2_Zhat = 0.5*(Z[:,None,:] +Z[None,:,:]) # M,M,Q
|
||||
self._psi2_Zdist = Z[:,None,:]-Z[None,:,:] # M,M,Q
|
||||
self._psi2_Zdist_sq = np.square(self._psi2_Zdist)/self.lengthscale2 # M,M,Q
|
||||
self._psi2_Zdist = 0.5*(Z[:,None,:]-Z[None,:,:]) # M,M,Q
|
||||
self._psi2_Zdist_sq = np.square(self._psi2_Zdist/self.lengthscale) # M,M,Q
|
||||
self._Z = Z
|
||||
|
||||
if not (np.all(Z==self._Z) and np.all(mu==self._mu) and np.all(S==self._S)):
|
||||
|
|
@ -219,7 +219,7 @@ class rbf(kernpart):
|
|||
self._psi2_mudist, self._psi2_mudist_sq, self._psi2_exponent, _ = self.weave_psi2(mu,self._psi2_Zhat)
|
||||
#self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q
|
||||
#self._psi2_mudist_sq = np.square(self._psi2_mudist)/(self.lengthscale2*self._psi2_denom)
|
||||
#self._psi2_exponent = np.sum(-self._psi2_Zdist_sq/4. -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M
|
||||
#self._psi2_exponent = np.sum(-self._psi2_Zdist_sq -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M
|
||||
self._psi2 = np.square(self.variance)*np.exp(self._psi2_exponent) # N,M,M
|
||||
|
||||
#store matrices for caching
|
||||
|
|
@ -239,13 +239,13 @@ class rbf(kernpart):
|
|||
psi2 = np.empty((N,M,M))
|
||||
|
||||
psi2_Zdist_sq = self._psi2_Zdist_sq
|
||||
half_log_psi2_denom = 0.5*np.log(self._psi2_denom).squeeze()
|
||||
_psi2_denom = self._psi2_denom.squeeze().reshape(N,self.D)
|
||||
half_log_psi2_denom = 0.5*np.log(self._psi2_denom).squeeze().reshape(N,self.D)
|
||||
variance_sq = float(np.square(self.variance))
|
||||
if self.ARD:
|
||||
lengthscale2 = self.lengthscale2
|
||||
else:
|
||||
lengthscale2 = np.ones(Q)*self.lengthscale2
|
||||
_psi2_denom = self._psi2_denom.squeeze()
|
||||
code = """
|
||||
double tmp;
|
||||
|
||||
|
|
@ -265,7 +265,7 @@ class rbf(kernpart):
|
|||
mudist_sq(n,mm,m,q) = tmp;
|
||||
|
||||
//now psi2_exponent
|
||||
tmp = -psi2_Zdist_sq(m,mm,q)/4.0 - tmp - half_log_psi2_denom(n,q);
|
||||
tmp = -psi2_Zdist_sq(m,mm,q) - tmp - half_log_psi2_denom(n,q);
|
||||
psi2_exponent(n,mm,m) += tmp;
|
||||
if (m !=mm){
|
||||
psi2_exponent(n,m,mm) += tmp;
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class GPLVM(GP):
|
|||
Xtest_full = np.zeros((Xtest.shape[0], self.X.shape[1]))
|
||||
Xtest_full[:, :2] = Xtest
|
||||
mu, var, low, up = self.predict(Xtest_full)
|
||||
var = var[:, :1]
|
||||
var = var[:, :1]
|
||||
ax.imshow(var.reshape(resolution, resolution).T[::-1, :],
|
||||
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear')
|
||||
|
||||
|
|
@ -109,17 +109,16 @@ class GPLVM(GP):
|
|||
else:
|
||||
x = self.X[index,input_1]
|
||||
y = self.X[index,input_2]
|
||||
pb.plot(x,y,marker='o',color=util.plot.Tango.nextMedium(),mew=0,label=this_label,linewidth=0)
|
||||
ax.plot(x,y,marker='o',color=util.plot.Tango.nextMedium(),mew=0,label=this_label,linewidth=0)
|
||||
|
||||
ax.set_xlabel('latent dimension %i' % input_1)
|
||||
ax.set_ylabel('latent dimension %i' % input_2)
|
||||
ax.set_xlabel('latent dimension %i'%input_1)
|
||||
ax.set_ylabel('latent dimension %i'%input_2)
|
||||
|
||||
if not np.all(labels==1.):
|
||||
ax.legend(loc=0, numpoints=1)
|
||||
ax.legend(loc=0,numpoints=1)
|
||||
|
||||
ax.set_xlim(xmin[0], xmax[0])
|
||||
ax.set_ylim(xmin[1], xmax[1])
|
||||
ax.grid(b=False) # remove the grid if present, it doesn't look good
|
||||
# ax = pb.gca()
|
||||
ax.set_xlim(xmin[0],xmax[0])
|
||||
ax.set_ylim(xmin[1],xmax[1])
|
||||
ax.grid(b=False) # remove the grid if present, it doesn't look good
|
||||
ax.set_aspect('auto') # set a nice aspect ratio
|
||||
return ax
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ class sparse_GP(GP):
|
|||
self.V = (self.likelihood.precision/self.scale_factor)*self.likelihood.Y
|
||||
|
||||
#Compute A = L^-1 psi2 beta L^-T
|
||||
#self. A = mdot(self.Lmi,self.psi2_beta_scaled,self.Lmi.T)
|
||||
tmp = linalg.lapack.flapack.dtrtrs(self.Lm,self.psi2_beta_scaled.T,lower=1)[0]
|
||||
self.A = linalg.lapack.flapack.dtrtrs(self.Lm,np.asarray(tmp.T,order='F'),lower=1)[0]
|
||||
|
||||
|
|
@ -142,9 +143,11 @@ class sparse_GP(GP):
|
|||
|
||||
|
||||
# Compute dL_dKmm
|
||||
self.dL_dKmm = -0.5 * self.D * mdot(self.Lmi.T, self.A, self.Lmi)*sf2 # dB
|
||||
#self.dL_dKmm_old = -0.5 * self.D * mdot(self.Lmi.T, self.A, self.Lmi)*sf2 # dB
|
||||
#self.dL_dKmm += -0.5 * self.D * (- self.C/sf2 - 2.*mdot(self.C, self.psi2_beta_scaled, self.Kmmi) + self.Kmmi) # dC
|
||||
#self.dL_dKmm += np.dot(np.dot(self.E*sf2, self.psi2_beta_scaled) - self.Cpsi1VVpsi1, self.Kmmi) + 0.5*self.E # dD
|
||||
tmp = linalg.lapack.flapack.dtrtrs(self.Lm,np.asfortranarray(self.A),lower=1,trans=1)[0]
|
||||
self.dL_dKmm = -0.5*self.D*sf2*linalg.lapack.flapack.dtrtrs(self.Lm,np.asfortranarray(tmp.T),lower=1,trans=1)[0] #dA
|
||||
self.dL_dKmm += 0.5*(self.D*(self.C/sf2 -self.Kmmi) + self.E) + np.dot(np.dot(self.D*self.C + self.E*sf2,self.psi2_beta_scaled) - self.Cpsi1VVpsi1,self.Kmmi) # d(C+D)
|
||||
|
||||
#the partial derivative vector for the likelihood
|
||||
|
|
@ -186,6 +189,11 @@ class sparse_GP(GP):
|
|||
self._compute_kernel_matrices()
|
||||
if self.auto_scale_factor:
|
||||
self.scale_factor = np.sqrt(self.psi2.sum(0).mean()*self.likelihood.precision)
|
||||
#if self.auto_scale_factor:
|
||||
# if self.likelihood.is_heteroscedastic:
|
||||
# self.scale_factor = max(1,np.sqrt(self.psi2_beta_scaled.sum(0).mean()))
|
||||
# else:
|
||||
# self.scale_factor = np.sqrt(self.psi2.sum(0).mean()*self.likelihood.precision)
|
||||
self._computations()
|
||||
|
||||
def _get_params(self):
|
||||
|
|
|
|||
|
|
@ -2,22 +2,37 @@ import matplotlib.pyplot as plt
|
|||
from mpl_toolkits.mplot3d import Axes3D
|
||||
import GPy
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
|
||||
class lvm:
|
||||
def __init__(self, model, data_visualize, latent_axis, latent_index=[0,1], latent_dim=2):
|
||||
self.cid = latent_axis.figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
self.cid = latent_axis.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||
def __init__(self, model, data_visualize, latent_axes, latent_index=[0,1]):
|
||||
if isinstance(latent_axes,mpl.axes.Axes):
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
||||
else:
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
||||
self.data_visualize = data_visualize
|
||||
self.model = model
|
||||
self.latent_axis = latent_axis
|
||||
self.latent_axes = latent_axes
|
||||
|
||||
self.called = False
|
||||
self.move_on = False
|
||||
self.latent_index = latent_index
|
||||
self.latent_dim = latent_dim
|
||||
self.latent_dim = model.Q
|
||||
|
||||
def on_enter(self,event):
|
||||
pass
|
||||
def on_leave(self,event):
|
||||
pass
|
||||
|
||||
def on_click(self, event):
|
||||
#print 'click', event.xdata, event.ydata
|
||||
if event.inaxes!=self.latent_axis: return
|
||||
if event.inaxes!=self.latent_axes: return
|
||||
self.move_on = not self.move_on
|
||||
# if self.called:
|
||||
# self.xs.append(event.xdata)
|
||||
|
|
@ -27,10 +42,10 @@ class lvm:
|
|||
# else:
|
||||
# self.xs = [event.xdata]
|
||||
# self.ys = [event.ydata]
|
||||
# self.line, = self.latent_axis.plot(event.xdata, event.ydata)
|
||||
# self.line, = self.latent_axes.plot(event.xdata, event.ydata)
|
||||
self.called = True
|
||||
def on_move(self, event):
|
||||
if event.inaxes!=self.latent_axis: return
|
||||
if event.inaxes!=self.latent_axes: return
|
||||
if self.called and self.move_on:
|
||||
# Call modify code on move
|
||||
#print 'move', event.xdata, event.ydata
|
||||
|
|
@ -40,52 +55,124 @@ class lvm:
|
|||
self.data_visualize.modify(y)
|
||||
#print 'y', y
|
||||
|
||||
class data_show:
|
||||
"""The data show class is a base class which describes how to visualize a particular data set. For example, motion capture data can be plotted as a stick figure, or images are shown using imshow. This class enables latent to data visualizations for the GP-LVM."""
|
||||
class lvm_subplots(lvm):
|
||||
"""
|
||||
latent_axes is a np array of dimension np.ceil(Q/2) + 1,
|
||||
one for each pair of the axes, and the last one for the sensitiity histogram
|
||||
"""
|
||||
def __init__(self, model, data_visualize, latent_axes=None, latent_index=[0,1]):
|
||||
self.nplots = int(np.ceil(model.Q/2.))+1
|
||||
lvm.__init__(self,model,data_visualize,latent_axes,latent_index)
|
||||
self.latent_values = np.zeros(2*np.ceil(self.model.Q/2.)) # possibly an extra dimension on this
|
||||
assert latent_axes.size == self.nplots
|
||||
|
||||
def __init__(self, vals, axis=None):
|
||||
|
||||
class lvm_dimselect(lvm):
|
||||
"""
|
||||
A visualizer for latent variable models
|
||||
with selection by clicking on the histogram
|
||||
"""
|
||||
def __init__(self, model, data_visualize):
|
||||
self.fig,(latent_axes,self.hist_axes) = plt.subplots(1,2)
|
||||
|
||||
lvm.__init__(self,model,data_visualize,latent_axes,[0,1])
|
||||
self.latent_values_clicked = np.zeros(model.Q)
|
||||
self.clicked_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
|
||||
print "use left and right mouse butons to select dimensions"
|
||||
|
||||
def on_click(self, event):
|
||||
#print "click"
|
||||
if event.inaxes==self.hist_axes:
|
||||
self.hist_axes.cla()
|
||||
self.hist_axes.bar(np.arange(self.model.Q),1./self.model.input_sensitivity(),color='b')
|
||||
new_index = max(0,min(int(np.round(event.xdata-0.5)),self.model.Q-1))
|
||||
self.latent_index[(0 if event.button==1 else 1)] = new_index
|
||||
self.hist_axes.bar(np.array(self.latent_index),1./self.model.input_sensitivity()[self.latent_index],color='r')
|
||||
self.latent_axes.cla()
|
||||
self.model.plot_latent(which_indices = self.latent_index,ax=self.latent_axes)
|
||||
self.clicked_handle = self.latent_axes.plot([self.latent_values_clicked[self.latent_index[0]]],self.latent_values_clicked[self.latent_index[1]],'rx',mew=2)[0]
|
||||
if event.inaxes==self.latent_axes:
|
||||
self.clicked_handle.set_visible(False)
|
||||
self.latent_values_clicked[self.latent_index] = np.array([event.xdata,event.ydata])
|
||||
self.clicked_handle = self.latent_axes.plot([self.latent_values_clicked[self.latent_index[0]]],self.latent_values_clicked[self.latent_index[1]],'rx',mew=2)[0]
|
||||
self.fig.canvas.draw()
|
||||
self.move_on=True
|
||||
self.called = True
|
||||
|
||||
|
||||
def on_move(self, event):
|
||||
#print "move"
|
||||
if event.inaxes!=self.latent_axes: return
|
||||
if self.called and self.move_on:
|
||||
latent_values = self.latent_values_clicked.copy()
|
||||
latent_values[self.latent_index] = np.array([event.xdata, event.ydata])
|
||||
y = self.model.predict(latent_values[None,:])[0]
|
||||
self.data_visualize.modify(y)
|
||||
|
||||
def on_leave(self,event):
|
||||
latent_values = self.latent_values_clicked.copy()
|
||||
y = self.model.predict(latent_values[None,:])[0]
|
||||
self.data_visualize.modify(y)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class data_show:
|
||||
"""
|
||||
The data show class is a base class which describes how to visualize a
|
||||
particular data set. For example, motion capture data can be plotted as a
|
||||
stick figure, or images are shown using imshow. This class enables latent
|
||||
to data visualizations for the GP-LVM.
|
||||
"""
|
||||
|
||||
def __init__(self, vals, axes=None):
|
||||
self.vals = vals
|
||||
# If no axes are defined, create some.
|
||||
if axis==None:
|
||||
if axes==None:
|
||||
fig = plt.figure()
|
||||
self.axis = fig.add_subplot(111)
|
||||
self.axes = fig.add_subplot(111)
|
||||
else:
|
||||
self.axis = axis
|
||||
self.axes = axes
|
||||
|
||||
def modify(self, vals):
|
||||
raise NotImplementedError, "this needs to be implemented to use the data_show class"
|
||||
|
||||
class vector_show(data_show):
|
||||
"""A base visualization class that just shows a data vector as a plot of vector elements alongside their indices."""
|
||||
def __init__(self, vals, axis=None):
|
||||
data_show.__init__(self, vals, axis)
|
||||
"""
|
||||
A base visualization class that just shows a data vector as a plot of
|
||||
vector elements alongside their indices.
|
||||
"""
|
||||
def __init__(self, vals, axes=None):
|
||||
data_show.__init__(self, vals, axes)
|
||||
self.vals = vals.T
|
||||
self.handle = self.axis.plot(np.arange(0, len(vals))[:, None], self.vals)[0]
|
||||
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals)[0]
|
||||
|
||||
def modify(self, vals):
|
||||
xdata, ydata = self.handle.get_data()
|
||||
self.vals = vals.T
|
||||
self.handle.set_data(xdata, self.vals)
|
||||
self.axis.figure.canvas.draw()
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
class image_show(data_show):
|
||||
"""Show a data vector as an image."""
|
||||
def __init__(self, vals, axis=None, dimensions=(16,16), transpose=False, invert=False, scale=False):
|
||||
data_show.__init__(self, vals, axis)
|
||||
def __init__(self, vals, axes=None, dimensions=(16,16), transpose=False, invert=False, scale=False):
|
||||
data_show.__init__(self, vals, axes)
|
||||
self.dimensions = dimensions
|
||||
self.transpose = transpose
|
||||
self.invert = invert
|
||||
self.scale = scale
|
||||
self.set_image(vals/255.)
|
||||
self.handle = self.axis.imshow(self.vals, cmap=plt.cm.gray, interpolation='nearest')
|
||||
self.handle = self.axes.imshow(self.vals, cmap=plt.cm.gray, interpolation='nearest')
|
||||
plt.show()
|
||||
|
||||
def modify(self, vals):
|
||||
self.set_image(vals/255.)
|
||||
#self.handle.remove()
|
||||
#self.handle = self.axis.imshow(self.vals)
|
||||
#self.handle = self.axes.imshow(self.vals)
|
||||
self.handle.set_array(self.vals)
|
||||
#self.axis.figure.canvas.draw()
|
||||
#self.axes.figure.canvas.draw()
|
||||
plt.show()
|
||||
|
||||
def set_image(self, vals):
|
||||
|
|
@ -100,21 +187,21 @@ class image_show(data_show):
|
|||
class stick_show(data_show):
|
||||
"""Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect."""
|
||||
|
||||
def __init__(self, vals, axis=None, connect=None):
|
||||
if axis==None:
|
||||
def __init__(self, vals, axes=None, connect=None):
|
||||
if axes==None:
|
||||
fig = plt.figure()
|
||||
axis = fig.add_subplot(111, projection='3d')
|
||||
data_show.__init__(self, vals, axis)
|
||||
axes = fig.add_subplot(111, projection='3d')
|
||||
data_show.__init__(self, vals, axes)
|
||||
self.vals = vals.reshape((3, vals.shape[1]/3)).T
|
||||
self.x_lim = np.array([self.vals[:, 0].min(), self.vals[:, 0].max()])
|
||||
self.y_lim = np.array([self.vals[:, 1].min(), self.vals[:, 1].max()])
|
||||
self.z_lim = np.array([self.vals[:, 2].min(), self.vals[:, 2].max()])
|
||||
self.points_handle = self.axis.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
||||
self.axis.set_xlim(self.x_lim)
|
||||
self.axis.set_ylim(self.y_lim)
|
||||
self.axis.set_zlim(self.z_lim)
|
||||
self.axis.set_aspect(1)
|
||||
self.axis.autoscale(enable=False)
|
||||
self.points_handle = self.axes.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
||||
self.axes.set_xlim(self.x_lim)
|
||||
self.axes.set_ylim(self.y_lim)
|
||||
self.axes.set_zlim(self.z_lim)
|
||||
self.axes.set_aspect(1)
|
||||
self.axes.autoscale(enable=False)
|
||||
|
||||
self.connect = connect
|
||||
if not self.connect==None:
|
||||
|
|
@ -132,17 +219,17 @@ class stick_show(data_show):
|
|||
z.append(self.vals[self.I[i], 2])
|
||||
z.append(self.vals[self.J[i], 2])
|
||||
z.append(np.NaN)
|
||||
self.line_handle = self.axis.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
||||
self.axis.figure.canvas.draw()
|
||||
self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
def modify(self, vals):
|
||||
self.points_handle.remove()
|
||||
self.line_handle[0].remove()
|
||||
self.vals = vals.reshape((3, vals.shape[1]/3)).T
|
||||
self.points_handle = self.axis.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
||||
self.axis.set_xlim(self.x_lim)
|
||||
self.axis.set_ylim(self.y_lim)
|
||||
self.axis.set_zlim(self.z_lim)
|
||||
self.points_handle = self.axes.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
||||
self.axes.set_xlim(self.x_lim)
|
||||
self.axes.set_ylim(self.y_lim)
|
||||
self.axes.set_zlim(self.z_lim)
|
||||
self.line_handle = []
|
||||
if not self.connect==None:
|
||||
x = []
|
||||
|
|
@ -159,9 +246,9 @@ class stick_show(data_show):
|
|||
z.append(self.vals[self.I[i], 2])
|
||||
z.append(self.vals[self.J[i], 2])
|
||||
z.append(np.NaN)
|
||||
self.line_handle = self.axis.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
||||
self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
||||
|
||||
self.axis.figure.canvas.draw()
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue