plotting, allot of plotting

This commit is contained in:
mzwiessele 2014-03-20 16:20:39 +00:00
parent 6f9c97ee72
commit ce728d8465
9 changed files with 97 additions and 39 deletions

View file

@ -64,8 +64,8 @@ class SparseGP(GP):
self.kern.gradient += target self.kern.gradient += target
#gradients wrt Z #gradients wrt Z
self.Z.gradient[:,self.kern.active_dims] = self.kern.gradients_X(dL_dKmm, self.Z) self.Z.gradient = self.kern.gradients_X(dL_dKmm, self.Z)
self.Z.gradient[:,self.kern.active_dims] += self.kern.gradients_Z_expectations( self.Z.gradient += self.kern.gradients_Z_expectations(
self.grad_dict['dL_dpsi1'], self.grad_dict['dL_dpsi2'], Z=self.Z, variational_posterior=self.X) self.grad_dict['dL_dpsi1'], self.grad_dict['dL_dpsi2'], Z=self.Z, variational_posterior=self.X)
else: else:
#gradients wrt kernel #gradients wrt kernel
@ -76,8 +76,8 @@ class SparseGP(GP):
self.kern.update_gradients_full(self.grad_dict['dL_dKmm'], self.Z, None) self.kern.update_gradients_full(self.grad_dict['dL_dKmm'], self.Z, None)
self.kern.gradient += target self.kern.gradient += target
#gradients wrt Z #gradients wrt Z
self.Z.gradient[:,self.kern.active_dims] = self.kern.gradients_X(self.grad_dict['dL_dKmm'], self.Z) self.Z.gradient = self.kern.gradients_X(self.grad_dict['dL_dKmm'], self.Z)
self.Z.gradient[:,self.kern.active_dims] += self.kern.gradients_X(self.grad_dict['dL_dKnm'].T, self.Z, self.X) self.Z.gradient += self.kern.gradients_X(self.grad_dict['dL_dKnm'].T, self.Z, self.X)
def _raw_predict(self, Xnew, full_cov=False): def _raw_predict(self, Xnew, full_cov=False):
""" """

View file

@ -160,6 +160,7 @@ def swiss_roll(optimize=True, verbose=1, plot=True, N=1000, num_inducing=15, Q=4
def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40, max_iters=1000, **k): def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40, max_iters=1000, **k):
import GPy import GPy
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from ..util.misc import param_to_array
_np.random.seed(0) _np.random.seed(0)
data = GPy.util.datasets.oil() data = GPy.util.datasets.oil()
@ -173,11 +174,11 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40,
m.optimize('scg', messages=verbose, max_iters=max_iters, gtol=.05) m.optimize('scg', messages=verbose, max_iters=max_iters, gtol=.05)
if plot: if plot:
y = m.Y[0, :] y = m.Y
fig, (latent_axes, sense_axes) = plt.subplots(1, 2) fig, (latent_axes, sense_axes) = plt.subplots(1, 2)
m.plot_latent(ax=latent_axes) m.plot_latent(ax=latent_axes)
data_show = GPy.plotting.matplot_dep.visualize.vector_show(y) data_show = GPy.plotting.matplot_dep.visualize.vector_show(y)
lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(m.X[0, :], # @UnusedVariable lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean), # @UnusedVariable
m, data_show, latent_axes=latent_axes, sense_axes=sense_axes) m, data_show, latent_axes=latent_axes, sense_axes=sense_axes)
raw_input('Press enter to finish') raw_input('Press enter to finish')
plt.close(fig) plt.close(fig)

View file

@ -312,5 +312,4 @@ class Linear(Kern):
return np.dot(ZA, inner).swapaxes(0, 1) # NOTE: self.ZAinner \in [num_inducing x num_data x input_dim]! return np.dot(ZA, inner).swapaxes(0, 1) # NOTE: self.ZAinner \in [num_inducing x num_data x input_dim]!
def input_sensitivity(self): def input_sensitivity(self):
if self.ARD: return self.variances return np.ones(self.input_dim) * self.variances
else: return self.variances.repeat(self.input_dim)

View file

