diff --git a/GPy/plotting/gpy_plot/latent_plots.py b/GPy/plotting/gpy_plot/latent_plots.py index 976641b2..85a98f49 100644 --- a/GPy/plotting/gpy_plot/latent_plots.py +++ b/GPy/plotting/gpy_plot/latent_plots.py @@ -229,7 +229,7 @@ def plot_latent(self, labels=None, which_indices=None, plot_limits=None, updates=False, kern=None, marker='<>^vsd', - num_samples=1000, + num_samples=1000, projection='2d', scatter_kwargs=None, **imshow_kwargs): """ Plot the latent space of the GP on the inputs. This is the @@ -249,6 +249,8 @@ def plot_latent(self, labels=None, which_indices=None, :param imshow_kwargs: the kwargs for the imshow (magnification factor) :param scatter_kwargs: the kwargs for the scatter plots """ + if projection != '2d': + raise ValueError('Cannot plot latent in other then 2 dimensions, consider plot_scatter') input_1, input_2 = which_indices = self.get_most_significant_input_dimensions(which_indices)[:2] X = get_x_y_var(self)[0] _, _, Xgrid, _, _, xmin, xmax, resolution = helper_for_plot_data(self, X, plot_limits, which_indices, None, resolution)