From 9c2eac6e1e7916c0b90d41269ec3ee3d7c2b121f Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Fri, 26 Feb 2016 12:01:00 +0000 Subject: [PATCH] [latent plots] legend was always plotted --- GPy/plotting/gpy_plot/latent_plots.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/GPy/plotting/gpy_plot/latent_plots.py b/GPy/plotting/gpy_plot/latent_plots.py index 54982fc5..df5e239a 100644 --- a/GPy/plotting/gpy_plot/latent_plots.py +++ b/GPy/plotting/gpy_plot/latent_plots.py @@ -163,7 +163,8 @@ def plot_magnification(self, labels=None, which_indices=None, updates=False, mean=True, covariance=True, kern=None, num_samples=1000, - scatter_kwargs=None, **imshow_kwargs): + scatter_kwargs=None, plot_scatter=True, + **imshow_kwargs): """ Plot the magnification factor of the GP on the inputs. This is the density of the GP as a gray scale. @@ -188,18 +189,20 @@ def plot_magnification(self, labels=None, which_indices=None, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, X, plot_limits, which_indices, None, resolution) canvas, imshow_kwargs = pl().new_canvas(xlim=(xmin[0], xmax[0]), ylim=(xmin[1], xmax[1]), xlabel='latent dimension %i' % input_1, ylabel='latent dimension %i' % input_2, **imshow_kwargs) + plots = {} if legend: if (labels is not None): legend = find_best_layout_for_subplots(len(np.unique(labels)))[1] else: labels = np.ones(self.num_data) legend = False - scatters = _plot_latent_scatter(canvas, X, which_indices, labels, marker, num_samples, projection='2d', **scatter_kwargs or {}) - view = _plot_magnification(self, canvas, which_indices, Xgrid, xmin, xmax, resolution, updates, mean, covariance, kern, **imshow_kwargs) - retval = pl().add_to_canvas(canvas, dict(scatter=scatters, imshow=view), + if plot_scatter: + plots['scatters'] = _plot_latent_scatter(canvas, X, which_indices, labels, marker, num_samples, projection='2d', **scatter_kwargs or {}) + plots['view'] = _plot_magnification(self, canvas, which_indices, Xgrid, xmin, xmax, resolution, updates, mean, covariance, kern, **imshow_kwargs) + retval = pl().add_to_canvas(canvas, plots, legend=legend, ) - _wait_for_updates(view, updates) + _wait_for_updates(plots['view'], updates) return retval