diff --git a/GPy/plotting/matplot_dep/models_plots.py b/GPy/plotting/matplot_dep/models_plots.py index 7507c376..ae79569b 100644 --- a/GPy/plotting/matplot_dep/models_plots.py +++ b/GPy/plotting/matplot_dep/models_plots.py @@ -6,6 +6,7 @@ import numpy as np import Tango from base_plots import gpplot, x_frame1D, x_frame2D from ...util.misc import param_to_array +from ...models.gp_coregionalized_regression import GPCoregionalizedRegression def plot_fit(model, plot_limits=None, which_data_rows='all', @@ -85,8 +86,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all', lower = m - 2*np.sqrt(v) upper = m + 2*np.sqrt(v) else: - m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata) - lower, upper = model.predict_quantiles(Xgrid, Y_metadata=Y_metadata) + meta = {'output_index': Xgrid[:,-1:].astype(np.int)} if isinstance(model,GPCoregionalizedRegression) else None + m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta) + lower, upper = model.predict_quantiles(Xgrid, Y_metadata=meta) for d in which_data_ycols: