plot latent updated

This commit is contained in:
James Hensman 2014-02-26 09:16:31 +00:00
parent 3c0f89bf53
commit f6da5345d9
3 changed files with 32 additions and 203 deletions

View file

@ -2,6 +2,7 @@ import pylab as pb
import numpy as np
from latent_space_visualizations.controllers.imshow_controller import ImshowController,ImAnnotateController
from ...util.misc import param_to_array
from ...core.parameterization.variational import VariationalPosterior
from .base_plots import x_frame2D
import itertools
import Tango
@ -19,7 +20,7 @@ def most_significant_input_dimensions(model, which_indices):
input_1, input_2 = 0, 1
else:
try:
input_1, input_2 = np.argsort(model.input_sensitivity())[::-1][:2]
input_1, input_2 = np.argsort(model.kern.input_sensitivity())[::-1][:2]
except:
raise ValueError, "cannot automatically determine which dimensions to plot, please pass 'which_indices'"
else:
@ -43,26 +44,29 @@ def plot_latent(model, labels=None, which_indices=None,
labels = np.ones(model.num_data)
input_1, input_2 = most_significant_input_dimensions(model, which_indices)
X = param_to_array(model.X)
# first, plot the output variance as a function of the latent space
Xtest, xx, yy, xmin, xmax = x_frame2D(X[:, [input_1, input_2]], resolution=resolution)
Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
#fethch the data points X that we'd like to plot
X = model.X
if isinstance(X, VariationalPosterior):
X = param_to_array(X.mean)
else:
X = param_to_array(X)
# create a function which computes the shading of latent space according to the output variance
def plot_function(x):
Xtest_full = np.zeros((x.shape[0], model.X.shape[1]))
Xtest_full[:, [input_1, input_2]] = x
mu, var, low, up = model.predict(Xtest_full)
var = var[:, :1]
return np.log(var)
#Create an IMshow controller that can re-plot the latent space shading at a good resolution
view = ImshowController(ax, plot_function,
tuple(X[:, [input_1, input_2]].min(0)) + tuple(X[:, [input_1, input_2]].max(0)),
resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.binary)
# ax.imshow(var.reshape(resolution, resolution).T,
# extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary, interpolation='bilinear', origin='lower')
# make sure labels are in order of input:
ulabels = []
for lab in labels:
@ -95,8 +99,8 @@ def plot_latent(model, labels=None, which_indices=None,
if not np.all(labels == 1.) and legend:
ax.legend(loc=0, numpoints=1)
ax.set_xlim(xmin[0], xmax[0])
ax.set_ylim(xmin[1], xmax[1])
#ax.set_xlim(xmin[0], xmax[0])
#ax.set_ylim(xmin[1], xmax[1])
ax.grid(b=False) # remove the grid if present, it doesn't look good
ax.set_aspect('auto') # set a nice aspect ratio