@ -72,15 +72,19 @@ class BayesianGPLVM(SparseGP):
self.variational_prior.update_gradients_KL(self.X) self.variational_prior.update_gradients_KL(self.X)
def plot_latent(self, plot_inducing=True, *args, **kwargs): def plot_latent(self, labels=None, which_indices=None,
""" resolution=50, ax=None, marker='o', s=40,
See GPy.plotting.matplot_dep.dim_reduction_plots.plot_latent fignum=None, plot_inducing=True, legend=True,
""" plot_limits=None,
aspect='auto', updates=False, **kwargs):
import sys import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported." assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import dim_reduction_plots from ..plotting.matplot_dep import dim_reduction_plots
return dim_reduction_plots.plot_latent(self, plot_inducing=plot_inducing, *args, **kwargs) return dim_reduction_plots.plot_latent(self, labels, which_indices,
resolution, ax, marker, s,
fignum, plot_inducing, legend,
plot_limits, aspect, updates, **kwargs)
def do_test_latents(self, Y): def do_test_latents(self, Y):
""" """

View file

@ -67,12 +67,22 @@ class GPLVM(GP):
assert self.likelihood.Y.shape[1] == 2 assert self.likelihood.Y.shape[1] == 2
pb.scatter(self.likelihood.Y[:, 0], self.likelihood.Y[:, 1], 40, self.X[:, 0].copy(), linewidth=0, cmap=pb.cm.jet) # @UndefinedVariable pb.scatter(self.likelihood.Y[:, 0], self.likelihood.Y[:, 1], 40, self.X[:, 0].copy(), linewidth=0, cmap=pb.cm.jet) # @UndefinedVariable
Xnew = np.linspace(self.X.min(), self.X.max(), 200)[:, None] Xnew = np.linspace(self.X.min(), self.X.max(), 200)[:, None]
mu, var, upper, lower = self.predict(Xnew) mu, _ = self.predict(Xnew)
pb.plot(mu[:, 0], mu[:, 1], 'k', linewidth=1.5) pb.plot(mu[:, 0], mu[:, 1], 'k', linewidth=1.5)
def plot_latent(self, *args, **kwargs): def plot_latent(self, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, legend=True,
plot_limits=None,
aspect='auto', updates=False, **kwargs):
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import dim_reduction_plots from ..plotting.matplot_dep import dim_reduction_plots
return dim_reduction_plots.plot_latent(self, *args, **kwargs) return dim_reduction_plots.plot_latent(self, labels, which_indices,
resolution, ax, marker, s,
fignum, False, legend,
plot_limits, aspect, updates, **kwargs)
def plot_magnification(self, *args, **kwargs): def plot_magnification(self, *args, **kwargs):
return util.plot_latent.plot_magnification(self, *args, **kwargs) return util.plot_latent.plot_magnification(self, *args, **kwargs)

View file

