From 22221565bbcdfcb4d4d7f6576fe789cce813a95a Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Mon, 12 May 2014 12:05:06 +0100 Subject: [PATCH] [visualize] some adjustments to vector_show --- GPy/examples/dimensionality_reduction.py | 6 +++--- GPy/plotting/matplot_dep/visualize.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/GPy/examples/dimensionality_reduction.py b/GPy/examples/dimensionality_reduction.py index 8a31968e..ac1c50ee 100644 --- a/GPy/examples/dimensionality_reduction.py +++ b/GPy/examples/dimensionality_reduction.py @@ -161,6 +161,7 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40, import GPy from matplotlib import pyplot as plt from ..util.misc import param_to_array + import numpy as np _np.random.seed(0) data = GPy.util.datasets.oil() @@ -174,11 +175,10 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40, m.optimize('scg', messages=verbose, max_iters=max_iters, gtol=.05) if plot: - y = m.Y fig, (latent_axes, sense_axes) = plt.subplots(1, 2) m.plot_latent(ax=latent_axes, labels=m.data_labels) - data_show = GPy.plotting.matplot_dep.visualize.vector_show(y) - lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean), # @UnusedVariable + data_show = GPy.plotting.matplot_dep.visualize.vector_show(np.zeros((m.Y.shape[1], 1))) + lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean)[0:1,:], # @UnusedVariable m, data_show, latent_axes=latent_axes, sense_axes=sense_axes) raw_input('Press enter to finish') plt.close(fig) diff --git a/GPy/plotting/matplot_dep/visualize.py b/GPy/plotting/matplot_dep/visualize.py index fae05ff3..b26910c4 100644 --- a/GPy/plotting/matplot_dep/visualize.py +++ b/GPy/plotting/matplot_dep/visualize.py @@ -74,13 +74,17 @@ class vector_show(matplotlib_show): """ def __init__(self, vals, axes=None): matplotlib_show.__init__(self, vals, axes) - self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals) + #assert vals.ndim == 2, "Please give a vector in [n x 1] to plot" + #assert vals.shape[1] == 1, "only showing a vector in one dimension" + self.size = vals.size + + self.handle = self.axes.plot(np.arange(0, vals.size)[:, None], self.vals)[0] def modify(self, vals): self.vals = vals.copy() - for handle, vals in zip(self.handle, self.vals.T): - xdata, ydata = handle.get_data() - handle.set_data(xdata, vals) + xdata, ydata = self.handle.get_data() + assert vals.size == self.size, "values passed into modify changed size! vals:{} != in:{}".format(vals.size, self.size) + self.handle.set_data(xdata, self.vals) self.axes.figure.canvas.draw() @@ -94,12 +98,12 @@ class lvm(matplotlib_show): :type data_visualize: visualize.data_show type. :param latent_axes: the axes where the latent visualization should be plotted. """ - if vals == None: + if vals is None: if isinstance(model.X, VariationalPosterior): vals = param_to_array(model.X.mean) else: vals = param_to_array(model.X) - + vals = param_to_array(vals) matplotlib_show.__init__(self, vals, axes=latent_axes)