plotting labels are now in order as passed in and marker can be passed with as many markers as there are labels

This commit is contained in:
Max Zwiessele 2013-05-22 12:38:03 +01:00
parent cbdb89ffe8
commit b7de50b5b3

View file

@ -60,7 +60,7 @@ class GPLVM(GP):
mu, var, upper, lower = self.predict(Xnew) mu, var, upper, lower = self.predict(Xnew)
pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5) pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5)
def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None): def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
""" """
:param labels: a np.array of size self.N containing labels for the points (can be number, strings, etc) :param labels: a np.array of size self.N containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance :param resolution: the resolution of the grid on which to evaluate the predictive variance
@ -94,13 +94,23 @@ class GPLVM(GP):
ax.imshow(var.reshape(resolution, resolution).T, ax.imshow(var.reshape(resolution, resolution).T,
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower') extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
for i,ul in enumerate(np.unique(labels)): # make sure labels are in order of input:
ulabels = []
for lab in labels:
if not lab in ulabels:
ulabels.append(lab)
for i, ul in enumerate(ulabels):
if type(ul) is np.string_: if type(ul) is np.string_:
this_label = ul this_label = ul
elif type(ul) is np.int64: elif type(ul) is np.int64:
this_label = 'class %i'%ul this_label = 'class %i'%ul
else: else:
this_label = 'class %i'%i this_label = 'class %i'%i
if len(marker) == len(ulabels):
m = marker[i]
else:
m = marker
index = np.nonzero(labels==ul)[0] index = np.nonzero(labels==ul)[0]
if self.Q==1: if self.Q==1:
@ -109,7 +119,7 @@ class GPLVM(GP):
else: else:
x = self.X[index,input_1] x = self.X[index,input_1]
y = self.X[index,input_2] y = self.X[index,input_2]
ax.plot(x,y,marker='o',color=util.plot.Tango.nextMedium(),mew=0,label=this_label,linewidth=0) ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), mew=1.3, label=this_label)
ax.set_xlabel('latent dimension %i'%input_1) ax.set_xlabel('latent dimension %i'%input_1)
ax.set_ylabel('latent dimension %i'%input_2) ax.set_ylabel('latent dimension %i'%input_2)