@ -30,7 +30,8 @@ def most_significant_input_dimensions(model, which_indices):
def plot_latent(model, labels=None, which_indices=None, def plot_latent(model, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40, resolution=50, ax=None, marker='o', s=40,
fignum=None, plot_inducing=False, legend=True, fignum=None, plot_inducing=False, legend=True,
aspect='auto', updates=False): plot_limits=None,
aspect='auto', updates=False, **kwargs):
""" """
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc) :param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance :param resolution: the resolution of the grid on which to evaluate the predictive variance
@ -38,6 +39,8 @@ def plot_latent(model, labels=None, which_indices=None,
if ax is None: if ax is None:
fig = pb.figure(num=fignum) fig = pb.figure(num=fignum)
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
else:
fig = ax.figure
Tango.reset() Tango.reset()
if labels is None: if labels is None:
@ -57,15 +60,28 @@ def plot_latent(model, labels=None, which_indices=None,
def plot_function(x): def plot_function(x):
Xtest_full = np.zeros((x.shape[0], model.X.shape[1])) Xtest_full = np.zeros((x.shape[0], model.X.shape[1]))
Xtest_full[:, [input_1, input_2]] = x Xtest_full[:, [input_1, input_2]] = x
mu, var, low, up = model.predict(Xtest_full) _, var = model.predict(Xtest_full)
var = var[:, :1] var = var[:, :1]
return np.log(var) return np.log(var)
#Create an IMshow controller that can re-plot the latent space shading at a good resolution #Create an IMshow controller that can re-plot the latent space shading at a good resolution
if plot_limits is None:
xmin, ymin = X[:, [input_1, input_2]].min(0)
xmax, ymax = X[:, [input_1, input_2]].max(0)
x_r, y_r = xmax-xmin, ymax-ymin
xmin -= .1*x_r
xmax += .1*x_r
ymin -= .1*y_r
ymax += .1*y_r
else:
try:
xmin, xmax, ymin, ymax = plot_limits
except (TypeError, ValueError) as e:
raise e.__class__, "Wrong plot limits: {} given -> need (xmin, xmax, ymin, ymax)".format(plot_limits)
view = ImshowController(ax, plot_function, view = ImshowController(ax, plot_function,
tuple(X[:, [input_1, input_2]].min(0)) + tuple(X[:, [input_1, input_2]].max(0)), (xmin, ymin, xmax, ymax),
resolution, aspect=aspect, interpolation='bilinear', resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.binary) cmap=pb.cm.binary, **kwargs)
# make sure labels are in order of input: # make sure labels are in order of input:
ulabels = [] ulabels = []
@ -99,18 +115,31 @@ def plot_latent(model, labels=None, which_indices=None,
if not np.all(labels == 1.) and legend: if not np.all(labels == 1.) and legend:
ax.legend(loc=0, numpoints=1) ax.legend(loc=0, numpoints=1)
#ax.set_xlim(xmin[0], xmax[0])
#ax.set_ylim(xmin[1], xmax[1])
ax.grid(b=False) # remove the grid if present, it doesn't look good ax.grid(b=False) # remove the grid if present, it doesn't look good
ax.set_aspect('auto') # set a nice aspect ratio ax.set_aspect('auto') # set a nice aspect ratio
if plot_inducing: if plot_inducing:
Z = param_to_array(model.Z) Z = param_to_array(model.Z)
ax.plot(Z[:, input_1], Z[:, input_2], '^w') ax.plot(Z[:, input_1], Z[:, input_2], '^w')
ax.set_xlim((xmin, xmax))
ax.set_ylim((ymin, ymax))
try:
fig.canvas.draw()
fig.tight_layout()
fig.canvas.draw()
except Exception as e:
print "Could not invoke tight layout: {}".format(e)
pass
if updates: if updates:
ax.figure.canvas.show() try:
ax.figure.canvas.show()
except Exception as e:
print "Could not invoke show: {}".format(e)
raw_input('Enter to continue') raw_input('Enter to continue')
view.deactivate()
return ax return ax
def plot_magnification(model, labels=None, which_indices=None, def plot_magnification(model, labels=None, which_indices=None,
@ -186,7 +215,7 @@ def plot_magnification(model, labels=None, which_indices=None,
ax.plot(model.Z[:, input_1], model.Z[:, input_2], '^w') ax.plot(model.Z[:, input_1], model.Z[:, input_2], '^w')
if updates: if updates:
ax.figure.canvas.show() fig.canvas.show()
raw_input('Enter to continue') raw_input('Enter to continue')
pb.title('Magnification Factor') pb.title('Magnification Factor')

View file

@ -33,7 +33,7 @@ class AxisChangedController(AxisEventController):
Constructor Constructor
''' '''
super(AxisChangedController, self).__init__(ax) super(AxisChangedController, self).__init__(ax)
self._lim_ratio_threshold = update_lim or .8 self._lim_ratio_threshold = update_lim or .95
self._x_lim = self.ax.get_xlim() self._x_lim = self.ax.get_xlim()
self._y_lim = self.ax.get_ylim() self._y_lim = self.ax.get_ylim()
@ -80,6 +80,10 @@ class AxisChangedController(AxisEventController):
class BufferedAxisChangedController(AxisChangedController): class BufferedAxisChangedController(AxisChangedController):
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=None, **kwargs): def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=None, **kwargs):
""" """
Buffered axis changed controller. Controls the buffer and handles update events for when the axes changed.
Updated plotting will be after first reload (first time will be within plot limits, after that the limits will be buffered)
:param plot_function: :param plot_function:
function to use for creating image for plotting (return ndarray-like) function to use for creating image for plotting (return ndarray-like)
plot_function gets called with (2D!) Xtest grid if replotting required plot_function gets called with (2D!) Xtest grid if replotting required
@ -91,11 +95,13 @@ class BufferedAxisChangedController(AxisChangedController):
""" """
super(BufferedAxisChangedController, self).__init__(ax, update_lim=update_lim) super(BufferedAxisChangedController, self).__init__(ax, update_lim=update_lim)
self.plot_function = plot_function self.plot_function = plot_function
xmin, xmax = self._x_lim # self._compute_buffered(*self._x_lim) xmin, ymin, xmax, ymax = plot_limits#self._x_lim # self._compute_buffered(*self._x_lim)
ymin, ymax = self._y_lim # self._compute_buffered(*self._y_lim) # imshow acts on the limits of the plot, this is why we need to override the limits here, to make sure the right plot limits are used:
self._x_lim = xmin, xmax
self._y_lim = ymin, ymax
self.resolution = resolution self.resolution = resolution
self._not_init = False self._not_init = False
self.view = self._init_view(self.ax, self.recompute_X(), xmin, xmax, ymin, ymax, **kwargs) self.view = self._init_view(self.ax, self.recompute_X(buffered=False), xmin, xmax, ymin, ymax, **kwargs)
self._not_init = True self._not_init = True
def update(self, ax): def update(self, ax):
@ -111,14 +117,16 @@ class BufferedAxisChangedController(AxisChangedController):
def update_view(self, view, X, xmin, xmax, ymin, ymax): def update_view(self, view, X, xmin, xmax, ymin, ymax):
raise NotImplementedError('update view given in here') raise NotImplementedError('update view given in here')
def get_grid(self): def get_grid(self, buffered=True):
xmin, xmax = self._compute_buffered(*self._x_lim) if buffered: comp = self._compute_buffered
ymin, ymax = self._compute_buffered(*self._y_lim) else: comp = lambda a,b: (a,b)
xmin, xmax = comp(*self._x_lim)
ymin, ymax = comp(*self._y_lim)
x, y = numpy.mgrid[xmin:xmax:1j * self.resolution, ymin:ymax:1j * self.resolution] x, y = numpy.mgrid[xmin:xmax:1j * self.resolution, ymin:ymax:1j * self.resolution]
return numpy.hstack((x.flatten()[:, None], y.flatten()[:, None])) return numpy.hstack((x.flatten()[:, None], y.flatten()[:, None]))
def recompute_X(self): def recompute_X(self, buffered=True):
X = self.plot_function(self.get_grid()) X = self.plot_function(self.get_grid(buffered))
if isinstance(X, (tuple, list)): if isinstance(X, (tuple, list)):
for x in X: for x in X:
x.shape = [self.resolution, self.resolution] x.shape = [self.resolution, self.resolution]

View file

@ -9,7 +9,7 @@ import numpy
class ImshowController(BufferedAxisChangedController): class ImshowController(BufferedAxisChangedController):
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=.5, **kwargs): def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=.8, **kwargs):
""" """
:param plot_function: :param plot_function:
function to use for creating image for plotting (return ndarray-like) function to use for creating image for plotting (return ndarray-like)

View file

@ -4,6 +4,8 @@ import GPy
import numpy as np import numpy as np
import matplotlib as mpl import matplotlib as mpl
import time import time
from ...util.misc import param_to_array
from GPy.core.parameterization.variational import VariationalPosterior
try: try:
import visual import visual
visual_available = True visual_available = True
@ -72,12 +74,13 @@ class vector_show(matplotlib_show):
""" """
def __init__(self, vals, axes=None): def __init__(self, vals, axes=None):
matplotlib_show.__init__(self, vals, axes) matplotlib_show.__init__(self, vals, axes)
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals.T)[0] self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals)
def modify(self, vals): def modify(self, vals):
self.vals = vals.copy() self.vals = vals.copy()
xdata, ydata = self.handle.get_data() for handle, vals in zip(self.handle, self.vals.T):
self.handle.set_data(xdata, self.vals.T) xdata, ydata = handle.get_data()
handle.set_data(xdata, vals)
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()
@ -91,8 +94,12 @@ class lvm(matplotlib_show):
:param latent_axes: the axes where the latent visualization should be plotted. :param latent_axes: the axes where the latent visualization should be plotted.
""" """
if vals == None: if vals == None:
vals = model.X[0] if isinstance(model.X, VariationalPosterior):
vals = param_to_array(model.X.mean)
else:
vals = param_to_array(model.X)
vals = param_to_array(vals)
matplotlib_show.__init__(self, vals, axes=latent_axes) matplotlib_show.__init__(self, vals, axes=latent_axes)
if isinstance(latent_axes,mpl.axes.Axes): if isinstance(latent_axes,mpl.axes.Axes):