From b7de50b5b3c3b50ec424d57c0f243bc1ced4f47f Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Wed, 22 May 2013 12:38:03 +0100 Subject: [PATCH] plotting labels are now in order as passed in and marker can be passed with as many markers as there are labels --- GPy/models/GPLVM.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/GPy/models/GPLVM.py b/GPy/models/GPLVM.py index 7543e111..ff2be732 100644 --- a/GPy/models/GPLVM.py +++ b/GPy/models/GPLVM.py @@ -60,7 +60,7 @@ class GPLVM(GP): mu, var, upper, lower = self.predict(Xnew) pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5) - def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None): + def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40): """ :param labels: a np.array of size self.N containing labels for the points (can be number, strings, etc) :param resolution: the resolution of the grid on which to evaluate the predictive variance @@ -94,13 +94,23 @@ class GPLVM(GP): ax.imshow(var.reshape(resolution, resolution).T, extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower') - for i,ul in enumerate(np.unique(labels)): + # make sure labels are in order of input: + ulabels = [] + for lab in labels: + if not lab in ulabels: + ulabels.append(lab) + + for i, ul in enumerate(ulabels): if type(ul) is np.string_: this_label = ul elif type(ul) is np.int64: this_label = 'class %i'%ul else: this_label = 'class %i'%i + if len(marker) == len(ulabels): + m = marker[i] + else: + m = marker index = np.nonzero(labels==ul)[0] if self.Q==1: @@ -109,7 +119,7 @@ class GPLVM(GP): else: x = self.X[index,input_1] y = self.X[index,input_2] - ax.plot(x,y,marker='o',color=util.plot.Tango.nextMedium(),mew=0,label=this_label,linewidth=0) + ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), mew=1.3, label=this_label) ax.set_xlabel('latent dimension %i'%input_1) ax.set_ylabel('latent dimension %i'%input_2)