diff --git a/GPy/plotting/gpy_plot/latent_plots.py b/GPy/plotting/gpy_plot/latent_plots.py index 40ad0251..92035d65 100644 --- a/GPy/plotting/gpy_plot/latent_plots.py +++ b/GPy/plotting/gpy_plot/latent_plots.py @@ -256,7 +256,7 @@ def plot_latent(self, labels=None, which_indices=None, xlabel='latent dimension %i' % input_1, ylabel='latent dimension %i' % input_2, **imshow_kwargs) if (labels is not None) and legend: legend = find_best_layout_for_subplots(len(np.unique(labels)))[1] - else: + elif legend: labels = np.ones(self.num_data) legend = False scatters = _plot_latent_scatter(canvas, X, which_indices, labels, marker, num_samples, projection='2d', **scatter_kwargs or {})