diff --git a/GPy/plotting/matplot_dep/models_plots.py b/GPy/plotting/matplot_dep/models_plots.py index ae79569b..cbb213b1 100644 --- a/GPy/plotting/matplot_dep/models_plots.py +++ b/GPy/plotting/matplot_dep/models_plots.py @@ -7,6 +7,7 @@ import Tango from base_plots import gpplot, x_frame1D, x_frame2D from ...util.misc import param_to_array from ...models.gp_coregionalized_regression import GPCoregionalizedRegression +from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalizedRegression def plot_fit(model, plot_limits=None, which_data_rows='all', @@ -86,7 +87,10 @@ def plot_fit(model, plot_limits=None, which_data_rows='all', lower = m - 2*np.sqrt(v) upper = m + 2*np.sqrt(v) else: - meta = {'output_index': Xgrid[:,-1:].astype(np.int)} if isinstance(model,GPCoregionalizedRegression) else None + if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression): + meta = {'output_index': Xgrid[:,-1:].astype(np.int)} + else: + meta = None m, v = model.predict(Xgrid, full_cov=False, Y_metadata=meta) lower, upper = model.predict_quantiles(Xgrid, Y_metadata=meta)