diff --git a/GPy/plotting/matplot_dep/models_plots.py b/GPy/plotting/matplot_dep/models_plots.py index 7926410e..46a79ad8 100644 --- a/GPy/plotting/matplot_dep/models_plots.py +++ b/GPy/plotting/matplot_dep/models_plots.py @@ -150,7 +150,11 @@ def plot_fit(model, plot_limits=None, which_data_rows='all', if plot_raw: m, _ = model._raw_predict(Xgrid) else: - m, _ = model.predict(Xgrid) + if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression): + meta = {'output_index': Xgrid[:,-1:].astype(np.int)} + else: + meta = None + m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta) for d in which_data_ycols: m_d = m[:,d].reshape(resolution, resolution).T plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)