mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
plotting labels are now in order as passed in and marker can be passed with as many markers as there are labels
This commit is contained in:
parent
cbdb89ffe8
commit
b7de50b5b3
1 changed files with 13 additions and 3 deletions
|
|
@ -60,7 +60,7 @@ class GPLVM(GP):
|
|||
mu, var, upper, lower = self.predict(Xnew)
|
||||
pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5)
|
||||
|
||||
def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None):
|
||||
def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
|
||||
"""
|
||||
:param labels: a np.array of size self.N containing labels for the points (can be number, strings, etc)
|
||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||
|
|
@ -94,13 +94,23 @@ class GPLVM(GP):
|
|||
ax.imshow(var.reshape(resolution, resolution).T,
|
||||
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
|
||||
|
||||
for i,ul in enumerate(np.unique(labels)):
|
||||
# make sure labels are in order of input:
|
||||
ulabels = []
|
||||
for lab in labels:
|
||||
if not lab in ulabels:
|
||||
ulabels.append(lab)
|
||||
|
||||
for i, ul in enumerate(ulabels):
|
||||
if type(ul) is np.string_:
|
||||
this_label = ul
|
||||
elif type(ul) is np.int64:
|
||||
this_label = 'class %i'%ul
|
||||
else:
|
||||
this_label = 'class %i'%i
|
||||
if len(marker) == len(ulabels):
|
||||
m = marker[i]
|
||||
else:
|
||||
m = marker
|
||||
|
||||
index = np.nonzero(labels==ul)[0]
|
||||
if self.Q==1:
|
||||
|
|
@ -109,7 +119,7 @@ class GPLVM(GP):
|
|||
else:
|
||||
x = self.X[index,input_1]
|
||||
y = self.X[index,input_2]
|
||||
ax.plot(x,y,marker='o',color=util.plot.Tango.nextMedium(),mew=0,label=this_label,linewidth=0)
|
||||
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), mew=1.3, label=this_label)
|
||||
|
||||
ax.set_xlabel('latent dimension %i'%input_1)
|
||||
ax.set_ylabel('latent dimension %i'%input_2)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue