diff --git a/GPy/plotting/gpy_plot/latent_plots.py b/GPy/plotting/gpy_plot/latent_plots.py index ef8f3072..5427e013 100644 --- a/GPy/plotting/gpy_plot/latent_plots.py +++ b/GPy/plotting/gpy_plot/latent_plots.py @@ -147,6 +147,7 @@ def _plot_magnification(self, canvas, which_indices, Xgrid, def plot_function(x): Xtest_full = np.zeros((x.shape[0], Xgrid.shape[1])) Xtest_full[:, which_indices] = x + mf = self.predict_magnification(Xtest_full, kern=kern, mean=mean, covariance=covariance) return mf.reshape(resolution, resolution).T imshow_kwargs = update_not_existing_kwargs(imshow_kwargs, pl().defaults.magnification) @@ -215,7 +216,12 @@ def _plot_latent(self, canvas, which_indices, Xgrid, def plot_function(x): Xtest_full = np.zeros((x.shape[0], Xgrid.shape[1])) Xtest_full[:, which_indices] = x - mf = np.log(self.predict(Xtest_full, kern=kern)[1]) + mf = self.predict(Xtest_full, kern=kern)[1] + if mf.shape[1]==self.output_dim: + mf = mf.sum(-1) + else: + mf *= self.output_dim + mf = np.log(mf) return mf.reshape(resolution, resolution).T imshow_kwargs = update_not_existing_kwargs(imshow_kwargs, pl().defaults.latent)