mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
plotting labels are now in order as passed in and marker can be passed with as many markers as there are labels
This commit is contained in:
parent
cbdb89ffe8
commit
b7de50b5b3
1 changed files with 13 additions and 3 deletions
|
|
@ -60,7 +60,7 @@ class GPLVM(GP):
|
||||||
mu, var, upper, lower = self.predict(Xnew)
|
mu, var, upper, lower = self.predict(Xnew)
|
||||||
pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5)
|
pb.plot(mu[:,0], mu[:,1],'k',linewidth=1.5)
|
||||||
|
|
||||||
def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None):
|
def plot_latent(self, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40):
|
||||||
"""
|
"""
|
||||||
:param labels: a np.array of size self.N containing labels for the points (can be number, strings, etc)
|
:param labels: a np.array of size self.N containing labels for the points (can be number, strings, etc)
|
||||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||||
|
|
@ -94,13 +94,23 @@ class GPLVM(GP):
|
||||||
ax.imshow(var.reshape(resolution, resolution).T,
|
ax.imshow(var.reshape(resolution, resolution).T,
|
||||||
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
|
extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary,interpolation='bilinear',origin='lower')
|
||||||
|
|
||||||
for i,ul in enumerate(np.unique(labels)):
|
# make sure labels are in order of input:
|
||||||
|
ulabels = []
|
||||||
|
for lab in labels:
|
||||||
|
if not lab in ulabels:
|
||||||
|
ulabels.append(lab)
|
||||||
|
|
||||||
|
for i, ul in enumerate(ulabels):
|
||||||
if type(ul) is np.string_:
|
if type(ul) is np.string_:
|
||||||
this_label = ul
|
this_label = ul
|
||||||
elif type(ul) is np.int64:
|
elif type(ul) is np.int64:
|
||||||
this_label = 'class %i'%ul
|
this_label = 'class %i'%ul
|
||||||
else:
|
else:
|
||||||
this_label = 'class %i'%i
|
this_label = 'class %i'%i
|
||||||
|
if len(marker) == len(ulabels):
|
||||||
|
m = marker[i]
|
||||||
|
else:
|
||||||
|
m = marker
|
||||||
|
|
||||||
index = np.nonzero(labels==ul)[0]
|
index = np.nonzero(labels==ul)[0]
|
||||||
if self.Q==1:
|
if self.Q==1:
|
||||||
|
|
@ -109,7 +119,7 @@ class GPLVM(GP):
|
||||||
else:
|
else:
|
||||||
x = self.X[index,input_1]
|
x = self.X[index,input_1]
|
||||||
y = self.X[index,input_2]
|
y = self.X[index,input_2]
|
||||||
ax.plot(x,y,marker='o',color=util.plot.Tango.nextMedium(),mew=0,label=this_label,linewidth=0)
|
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), mew=1.3, label=this_label)
|
||||||
|
|
||||||
ax.set_xlabel('latent dimension %i'%input_1)
|
ax.set_xlabel('latent dimension %i'%input_1)
|
||||||
ax.set_ylabel('latent dimension %i'%input_2)
|
ax.set_ylabel('latent dimension %i'%input_2)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue