diff --git a/GPy/plotting/gpy_plot/latent_plots.py b/GPy/plotting/gpy_plot/latent_plots.py index 92035d65..54982fc5 100644 --- a/GPy/plotting/gpy_plot/latent_plots.py +++ b/GPy/plotting/gpy_plot/latent_plots.py @@ -188,11 +188,12 @@ def plot_magnification(self, labels=None, which_indices=None, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, X, plot_limits, which_indices, None, resolution) canvas, imshow_kwargs = pl().new_canvas(xlim=(xmin[0], xmax[0]), ylim=(xmin[1], xmax[1]), xlabel='latent dimension %i' % input_1, ylabel='latent dimension %i' % input_2, **imshow_kwargs) - if (labels is not None) and legend: - legend = find_best_layout_for_subplots(len(np.unique(labels)))[1] - else: - labels = np.ones(self.num_data) - legend = False + if legend: + if (labels is not None): + legend = find_best_layout_for_subplots(len(np.unique(labels)))[1] + else: + labels = np.ones(self.num_data) + legend = False scatters = _plot_latent_scatter(canvas, X, which_indices, labels, marker, num_samples, projection='2d', **scatter_kwargs or {}) view = _plot_magnification(self, canvas, which_indices, Xgrid, xmin, xmax, resolution, updates, mean, covariance, kern, **imshow_kwargs) retval = pl().add_to_canvas(canvas, dict(scatter=scatters, imshow=view), @@ -254,11 +255,12 @@ def plot_latent(self, labels=None, which_indices=None, _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, X, plot_limits, which_indices, None, resolution) canvas, imshow_kwargs = pl().new_canvas(xlim=(xmin[0], xmax[0]), ylim=(xmin[1], xmax[1]), xlabel='latent dimension %i' % input_1, ylabel='latent dimension %i' % input_2, **imshow_kwargs) - if (labels is not None) and legend: - legend = find_best_layout_for_subplots(len(np.unique(labels)))[1] - elif legend: - labels = np.ones(self.num_data) - legend = False + if legend: + if (labels is not None): + legend = find_best_layout_for_subplots(len(np.unique(labels)))[1] + else: + labels = np.ones(self.num_data) + legend = False scatters = _plot_latent_scatter(canvas, X, which_indices, labels, marker, num_samples, projection='2d', **scatter_kwargs or {}) view = _plot_latent(self, canvas, which_indices, Xgrid, xmin, xmax, resolution, updates, kern, **imshow_kwargs) retval = pl().add_to_canvas(canvas, dict(scatter=scatters, imshow=view), legend=legend)